ROC: Improve error messages and coding style.
authorBen Pfaff <blp@cs.stanford.edu>
Sun, 20 Nov 2022 22:14:04 +0000 (14:14 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Mon, 21 Nov 2022 00:37:15 +0000 (16:37 -0800)
src/language/stats/roc.c
tests/language/stats/roc.at

index 8b5cc3d0495e2abc4e16f1541ea1cb40a117a2bf..adfc5d0c8890b4f359ef1b21c32f49f80013df31 100644 (file)
@@ -63,8 +63,8 @@ struct cmd_roc
                      should be used */
   enum mv_class exclude;
 
-  bool invert ; /* True iff a smaller test result variable indicates
-                  a positive result */
+  bool invert; /* True iff a smaller test result variable indicates
+                  a positive result */
 
   double pos;
   double neg;
@@ -72,60 +72,43 @@ struct cmd_roc
   double neg_weighted;
 };
 
-static int run_roc (struct dataset *ds, struct cmd_roc *roc);
+static int run_roc (struct dataset *, struct cmd_roc *);
+static void do_roc (struct cmd_roc *, struct casereader *, struct dictionary *);
+
 
 int
 cmd_roc (struct lexer *lexer, struct dataset *ds)
 {
-  struct cmd_roc roc ;
   const struct dictionary *dict = dataset_dict (ds);
 
-  roc.vars = NULL;
-  roc.n_vars = 0;
-  roc.print_se = false;
-  roc.print_coords = false;
-  roc.exclude = MV_ANY;
-  roc.curve = true;
-  roc.reference = false;
-  roc.ci = 95;
-  roc.bi_neg_exp = false;
-  roc.invert = false;
-  roc.pos = roc.pos_weighted = 0;
-  roc.neg = roc.neg_weighted = 0;
-  roc.dict = dataset_dict (ds);
-  roc.state_var = NULL;
-  roc.state_var_width = -1;
+  struct cmd_roc roc = {
+    .exclude = MV_ANY,
+    .curve = true,
+    .ci = 95,
+    .dict = dict,
+    .state_var_width = -1,
+  };
 
   lex_match (lexer, T_SLASH);
   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
                              PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
     goto error;
 
-  if (! lex_force_match (lexer, T_BY))
-    {
-      goto error;
-    }
+  if (!lex_force_match (lexer, T_BY))
+    goto error;
 
   roc.state_var = parse_variable (lexer, dict);
-  if (! roc.state_var)
-    {
-      goto error;
-    }
+  if (!roc.state_var)
+    goto error;
 
   if (!lex_force_match (lexer, T_LPAREN))
-    {
-      goto error;
-    }
+    goto error;
 
   roc.state_var_width = var_get_width (roc.state_var);
   value_init (&roc.state_value, roc.state_var_width);
-  parse_value (lexer, &roc.state_value, roc.state_var);
-
-
-  if (!lex_force_match (lexer, T_RPAREN))
-    {
-      goto error;
-    }
+  if (!parse_value (lexer, &roc.state_value, roc.state_var)
+      || !lex_force_match (lexer, T_RPAREN))
+    goto error;
 
   while (lex_token (lexer) != T_ENDCMD)
     {
@@ -136,16 +119,12 @@ cmd_roc (struct lexer *lexer, struct dataset *ds)
           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
             {
              if (lex_match_id (lexer, "INCLUDE"))
-               {
-                 roc.exclude = MV_SYSTEM;
-               }
+                roc.exclude = MV_SYSTEM;
              else if (lex_match_id (lexer, "EXCLUDE"))
-               {
-                 roc.exclude = MV_ANY;
-               }
+                roc.exclude = MV_ANY;
              else
                {
-                  lex_error (lexer, NULL);
+                  lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
                  goto error;
                }
            }
@@ -159,19 +138,16 @@ cmd_roc (struct lexer *lexer, struct dataset *ds)
              if (lex_match (lexer, T_LPAREN))
                {
                  roc.reference = true;
-                 if (! lex_force_match_id (lexer, "REFERENCE"))
-                   goto error;
-                 if (! lex_force_match (lexer, T_RPAREN))
+                 if (!lex_force_match_id (lexer, "REFERENCE")
+                      || !lex_force_match (lexer, T_RPAREN))
                    goto error;
                }
            }
          else if (lex_match_id (lexer, "NONE"))
-           {
-             roc.curve = false;
-           }
+            roc.curve = false;
          else
            {
-             lex_error (lexer, NULL);
+             lex_error_expecting (lexer, "CURVE", "NONE");
              goto error;
            }
        }
