GLM: Improve error messages and coding style.
authorBen Pfaff <blp@cs.stanford.edu>
Sun, 20 Nov 2022 02:57:53 +0000 (18:57 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sun, 20 Nov 2022 02:57:53 +0000 (18:57 -0800)
src/language/stats/glm.c

index 92a1959b659c68c4050481876113fe74dafb0fbd..004ae97feef474389de5015fbc8de71edad4f75f 100644 (file)
 #define _(msgid) gettext (msgid)
 
 struct glm_spec
-{
-  size_t n_dep_vars;
-  const struct variable **dep_vars;
+  {
+    const struct variable **dep_vars;
+    size_t n_dep_vars;
 
-  size_t n_factor_vars;
-  const struct variable **factor_vars;
+    const struct variable **factor_vars;
+    size_t n_factor_vars;
 
-  size_t n_interactions;
-  struct interaction **interactions;
+    struct interaction **interactions;
+    size_t n_interactions;
 
-  enum mv_class exclude;
+    enum mv_class exclude;
 
-  /* The weight variable */
-  const struct variable *wv;
+    const struct variable *wv;    /* The weight variable */
 
-  const struct dictionary *dict;
+    const struct dictionary *dict;
 
-  int ss_type;
-  bool intercept;
+    int ss_type;
+    bool intercept;
 
-  double alpha;
+    double alpha;
 
-  bool dump_coding;
-};
+    bool dump_coding;
+  };
 
 struct glm_workspace
-{
-  double total_ssq;
-  struct moments *totals;
+  {
+    double total_ssq;
+    struct moments *totals;
 
-  struct categoricals *cats;
+    struct categoricals *cats;
 
-  /*
-     Sums of squares due to different variables. Element 0 is the SSE
-     for the entire model. For i > 0, element i is the SS due to
-     variable i.
-   */
-  gsl_vector *ssq;
-};
+    /*
+      Sums of squares due to different variables. Element 0 is the SSE
+      for the entire model. For i > 0, element i is the SS due to
+      variable i.
+    */
+    gsl_vector *ssq;
+  };
 
 
 /* Default design: all possible interactions */
 static void
 design_full (struct glm_spec *glm)
 {
-  int sz;
-  int i = 0;
   glm->n_interactions = (1 << glm->n_factor_vars) - 1;
-
   glm->interactions = xcalloc (glm->n_interactions, sizeof *glm->interactions);
 
   /* All subsets, with exception of the empty set, of [0, glm->n_factor_vars) */
-  for (sz = 1; sz <= glm->n_factor_vars; ++sz)
+  size_t i = 0;
+  for (size_t sz = 1; sz <= glm->n_factor_vars; ++sz)
     {
       gsl_combination *c = gsl_combination_calloc (glm->n_factor_vars, sz);
 
@@ -110,7 +107,7 @@ design_full (struct glm_spec *glm)
        {
          struct interaction *iact = interaction_create (NULL);
          int e;
-         for (e = 0 ; e < gsl_combination_k (c); ++e)
+         for (e = 0; e < gsl_combination_k (c); ++e)
            interaction_add_variable (iact, glm->factor_vars [gsl_combination_get (c, e)]);
 
          glm->interactions[i++] = iact;
@@ -133,23 +130,17 @@ static bool parse_design_spec (struct lexer *lexer, struct glm_spec *glm);
 int
 cmd_glm (struct lexer *lexer, struct dataset *ds)
 {
-  int i;
   struct const_var_set *factors = NULL;
-  struct glm_spec glm;
   bool design = false;
-  glm.dict = dataset_dict (ds);
-  glm.n_dep_vars = 0;
-  glm.n_factor_vars = 0;
-  glm.n_interactions = 0;
-  glm.interactions = NULL;
-  glm.dep_vars = NULL;
-  glm.factor_vars = NULL;
-  glm.exclude = MV_ANY;
-  glm.intercept = true;
-  glm.wv = dict_get_weight (glm.dict);
-  glm.alpha = 0.05;
-  glm.dump_coding = false;
-  glm.ss_type = 3;
+  struct dictionary *dict = dataset_dict (ds);
+  struct glm_spec glm = {
+    .dict = dict,
+    .exclude = MV_ANY,
+    .intercept = true,
+    .wv = dict_get_weight (dict),
+    .alpha = 0.05,
+    .ss_type = 3,
+  };
 
   int dep_vars_start = lex_ofs (lexer);
   if (!parse_variables_const (lexer, glm.dict,
@@ -158,7 +149,7 @@ cmd_glm (struct lexer *lexer, struct dataset *ds)
     goto error;
   int dep_vars_end = lex_ofs (lexer) - 1;
 
-  if (! lex_force_match (lexer, T_BY))
+  if (!lex_force_match (lexer, T_BY))
     goto error;
 
   if (!parse_variables_const (lexer, glm.dict,
@@ -173,8 +164,7 @@ cmd_glm (struct lexer *lexer, struct dataset *ds)
       return CMD_FAILURE;
     }
 
-  factors =
-    const_var_set_create_from_array (glm.factor_vars, glm.n_factor_vars);
+  factors = const_var_set_create_from_array (glm.factor_vars, glm.n_factor_vars);
 
   while (lex_token (lexer) != T_ENDCMD)
     {
@@ -187,16 +177,12 @@ cmd_glm (struct lexer *lexer, struct dataset *ds)
                 && lex_token (lexer) != T_SLASH)
            {
              if (lex_match_id (lexer, "INCLUDE"))
-               {
-                 glm.exclude = MV_SYSTEM;
-               }
+                glm.exclude = MV_SYSTEM;
              else if (lex_match_id (lexer, "EXCLUDE"))
-               {
-                 glm.exclude = MV_ANY;
-               }
+                glm.exclude = MV_ANY;
              else
                {
-                 lex_error (lexer, NULL);
+                 lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
                  goto error;
                }
            }
@@ -208,16 +194,12 @@ cmd_glm (struct lexer *lexer, struct dataset *ds)
                 && lex_token (lexer) != T_SLASH)
            {
              if (lex_match_id (lexer, "INCLUDE"))
-               {
-                 glm.intercept = true;
-               }
+                glm.intercept = true;
              else if (lex_match_id (lexer, "EXCLUDE"))
-               {
-                 glm.intercept = false;
-               }
+                glm.intercept = false;
              else
                {
-                 lex_error (lexer, NULL);
+                 lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
                  goto error;
                }
            }
@@ -225,118 +207,74 @@ cmd_glm (struct lexer *lexer, struct dataset *ds)
       else if (lex_match_id (lexer, "CRITERIA"))
        {
          lex_match (lexer, T_EQUALS);
-         if (lex_match_id (lexer, "ALPHA"))
-           {
-             if (lex_force_match (lexer, T_LPAREN))
-               {
-                 if (! lex_force_num (lexer))
-                   {
-                     lex_error (lexer, NULL);
-                     goto error;
-                   }
-
-                 glm.alpha = lex_number (lexer);
-                 lex_get (lexer);
-                 if (! lex_force_match (lexer, T_RPAREN))
-                   {
-                     lex_error (lexer, NULL);
-                     goto error;
-                   }
-               }
-           }
-         else
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
+         if (!lex_force_match_phrase (lexer, "ALPHA(")
+              || !lex_force_num (lexer))
+            goto error;
+          glm.alpha = lex_number (lexer);
+          lex_get (lexer);
+          if (!lex_force_match (lexer, T_RPAREN))
+            goto error;
        }
       else if (lex_match_id (lexer, "METHOD"))
        {
          lex_match (lexer, T_EQUALS);
-         if (!lex_force_match_id (lexer, "SSTYPE"))
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
-
-         if (! lex_force_match (lexer, T_LPAREN))
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
-
-         if (!lex_force_int_range (lexer, "SSTYPE", 1, 3))
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
+         if (!lex_force_match_phrase (lexer, "SSTYPE(")
+              || !lex_force_int_range (lexer, "SSTYPE", 1, 3))
+            goto error;
 
          glm.ss_type = lex_integer (lexer);
          lex_get (lexer);
 
-         if (! lex_force_match (lexer, T_RPAREN))
-           {
-             lex_error (lexer, NULL);
-             goto error;
-           }
+         if (!lex_force_match (lexer, T_RPAREN))
+            goto error;
        }
       else if (lex_match_id (lexer, "DESIGN"))
        {
          lex_match (lexer, T_EQUALS);
 
-         if (! parse_design_spec (lexer, &glm))
+         if (!parse_design_spec (lexer, &glm))
            goto error;
 
          if (glm.n_interactions > 0)
            design = true;
        }
       else if (lex_match_id (lexer, "SHOWCODES"))
-       /* Undocumented debug option */
        {
-         lex_match (lexer, T_EQUALS);
-
+          /* Undocumented debug option */
          glm.dump_coding = true;
        }
       else
        {
-         lex_error (lexer, NULL);
+         lex_error_expecting (lexer, "MISSING", "INTERCEPT", "CRITERIA",
+                               "METHOD", "DESIGN");
          goto error;
        }
     }
 
-  if (! design)
-    {
-      design_full (&glm);
-    }
+  if (!design)
+    design_full (&glm);
 
-  {
-    struct casegrouper *grouper;
-    struct casereader *group;
-    bool ok;
-
-    grouper = casegrouper_create_splits (proc_open (ds), glm.dict);
-    while (casegrouper_get_next_group (grouper, &group))
-      run_glm (&glm, group, ds);
-    ok = casegrouper_destroy (grouper);
-    ok = proc_commit (ds) && ok;
-  }
+  struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), glm.dict);
+  struct casereader *group;
+  while (casegrouper_get_next_group (grouper, &group))
+    run_glm (&glm, group, ds);
+  bool ok = casegrouper_destroy (grouper);
+  ok = proc_commit (ds) && ok;
 
   const_var_set_destroy (factors);
   free (glm.factor_vars);
-  for (i = 0 ; i < glm.n_interactions; ++i)
+  for (size_t i = 0; i < glm.n_interactions; ++i)
     interaction_destroy (glm.interactions[i]);
 
   free (glm.interactions);
   free (glm.dep_vars);
 
-
   return CMD_SUCCESS;
 
 error:
-
   const_var_set_destroy (factors);
   free (glm.factor_vars);
-  for (i = 0 ; i < glm.n_interactions; ++i)
+  for (size_t i = 0; i < glm.n_interactions; ++i)
     interaction_destroy (glm.interactions[i]);
 
   free (glm.interactions);
@@ -348,7 +286,7 @@ error:
 static inline bool
 not_dropped (size_t j, const bool *ff)
 {
-  return ! ff[j];
+  return !ff[j];
 }
 
 static void
@@ -843,15 +781,15 @@ static bool
 parse_nested_variable (struct lexer *lexer, struct glm_spec *glm)
 {
   const struct variable *v = NULL;
-  if (! lex_match_variable (lexer, glm->dict, &v))
+  if (!lex_match_variable (lexer, glm->dict, &v))
     return false;
 
   if (lex_match (lexer, T_LPAREN))
     {
-      if (! parse_nested_variable (lexer, glm))
+      if (!parse_nested_variable (lexer, glm))
        return false;
 
-      if (! lex_force_match (lexer, T_RPAREN))
+      if (!lex_force_match (lexer, T_RPAREN))
        return false;
     }
 
@@ -890,7 +828,7 @@ parse_design_spec (struct lexer *lexer, struct glm_spec *glm)
   if  (lex_token (lexer) == T_ENDCMD || lex_token (lexer) == T_SLASH)
     return true;
 
-  if (! parse_design_term (lexer, glm))
+  if (!parse_design_term (lexer, glm))
     return false;
 
   lex_match (lexer, T_COMMA);