GLM: Improve error messages and coding style.
[pspp] / src / language / stats / glm.c
index 51232fe160bcd5f64bfd39423ada691d627ee9ed..004ae97feef474389de5015fbc8de71edad4f75f 100644 (file)
 #define _(msgid) gettext (msgid)
 
 struct glm_spec
-{
-  size_t n_dep_vars;
-  const struct variable **dep_vars;
+  {
+    const struct variable **dep_vars;
+    size_t n_dep_vars;
 
-  size_t n_factor_vars;
-  const struct variable **factor_vars;
+    const struct variable **factor_vars;
+    size_t n_factor_vars;
 
-  size_t n_interactions;
-  struct interaction **interactions;
+    struct interaction **interactions;
+    size_t n_interactions;
 
-  enum mv_class exclude;
+    enum mv_class exclude;
 
-  /* The weight variable */
-  const struct variable *wv;
+    const struct variable *wv;    /* The weight variable */
 
-  const struct dictionary *dict;
+    const struct dictionary *dict;
 
-  int ss_type;
-  bool intercept;
+    int ss_type;
+    bool intercept;
 
-  double alpha;
+    double alpha;
 
-  bool dump_coding;
-};
+    bool dump_coding;
+  };
 
 struct glm_workspace
-{
-  double total_ssq;
-  struct moments *totals;
+  {
+    double total_ssq;
+    struct moments *totals;
 
-  struct categoricals *cats;
+    struct categoricals *cats;
 
-  /*
-     Sums of squares due to different variables. Element 0 is the SSE
-     for the entire model. For i > 0, element i is the SS due to
-     variable i.
-   */
-  gsl_vector *ssq;
-};
+    /*
+      Sums of squares due to different variables. Element 0 is the SSE
+      for the entire model. For i > 0, element i is the SS due to
+      variable i.
+    */
+    gsl_vector *ssq;
+  };
 
 
 /* Default design: all possible interactions */
 static void
 design_full (struct glm_spec *glm)
 {
-  int sz;
-  int i = 0;
   glm->n_interactions = (1 << glm->n_factor_vars) - 1;
-
   glm->interactions = xcalloc (glm->n_interactions, sizeof *glm->interactions);
 
   /* All subsets, with exception of the empty set, of [0, glm->n_factor_vars) */
-  for (sz = 1; sz <= glm->n_factor_vars; ++sz)
+  size_t i = 0;
+  for (size_t sz = 1; sz <= glm->n_factor_vars; ++sz)
     {
       gsl_combination *c = gsl_combination_calloc (glm->n_factor_vars, sz);
 
@@ -110,7 +107,7 @@ design_full (struct glm_spec *glm)
        {
          struct interaction *iact = interaction_create (NULL);
          int e;
-         for (e = 0 ; e < gsl_combination_k (c); ++e)
+         for (e = 0; e < gsl_combination_k (c); ++e)
            interaction_add_variable (iact, glm->factor_vars [gsl_combination_get (c, e)]);
 
          glm->interactions[i++] = iact;
@@ -133,30 +130,26 @@ static bool parse_design_spec (struct lexer *lexer, struct glm_spec *glm);
 int
 cmd_glm (struct lexer *lexer, struct dataset *ds)
 {
-  int i;
   struct const_var_set *factors = NULL;
-  struct glm_spec glm;
   bool design = false;
-  glm.dict = dataset_dict (ds);
-  glm.n_dep_vars = 0;
-  glm.n_factor_vars = 0;
-  glm.n_interactions = 0;
-  glm.interactions = NULL;
-  glm.dep_vars = NULL;
-  glm.factor_vars = NULL;
-  glm.exclude = MV_ANY;
-  glm.intercept = true;
-  glm.wv = dict_get_weight (glm.dict);
-  glm.alpha = 0.05;
-  glm.dump_coding = false;
-  glm.ss_type = 3;
-
+  struct dictionary *dict = dataset_dict (ds);
+  struct glm_spec glm = {
+    .dict = dict,
+    .exclude = MV_ANY,
+    .intercept = true,
+    .wv = dict_get_weight (dict),
+    .alpha = 0.05,
+    .ss_type = 3,
+  };
+
+  int dep_vars_start = lex_ofs (lexer);
   if (!parse_variables_const (lexer, glm.dict,
                              &glm.dep_vars, &glm.n_dep_vars,
                              PV_NO_DUPLICATE | PV_NUMERIC))
     goto error;
+  int dep_vars_end = lex_ofs (lexer) - 1;
 
-  if (! lex_force_match (lexer, T_BY))
+  if (!lex_force_match (lexer, T_BY))
     goto error;
 
   if (!parse_variables_const (lexer, glm.dict,
@@ -166,12 +159,12 @@ cmd_glm (struct lexer *lexer, struct dataset *ds)
 
   if (glm.n_dep_vars > 1)
     {
-      msg (ME, _("Multivariate analysis is not yet implemented"));
+      lex_ofs_error (lexer, dep_vars_start, dep_vars_end,
+                     _("Multivariate analysis is not yet implemented"));
       return CMD_FAILURE;
     }
 
-  factors =
-    const_var_set_create_from_array (glm.factor_vars, glm.n_factor_vars);
+  factors = const_var_set_create_from_array (glm.factor_vars, glm.n_factor_vars);
 
   while (lex_token (lexer) != T_ENDCMD)
     {
@@ -184,16 +177,12 @@ cmd_glm (struct lexer *lexer, struct dataset *ds)
                 && lex_token (lexer) != T_SLASH)
            {
              if (lex_match_id (lexer, "INCLUDE"))
-               {
-                 glm.exclude = MV_SYSTEM;
-               }
+                glm.exclude = MV_SYSTEM;
              else if (lex_match_id (lexer, "EXCLUDE"))
-               {
-                 glm.exclude = MV_ANY;
-               }
+                glm.exclude = MV_ANY;
              else
                {
-                 lex_error (lexer, NULL);
+                 lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
                  goto error;
                }
            }
@@ -205,16 +194,12 @@ cmd_glm (struct lexer *lexer, struct dataset *ds)
                 && lex_token (lexer) != T_SLASH)
            {
              if (lex_match_id (lexer, "INCLUDE"))
-               {
-                 glm.intercept = true;
-               }
+                glm.intercept = true;
              else if (lex_match_id (lexer, "EXCLUDE"))
-               {
-                 glm.intercept = false;
-               }
+                glm.intercept = false;
              else
                {
-                 lex_error (lexer, NULL);
+                 lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
                  goto error;
                }
            }
@@ -222,124 +207,74 @@ cmd_glm (struct lexer *lexer, struct dataset *ds)
       else if (lex_match_id (lexer, "CRITERIA"))
        {
          lex_match (lexer, T_EQUALS);
-         if (lex_match_id (lexer, "ALPHA"))
-           {
-             if (lex_force_match (lexer, T_LPAREN))
-               {
-                 if (! lex_force_num (lexer))
-                   {
-                     lex_error (lexer, NULL);
-                     goto error;
-                   }
-
-                 glm.alpha = lex_number (lexer);
-                 lex_get (lexer);
-                 if (! lex_force_match (lexer, T_RPAREN))
-                   {
-                     lex_error (lexer, NULL);
-                     goto error;
-                   }
-               }
-           }
-         else
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
+         if (!lex_force_match_phrase (lexer, "ALPHA(")
+              || !lex_force_num (lexer))
+            goto error;
+          glm.alpha = lex_number (lexer);
+          lex_get (lexer);
+          if (!lex_force_match (lexer, T_RPAREN))
+            goto error;
        }
       else if (lex_match_id (lexer, "METHOD"))
        {
          lex_match (lexer, T_EQUALS);
-         if (!lex_force_match_id (lexer, "SSTYPE"))
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
-
-         if (! lex_force_match (lexer, T_LPAREN))
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
-
-         if (! lex_force_int (lexer))
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
+         if (!lex_force_match_phrase (lexer, "SSTYPE(")
+              || !lex_force_int_range (lexer, "SSTYPE", 1, 3))
+            goto error;
 
          glm.ss_type = lex_integer (lexer);
-         if (1 > glm.ss_type  ||  3 < glm.ss_type)
-           {
-             msg (ME, _("Only types 1, 2 & 3 sums of squares are currently implemented"));
-             goto error;
-           }
-
          lex_get (lexer);
 
-         if (! lex_force_match (lexer, T_RPAREN))
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
+         if (!lex_force_match (lexer, T_RPAREN))
+            goto error;
        }
       else if (lex_match_id (lexer, "DESIGN"))
        {
          lex_match (lexer, T_EQUALS);
 
-         if (! parse_design_spec (lexer, &glm))
+         if (!parse_design_spec (lexer, &glm))
            goto error;
 
          if (glm.n_interactions > 0)
            design = true;
        }
       else if (lex_match_id (lexer, "SHOWCODES"))
-       /* Undocumented debug option */
        {
-         lex_match (lexer, T_EQUALS);
-
+          /* Undocumented debug option */
          glm.dump_coding = true;
        }
       else
        {
-         lex_error (lexer, NULL);
+         lex_error_expecting (lexer, "MISSING", "INTERCEPT", "CRITERIA",
+                               "METHOD", "DESIGN");
          goto error;
        }
     }
 
-  if (! design)
-    {
-      design_full (&glm);
-    }
+  if (!design)
+    design_full (&glm);
 
-  {
-    struct casegrouper *grouper;
-    struct casereader *group;
-    bool ok;
-
-    grouper = casegrouper_create_splits (proc_open (ds), glm.dict);
-    while (casegrouper_get_next_group (grouper, &group))
-      run_glm (&glm, group, ds);
-    ok = casegrouper_destroy (grouper);
-    ok = proc_commit (ds) && ok;
-  }
+  struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), glm.dict);
+  struct casereader *group;
+  while (casegrouper_get_next_group (grouper, &group))
+    run_glm (&glm, group, ds);
+  bool ok = casegrouper_destroy (grouper);
+  ok = proc_commit (ds) && ok;
 
   const_var_set_destroy (factors);
   free (glm.factor_vars);
-  for (i = 0 ; i < glm.n_interactions; ++i)
+  for (size_t i = 0; i < glm.n_interactions; ++i)
     interaction_destroy (glm.interactions[i]);
 
   free (glm.interactions);
   free (glm.dep_vars);
 
-
   return CMD_SUCCESS;
 
 error:
-
   const_var_set_destroy (factors);
   free (glm.factor_vars);
-  for (i = 0 ; i < glm.n_interactions; ++i)
+  for (size_t i = 0; i < glm.n_interactions; ++i)
     interaction_destroy (glm.interactions[i]);
 
   free (glm.interactions);
@@ -351,7 +286,7 @@ error:
 static inline bool
 not_dropped (size_t j, const bool *ff)
 {
-  return ! ff[j];
+  return !ff[j];
 }
 
 static void
@@ -392,8 +327,8 @@ ssq_type1 (struct covariance *cov, gsl_vector *ssq, const struct glm_spec *cmd)
   const gsl_matrix *cm = covariance_calculate_unnormalized (cov);
   size_t i;
   size_t k;
-  bool *model_dropped = xcalloc (covariance_dim (cov), sizeof (*model_dropped));
-  bool *submodel_dropped = xcalloc (covariance_dim (cov), sizeof (*submodel_dropped));
+  bool *model_dropped = XCALLOC (covariance_dim (cov), bool);
+  bool *submodel_dropped = XCALLOC (covariance_dim (cov), bool);
   const struct categoricals *cats = covariance_get_categoricals (cov);
 
   size_t n_dropped_model = 0;
@@ -461,8 +396,8 @@ ssq_type2 (struct covariance *cov, gsl_vector *ssq, const struct glm_spec *cmd)
   const gsl_matrix *cm = covariance_calculate_unnormalized (cov);
   size_t i;
   size_t k;
-  bool *model_dropped = xcalloc (covariance_dim (cov), sizeof (*model_dropped));
-  bool *submodel_dropped = xcalloc (covariance_dim (cov), sizeof (*submodel_dropped));
+  bool *model_dropped = XCALLOC (covariance_dim (cov), bool);
+  bool *submodel_dropped = XCALLOC (covariance_dim (cov), bool);
   const struct categoricals *cats = covariance_get_categoricals (cov);
 
   for (k = 0; k < cmd->n_interactions; k++)
@@ -524,8 +459,8 @@ ssq_type3 (struct covariance *cov, gsl_vector *ssq, const struct glm_spec *cmd)
   const gsl_matrix *cm = covariance_calculate_unnormalized (cov);
   size_t i;
   size_t k;
-  bool *model_dropped = xcalloc (covariance_dim (cov), sizeof (*model_dropped));
-  bool *submodel_dropped = xcalloc (covariance_dim (cov), sizeof (*submodel_dropped));
+  bool *model_dropped = XCALLOC (covariance_dim (cov), bool);
+  bool *submodel_dropped = XCALLOC (covariance_dim (cov), bool);
   const struct categoricals *cats = covariance_get_categoricals (cov);
 
   double ss0;
@@ -624,8 +559,7 @@ run_glm (struct glm_spec *cmd, struct casereader *input,
       double weight = dict_get_case_weight (dict, c, &warn_bad_weight);
 
       for (v = 0; v < cmd->n_dep_vars; ++v)
-       moments_pass_one (ws.totals, case_data (c, cmd->dep_vars[v])->f,
-                         weight);
+       moments_pass_one (ws.totals, case_num (c, cmd->dep_vars[v]), weight);
 
       covariance_accumulate_pass1 (cov, c);
     }
@@ -642,8 +576,7 @@ run_glm (struct glm_spec *cmd, struct casereader *input,
       double weight = dict_get_case_weight (dict, c, &warn_bad_weight);
 
       for (v = 0; v < cmd->n_dep_vars; ++v)
-       moments_pass_two (ws.totals, case_data (c, cmd->dep_vars[v])->f,
-                         weight);
+       moments_pass_two (ws.totals, case_num (c, cmd->dep_vars[v]), weight);
 
       covariance_accumulate_pass2 (cov, c);
     }
@@ -848,20 +781,20 @@ static bool
 parse_nested_variable (struct lexer *lexer, struct glm_spec *glm)
 {
   const struct variable *v = NULL;
-  if (! lex_match_variable (lexer, glm->dict, &v))
+  if (!lex_match_variable (lexer, glm->dict, &v))
     return false;
 
   if (lex_match (lexer, T_LPAREN))
     {
-      if (! parse_nested_variable (lexer, glm))
+      if (!parse_nested_variable (lexer, glm))
        return false;
 
-      if (! lex_force_match (lexer, T_RPAREN))
+      if (!lex_force_match (lexer, T_RPAREN))
        return false;
     }
 
-  lex_error (lexer, "Nested variables are not yet implemented"); return false;
-  return true;
+  lex_error (lexer, "Nested variables are not yet implemented");
+  return false;
 }
 
 /* A design term is an interaction OR a nested variable */
@@ -872,7 +805,7 @@ parse_design_term (struct lexer *lexer, struct glm_spec *glm)
   if (parse_design_interaction (lexer, glm->dict, &iact))
     {
       /* Interaction parsing successful.  Add to list of interactions */
-      glm->interactions = xrealloc (glm->interactions, sizeof *glm->interactions * ++glm->n_interactions);
+      glm->interactions = xrealloc (glm->interactions, sizeof (*glm->interactions) * ++glm->n_interactions);
       glm->interactions[glm->n_interactions - 1] = iact;
       return true;
     }
@@ -895,11 +828,10 @@ parse_design_spec (struct lexer *lexer, struct glm_spec *glm)
   if  (lex_token (lexer) == T_ENDCMD || lex_token (lexer) == T_SLASH)
     return true;
 
-  if (! parse_design_term (lexer, glm))
+  if (!parse_design_term (lexer, glm))
     return false;
 
   lex_match (lexer, T_COMMA);
 
   return parse_design_spec (lexer, glm);
 }
-