@@ -181,16 +157,12 @@ cmd_roc (struct lexer *lexer, struct dataset *ds)
           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
            {
              if (lex_match_id (lexer, "SE"))
-               {
-                 roc.print_se = true;
-               }
+                roc.print_se = true;
              else if (lex_match_id (lexer, "COORDINATES"))
-               {
-                 roc.print_coords = true;
-               }
+                roc.print_coords = true;
              else
                {
-                 lex_error (lexer, NULL);
+                 lex_error_expecting (lexer, "SE", "COORDINATES");
                  goto error;
                }
            }
@@ -202,49 +174,41 @@ cmd_roc (struct lexer *lexer, struct dataset *ds)
            {
              if (lex_match_id (lexer, "CUTOFF"))
                {
-                 if (! lex_force_match (lexer, T_LPAREN))
+                 if (!lex_force_match (lexer, T_LPAREN))
                    goto error;
                  if (lex_match_id (lexer, "INCLUDE"))
-                   {
-                     roc.exclude = MV_SYSTEM;
-                   }
+                    roc.exclude = MV_SYSTEM;
                  else if (lex_match_id (lexer, "EXCLUDE"))
-                   {
-                     roc.exclude = MV_USER | MV_SYSTEM;
-                   }
+                    roc.exclude = MV_USER | MV_SYSTEM;
                  else
                    {
-                     lex_error (lexer, NULL);
+                     lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
                      goto error;
                    }
-                 if (! lex_force_match (lexer, T_RPAREN))
+                 if (!lex_force_match (lexer, T_RPAREN))
                    goto error;
                }
              else if (lex_match_id (lexer, "TESTPOS"))
                {
-                 if (! lex_force_match (lexer, T_LPAREN))
+                 if (!lex_force_match (lexer, T_LPAREN))
                    goto error;
                  if (lex_match_id (lexer, "LARGE"))
-                   {
-                     roc.invert = false;
-                   }
+                    roc.invert = false;
                  else if (lex_match_id (lexer, "SMALL"))
-                   {
-                     roc.invert = true;
-                   }
+                    roc.invert = true;
                  else
                    {
-                     lex_error (lexer, NULL);
+                     lex_error_expecting (lexer, "LARGE", "SMALL");
                      goto error;
                    }
-                 if (! lex_force_match (lexer, T_RPAREN))
+                 if (!lex_force_match (lexer, T_RPAREN))
                    goto error;
                }
              else if (lex_match_id (lexer, "CI"))
                {
                  if (!lex_force_match (lexer, T_LPAREN))
                    goto error;
-                 if (! lex_force_num (lexer))
+                 if (!lex_force_num (lexer))
                    goto error;
                  roc.ci = lex_number (lexer);
                  lex_get (lexer);
@@ -256,16 +220,12 @@ cmd_roc (struct lexer *lexer, struct dataset *ds)
                  if (!lex_force_match (lexer, T_LPAREN))
                    goto error;
                  if (lex_match_id (lexer, "FREE"))
-                   {
-                     roc.bi_neg_exp = false;
-                   }
+                    roc.bi_neg_exp = false;
                  else if (lex_match_id (lexer, "NEGEXPO"))
-                   {
-                     roc.bi_neg_exp = true;
-                   }
+                    roc.bi_neg_exp = true;
                  else
                    {
-                     lex_error (lexer, NULL);
+                     lex_error_expecting (lexer, "FREE", "NEGEXPO");
                      goto error;
                    }
                  if (!lex_force_match (lexer, T_RPAREN))
@@ -273,19 +233,20 @@ cmd_roc (struct lexer *lexer, struct dataset *ds)
                }
              else
                {
-                 lex_error (lexer, NULL);
+                 lex_error_expecting (lexer, "CUTOFF", "TESTPOS", "CI",
+                                       "DISTRIBUTION");
                  goto error;
                }
            }
        }
       else
        {
-         lex_error (lexer, NULL);
-         break;
+         lex_error_expecting (lexer, "MISSING", "PLOT", "PRINT", "CRITERIA");
+         goto error;
        }
     }
 
-  if (! run_roc (ds, &roc))
+  if (!run_roc (ds, &roc))
     goto error;
 
   if (roc.state_var)
@@ -300,28 +261,18 @@ cmd_roc (struct lexer *lexer, struct dataset *ds)
   return CMD_FAILURE;
 }
 
-
-
-
-static void
-do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
-
-
 static int
 run_roc (struct dataset *ds, struct cmd_roc *roc)
 {
   struct dictionary *dict = dataset_dict (ds);
-  bool ok;
   struct casereader *group;
 
   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
   while (casegrouper_get_next_group (grouper, &group))
-    {
-      do_roc (roc, group, dataset_dict (ds));
-    }
-  ok = casegrouper_destroy (grouper);
-  ok = proc_commit (ds) && ok;
+    do_roc (roc, group, dataset_dict (ds));
 
+  bool ok = casegrouper_destroy (grouper);
+  ok = proc_commit (ds) && ok;
   return ok;
 }
 
@@ -334,8 +285,7 @@ dump_casereader (struct casereader *reader)
 
   for (; (c = casereader_read (r)); case_unref (c))
     {
-      int i;
-      for (i = 0 ; i < case_get_n_values (c); ++i)
+      for (size_t i = 0; i < case_get_n_values (c); ++i)
         printf ("%g ", case_num_idx (c, i));
       printf ("\n");
     }
@@ -541,7 +491,6 @@ process_group (const struct variable *var, struct casereader *reader,
       casereader_destroy (r2);
     }
 
-
   casereader_destroy (r1);
   casereader_destroy (rclone);
 
@@ -629,167 +578,128 @@ append_cutpoint (struct casewriter *writer, double cutpoint)
   casewriter_write (writer, cc);
 }
 
-
 /*
-   Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
-   be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
-   reader will be populated with its final number of cases.
-   However on exit from this function, only ROC_CUTPOINT entries will be set to their final
-   value.  The other entries will be initialised to zero.
+   Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the
+   readers will be created with width 5, ready to take the values (cutpoint,
+   ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the reader will be populated with its
+   final number of cases.  However on exit from this function, only
+   ROC_CUTPOINT entries will be set to their final value.  The other entries
+   will be initialised to zero.
 */
-static void
-prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
+static struct roc_state *
+prepare_cutpoints (struct cmd_roc *roc, struct casereader *input)
 {
-  int i;
   struct casereader *r = casereader_clone (input);
   struct ccase *c;
 
-  {
-    struct caseproto *proto = caseproto_create ();
-    struct subcase ordering;
-    subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
+  struct subcase ordering;
+  subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
 
-    proto = caseproto_add_width (proto, 0); /* cutpoint */
-    proto = caseproto_add_width (proto, 0); /* ROC_TP */
-    proto = caseproto_add_width (proto, 0); /* ROC_FN */
-    proto = caseproto_add_width (proto, 0); /* ROC_TN */
-    proto = caseproto_add_width (proto, 0); /* ROC_FP */
+  struct caseproto *proto = caseproto_create ();
+  proto = caseproto_add_width (proto, 0); /* cutpoint */
+  proto = caseproto_add_width (proto, 0); /* ROC_TP */
+  proto = caseproto_add_width (proto, 0); /* ROC_FN */
+  proto = caseproto_add_width (proto, 0); /* ROC_TN */
+  proto = caseproto_add_width (proto, 0); /* ROC_FP */
 
-    for (i = 0 ; i < roc->n_vars; ++i)
-      {
-       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
-       rs[i].prev_result = SYSMIS;
-       rs[i].max = -DBL_MAX;
-       rs[i].min = DBL_MAX;
-      }
+  struct roc_state *rs = xnmalloc (roc->n_vars, sizeof *rs);
+  for (size_t i = 0; i < roc->n_vars; ++i)
+    rs[i] = (struct roc_state) {
+      .cutpoint_wtr = sort_create_writer (&ordering, proto),
+      .prev_result = SYSMIS,
+      .max = -DBL_MAX,
+      .min = DBL_MAX,
+    };
 
-    caseproto_unref (proto);
-    subcase_uninit (&ordering);
-  }
+  caseproto_unref (proto);
+  subcase_uninit (&ordering);
 
   for (; (c = casereader_read (r)) != NULL; case_unref (c))
-    {
-      for (i = 0 ; i < roc->n_vars; ++i)
-       {
-         const union value *v = case_data (c, roc->vars[i]);
-         const double result = v->f;
+    for (size_t i = 0; i < roc->n_vars; ++i)
+      {
+        const union value *v = case_data (c, roc->vars[i]);
+        const double result = v->f;
 
-         if (mv_is_value_missing (var_get_missing_values (roc->vars[i]), v)
-              & roc->exclude)
-           continue;
+        if (mv_is_value_missing (var_get_missing_values (roc->vars[i]), v)
+            & roc->exclude)
+          continue;
 
-         minimize (&rs[i].min, result);
-         maximize (&rs[i].max, result);
+        minimize (&rs[i].min, result);
+        maximize (&rs[i].max, result);
 
-         if (rs[i].prev_result != SYSMIS && rs[i].prev_result != result)
-           {
-             const double mean = (result + rs[i].prev_result) / 2.0;
-             append_cutpoint (rs[i].cutpoint_wtr, mean);
-           }
+        if (rs[i].prev_result != SYSMIS && rs[i].prev_result != result)
+          {
+            const double mean = (result + rs[i].prev_result) / 2.0;
+            append_cutpoint (rs[i].cutpoint_wtr, mean);
+          }
 
-         rs[i].prev_result = result;
-       }
-    }
+        rs[i].prev_result = result;
+      }
   casereader_destroy (r);
 
-
   /* Append the min and max cutpoints */
-  for (i = 0 ; i < roc->n_vars; ++i)
+  for (size_t i = 0; i < roc->n_vars; ++i)
     {
       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
 
       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
     }
+
+  return rs;
 }
 
 static void
 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
 {
-  int i;
+  struct casereader *input = casereader_create_filter_missing (
+    reader, roc->vars, roc->n_vars, roc->exclude, NULL, NULL);
+  input = casereader_create_filter_missing (
+    input, &roc->state_var, 1, roc->exclude, NULL, NULL);
 
-  struct roc_state *rs = XCALLOC (roc->n_vars,  struct roc_state);
+  struct roc_state *rs = prepare_cutpoints (roc, input);
 
-  struct casereader *negatives = NULL;
-  struct casereader *positives = NULL;
+  /* Separate the positive actual state cases from the negative ones */
+  struct casewriter *neg_wtr
+    = autopaging_writer_create (casereader_get_proto (input));
+  struct casereader *positives = casereader_create_filter_func (
+    input, match_positives, NULL, roc, neg_wtr);
 
-  struct caseproto *n_proto = NULL;
+  struct caseproto *n_proto = caseproto_create ();
+  for (size_t i = 0; i < 5; i++)
+    n_proto = caseproto_add_width (n_proto, 0);
 
   struct subcase up_ordering;
   struct subcase down_ordering;
-
-  struct casewriter *neg_wtr = NULL;
-
-  struct casereader *input = casereader_create_filter_missing (reader,
-                                                              roc->vars, roc->n_vars,
-                                                              roc->exclude,
-                                                              NULL,
-                                                              NULL);
-
-  input = casereader_create_filter_missing (input,
-                                           &roc->state_var, 1,
-                                           roc->exclude,
-                                           NULL,
-                                           NULL);
-
-  neg_wtr = autopaging_writer_create (casereader_get_proto (input));
-
-  prepare_cutpoints (roc, rs, input);
-
-
-  /* Separate the positive actual state cases from the negative ones */
-  positives =
-    casereader_create_filter_func (input,
-                                  match_positives,
-                                  NULL,
-                                  roc,
-                                  neg_wtr);
-
-  n_proto = caseproto_create ();
-
-  n_proto = caseproto_add_width (n_proto, 0);
-  n_proto = caseproto_add_width (n_proto, 0);
-  n_proto = caseproto_add_width (n_proto, 0);
-  n_proto = caseproto_add_width (n_proto, 0);
-  n_proto = caseproto_add_width (n_proto, 0);
-
   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
 
-  for (i = 0 ; i < roc->n_vars; ++i)
+  struct casereader *negatives = NULL;
+  for (size_t i = 0; i < roc->n_vars; ++i)
     {
-      struct casewriter *w = NULL;
-      struct casereader *r = NULL;
-
-      struct ccase *c;
-
-      struct ccase *cpos;
-      struct casereader *n_neg_reader ;
       const struct variable *var = roc->vars[i];
 
-      struct casereader *neg ;
       struct casereader *pos = casereader_clone (positives);
 
       struct casereader *n_pos_reader =
        process_positive_group (var, pos, dict, &rs[i]);
 
-      if (negatives == NULL)
-       {
-         negatives = casewriter_make_reader (neg_wtr);
-       }
+      if (!negatives)
+        negatives = casewriter_make_reader (neg_wtr);
 
-      neg = casereader_clone (negatives);
-
-      n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
+      struct casereader *neg = casereader_clone (negatives);
+      struct casereader *n_neg_reader
+        = process_negative_group (var, neg, dict, &rs[i]);
 
       /* Merge the n_pos and n_neg casereaders */
-      w = sort_create_writer (&up_ordering, n_proto);
+      struct casewriter *w = sort_create_writer (&up_ordering, n_proto);
+      struct ccase *cpos;
       for (; (cpos = casereader_read (n_pos_reader)); case_unref (cpos))
        {
          struct ccase *pos_case = case_create (n_proto);
-         struct ccase *cneg;
          const double jpos = case_num_idx (cpos, VALUE);
 
+         struct ccase *cneg;
          while ((cneg = casereader_read (n_neg_reader)))
            {
              struct ccase *nc = case_create (n_proto);
@@ -823,122 +733,112 @@ do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
       casereader_destroy (n_pos_reader);
       casereader_destroy (n_neg_reader);
 
-/* These aren't used anymore */
-#undef N_EQ
-#undef N_PRED
-
-      r = casewriter_make_reader (w);
+      struct casereader *r = casewriter_make_reader (w);
 
       /* Propagate the N_POS_GT values from the positive cases
         to the negative ones */
-      {
-       double prev_pos_gt = rs[i].n1;
-       w = sort_create_writer (&down_ordering, n_proto);
-
-       for (; (c = casereader_read (r)); case_unref (c))
-         {
-           double n_pos_gt = case_num_idx (c, N_POS_GT);
-           struct ccase *nc = case_clone (c);
-
-           if (n_pos_gt == SYSMIS)
-             {
-               n_pos_gt = prev_pos_gt;
-               *case_num_rw_idx (nc, N_POS_GT) = n_pos_gt;
-             }
-
-           casewriter_write (w, nc);
-           prev_pos_gt = n_pos_gt;
-         }
-
-       casereader_destroy (r);
-       r = casewriter_make_reader (w);
-      }
+      double prev_pos_gt = rs[i].n1;
+      w = sort_create_writer (&down_ordering, n_proto);
+
+      struct ccase *c;
+      for (; (c = casereader_read (r)); case_unref (c))
+        {
+          double n_pos_gt = case_num_idx (c, N_POS_GT);
+          struct ccase *nc = case_clone (c);
+
+          if (n_pos_gt == SYSMIS)
+            {
+              n_pos_gt = prev_pos_gt;
+              *case_num_rw_idx (nc, N_POS_GT) = n_pos_gt;
+            }
+
+          casewriter_write (w, nc);
+          prev_pos_gt = n_pos_gt;
+        }
+      casereader_destroy (r);
+      r = casewriter_make_reader (w);
 
       /* Propagate the N_NEG_LT values from the negative cases
         to the positive ones */
-      {
-       double prev_neg_lt = rs[i].n2;
-       w = sort_create_writer (&up_ordering, n_proto);
-
-       for (; (c = casereader_read (r)); case_unref (c))
-         {
-           double n_neg_lt = case_num_idx (c, N_NEG_LT);
-           struct ccase *nc = case_clone (c);
-
-           if (n_neg_lt == SYSMIS)
-             {
-               n_neg_lt = prev_neg_lt;
-               *case_num_rw_idx (nc, N_NEG_LT) = n_neg_lt;
-             }
-
-           casewriter_write (w, nc);
-           prev_neg_lt = n_neg_lt;
-         }
-
-       casereader_destroy (r);
-       r = casewriter_make_reader (w);
-      }
+      double prev_neg_lt = rs[i].n2;
+      w = sort_create_writer (&up_ordering, n_proto);
 
-      {
-       struct ccase *prev_case = NULL;
-       for (; (c = casereader_read (r)); case_unref (c))
-         {
-           struct ccase *next_case = casereader_peek (r, 0);
-
-           const double j = case_num_idx (c, VALUE);
-           double n_pos_eq = case_num_idx (c, N_POS_EQ);
-           double n_pos_gt = case_num_idx (c, N_POS_GT);
-           double n_neg_eq = case_num_idx (c, N_NEG_EQ);
-           double n_neg_lt = case_num_idx (c, N_NEG_LT);
-
-           if (prev_case && j == case_num_idx (prev_case, VALUE))
-             {
-               if (0 ==  case_num_idx (c, N_POS_EQ))
-                 {
-                   n_pos_eq = case_num_idx (prev_case, N_POS_EQ);
-                   n_pos_gt = case_num_idx (prev_case, N_POS_GT);
-                 }
-
-               if (0 ==  case_num_idx (c, N_NEG_EQ))
-                 {
-                   n_neg_eq = case_num_idx (prev_case, N_NEG_EQ);
-                   n_neg_lt = case_num_idx (prev_case, N_NEG_LT);
-                 }
-             }
-
-           if (NULL == next_case || j != case_num_idx (next_case, VALUE))
-             {
-               rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
-
-               rs[i].q1hat +=
-                 n_neg_eq * (pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
-               rs[i].q2hat +=
-                 n_pos_eq * (pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
-
-             }
-
-           case_unref (next_case);
-           case_unref (prev_case);
-           prev_case = case_clone (c);
-         }
-       casereader_destroy (r);
-       case_unref (prev_case);
-
-       rs[i].auc /=  rs[i].n1 * rs[i].n2;
-       if (roc->invert)
-         rs[i].auc = 1 - rs[i].auc;
-
-       if (roc->bi_neg_exp)
-         {
-           rs[i].q1hat = rs[i].auc / (2 - rs[i].auc);
-           rs[i].q2hat = 2 * pow2 (rs[i].auc) / (1 + rs[i].auc);
-         }
-       else
-         {
-           rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
-           rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
-         }
-      }
+      for (; (c = casereader_read (r)); case_unref (c))
+        {
+          double n_neg_lt = case_num_idx (c, N_NEG_LT);
+          struct ccase *nc = case_clone (c);
+
+          if (n_neg_lt == SYSMIS)
+            {
+              n_neg_lt = prev_neg_lt;
+              *case_num_rw_idx (nc, N_NEG_LT) = n_neg_lt;
+            }
+
+          casewriter_write (w, nc);
+          prev_neg_lt = n_neg_lt;
+        }
+
+      casereader_destroy (r);
+      r = casewriter_make_reader (w);
+
+      struct ccase *prev_case = NULL;
+      for (; (c = casereader_read (r)); case_unref (c))
+        {
+          struct ccase *next_case = casereader_peek (r, 0);
+
+          const double j = case_num_idx (c, VALUE);
+          double n_pos_eq = case_num_idx (c, N_POS_EQ);
+          double n_pos_gt = case_num_idx (c, N_POS_GT);
+          double n_neg_eq = case_num_idx (c, N_NEG_EQ);
+          double n_neg_lt = case_num_idx (c, N_NEG_LT);
+
+          if (prev_case && j == case_num_idx (prev_case, VALUE))
+            {
+              if (0 ==  case_num_idx (c, N_POS_EQ))
+                {
+                  n_pos_eq = case_num_idx (prev_case, N_POS_EQ);
+                  n_pos_gt = case_num_idx (prev_case, N_POS_GT);
+                }
+
+              if (0 ==  case_num_idx (c, N_NEG_EQ))
+                {
+                  n_neg_eq = case_num_idx (prev_case, N_NEG_EQ);
+                  n_neg_lt = case_num_idx (prev_case, N_NEG_LT);
+                }
+            }
+
+          if (NULL == next_case || j != case_num_idx (next_case, VALUE))
+            {
+              rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
+
+              rs[i].q1hat +=
+                n_neg_eq * (pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
+              rs[i].q2hat +=
+                n_pos_eq * (pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
+
+            }
+
+          case_unref (next_case);
+          case_unref (prev_case);
+          prev_case = case_clone (c);
+        }
+      casereader_destroy (r);
+      case_unref (prev_case);
+
+      rs[i].auc /= rs[i].n1 * rs[i].n2;
+      if (roc->invert)
+        rs[i].auc = 1 - rs[i].auc;
+
+      if (roc->bi_neg_exp)
+        {
+          rs[i].q1hat = rs[i].auc / (2 - rs[i].auc);
+          rs[i].q2hat = 2 * pow2 (rs[i].auc) / (1 + rs[i].auc);
+        }
+      else
+        {
+          rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
+          rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
+        }
     }
 
   casereader_destroy (positives);
@@ -950,14 +850,14 @@ do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
 
   output_roc (rs, roc);
 
-  for (i = 0 ; i < roc->n_vars; ++i)
+  for (size_t i = 0; i < roc->n_vars; ++i)
     casereader_destroy (rs[i].cutpoint_rdr);
 
   free (rs);
 }
 
 static void
-show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
+show_auc (struct roc_state *rs, const struct cmd_roc *roc)
 {
   struct pivot_table *table = pivot_table_create (N_("Area Under the Curve"));
 
@@ -983,7 +883,7 @@ show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
     table, PIVOT_AXIS_ROW, N_("Variable under test"));
   variables->root->show_label = true;
 
-  for (size_t i = 0 ; i < roc->n_vars ; ++i)
+  for (size_t i = 0; i < roc->n_vars; ++i)
     {
       int var_idx = pivot_category_create_leaf (
         variables->root, pivot_value_new_variable (roc->vars[i]));
@@ -1015,7 +915,6 @@ show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
   pivot_table_submit (table);
 }
 
-
 static void
 show_summary (const struct cmd_roc *roc)
 {
@@ -1115,7 +1014,6 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc)
   pivot_table_submit (table);
 }
 
-
 static void
 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
 {
@@ -1123,11 +1021,8 @@ output_roc (struct roc_state *rs, const struct cmd_roc *roc)
 
   if (roc->curve)
     {
-      struct roc_chart *rc;
-      size_t i;
-
-      rc = roc_chart_create (roc->reference);
-      for (i = 0; i < roc->n_vars; i++)
+      struct roc_chart *rc = roc_chart_create (roc->reference);
+      for (size_t i = 0; i < roc->n_vars; i++)
         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
                            rs[i].cutpoint_rdr);
       roc_chart_submit (rc);
index 3ac61d01f951f97e5791812a13ebd808c4515f4d..fec6fbdd77204fd73318b5d7835e5c05f8ccb489 100644 (file)
@@ -217,3 +217,135 @@ roc x y by a (1)
 AT_CHECK([pspp -O format=csv roc.sps], [1], [ignore])
 
 AT_CLEANUP
+
+AT_SETUP([ROC syntax errors])
+AT_DATA([roc.sps], [dnl
+DATA LIST LIST NOTABLE/x y z.
+ROC **.
+ROC x **.
+ROC x BY **.
+ROC x BY y **.
+ROC x BY y(**).
+ROC x BY y(5 **).
+ROC x BY y(5)/MISSING=**.
+ROC x BY y(5)/PLOT=CURVE(**).
+ROC x BY y(5)/PLOT=CURVE(REFERENCE **).
+ROC x BY y(5)/PLOT=**.
+ROC x BY y(5)/PRINT=**.
+ROC x BY y(5)/CRITERIA=CUTOFF **.
+ROC x BY y(5)/CRITERIA=CUTOFF(**).
+ROC x BY y(5)/CRITERIA=CUTOFF(INCLUDE **).
+ROC x BY y(5)/CRITERIA=TESTPOS **.
+ROC x BY y(5)/CRITERIA=TESTPOS(**).
+ROC x BY y(5)/CRITERIA=TESTPOS(LARGE **).
+ROC x BY y(5)/CRITERIA=CI **.
+ROC x BY y(5)/CRITERIA=CI(**).
+ROC x BY y(5)/CRITERIA=CI(5 **).
+ROC x BY y(5)/CRITERIA=DISTRIBUTION **.
+ROC x BY y(5)/CRITERIA=DISTRIBUTION(**).
+ROC x BY y(5)/CRITERIA=DISTRIBUTION(FREE **).
+ROC x BY y(5)/CRITERIA=**.
+ROC x BY y(5)/ **.
+])
+AT_CHECK([pspp -O format=csv roc.sps], [1], [dnl
+"roc.sps:2.5-2.6: error: ROC: Syntax error expecting variable name.
+    2 | ROC **.
+      |     ^~"
+
+"roc.sps:3.7-3.8: error: ROC: Syntax error expecting `BY'.
+    3 | ROC x **.
+      |       ^~"
+
+"roc.sps:4.10-4.11: error: ROC: Syntax error expecting variable name.
+    4 | ROC x BY **.
+      |          ^~"
+
+"roc.sps:5.12-5.13: error: ROC: Syntax error expecting `('.
+    5 | ROC x BY y **.
+      |            ^~"
+
+"roc.sps:6.12-6.13: error: ROC: Syntax error expecting number.
+    6 | ROC x BY y(**).
+      |            ^~"
+
+"roc.sps:7.14-7.15: error: ROC: Syntax error expecting `)'.
+    7 | ROC x BY y(5 **).
+      |              ^~"
+
+"roc.sps:8.23-8.24: error: ROC: Syntax error expecting INCLUDE or EXCLUDE.
+    8 | ROC x BY y(5)/MISSING=**.
+      |                       ^~"
+
+"roc.sps:9.26-9.27: error: ROC: Syntax error expecting REFERENCE.
+    9 | ROC x BY y(5)/PLOT=CURVE(**).
+      |                          ^~"
+
+"roc.sps:10.36-10.37: error: ROC: Syntax error expecting `@:}@'.
+   10 | ROC x BY y(5)/PLOT=CURVE(REFERENCE **).
+      |                                    ^~"
+
+"roc.sps:11.20-11.21: error: ROC: Syntax error expecting CURVE or NONE.
+   11 | ROC x BY y(5)/PLOT=**.
+      |                    ^~"
+
+"roc.sps:12.21-12.22: error: ROC: Syntax error expecting SE or COORDINATES.
+   12 | ROC x BY y(5)/PRINT=**.
+      |                     ^~"
+
+"roc.sps:13.31-13.32: error: ROC: Syntax error expecting `('.
+   13 | ROC x BY y(5)/CRITERIA=CUTOFF **.
+      |                               ^~"
+
+"roc.sps:14.31-14.32: error: ROC: Syntax error expecting INCLUDE or EXCLUDE.
+   14 | ROC x BY y(5)/CRITERIA=CUTOFF(**).
+      |                               ^~"
+
+"roc.sps:15.39-15.40: error: ROC: Syntax error expecting `)'.
+   15 | ROC x BY y(5)/CRITERIA=CUTOFF(INCLUDE **).
+      |                                       ^~"
+
+"roc.sps:16.32-16.33: error: ROC: Syntax error expecting `('.
+   16 | ROC x BY y(5)/CRITERIA=TESTPOS **.
+      |                                ^~"
+
+"roc.sps:17.32-17.33: error: ROC: Syntax error expecting LARGE or SMALL.
+   17 | ROC x BY y(5)/CRITERIA=TESTPOS(**).
+      |                                ^~"
+
+"roc.sps:18.38-18.39: error: ROC: Syntax error expecting `)'.
+   18 | ROC x BY y(5)/CRITERIA=TESTPOS(LARGE **).
+      |                                      ^~"
+
+"roc.sps:19.27-19.28: error: ROC: Syntax error expecting `('.
+   19 | ROC x BY y(5)/CRITERIA=CI **.
+      |                           ^~"
+
+"roc.sps:20.27-20.28: error: ROC: Syntax error expecting number.
+   20 | ROC x BY y(5)/CRITERIA=CI(**).
+      |                           ^~"
+
+"roc.sps:21.29-21.30: error: ROC: Syntax error expecting `)'.
+   21 | ROC x BY y(5)/CRITERIA=CI(5 **).
+      |                             ^~"
+
+"roc.sps:22.37-22.38: error: ROC: Syntax error expecting `('.
+   22 | ROC x BY y(5)/CRITERIA=DISTRIBUTION **.
+      |                                     ^~"
+
+"roc.sps:23.37-23.38: error: ROC: Syntax error expecting FREE or NEGEXPO.
+   23 | ROC x BY y(5)/CRITERIA=DISTRIBUTION(**).
+      |                                     ^~"
+
+"roc.sps:24.42-24.43: error: ROC: Syntax error expecting `)'.
+   24 | ROC x BY y(5)/CRITERIA=DISTRIBUTION(FREE **).
+      |                                          ^~"
+
+"roc.sps:25.24-25.25: error: ROC: Syntax error expecting CUTOFF, TESTPOS, CI, or DISTRIBUTION.
+   25 | ROC x BY y(5)/CRITERIA=**.
+      |                        ^~"
+
+"roc.sps:26.16-26.17: error: ROC: Syntax error expecting MISSING, PLOT, PRINT, or CRITERIA.
+   26 | ROC x BY y(5)/ **.
+      |                ^~"
+])
+AT_CLEANUP
\ No newline at end of file