Fix bug when positive and negative groups are of different lengths
[pspp-builds.git] / src / language / stats / roc.c
index 46d2da9ed3a645b914fdfa14300941425e60ae41..b3841634378f30503c5e614f3387faeb3c3d144d 100644 (file)
@@ -26,6 +26,7 @@
 #include <data/casereader.h>
 #include <data/casewriter.h>
 #include <data/dictionary.h>
+#include <data/format.h>
 #include <math/sort.h>
 
 #include <libpspp/misc.h>
@@ -324,6 +325,9 @@ struct roc_state
 
   double n1;
   double n2;
+
+  double q1hat;
+  double q2hat;
 };
 
 
@@ -440,9 +444,6 @@ do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict)
 
   for (i = 0 ; i < roc->n_vars; ++i)
     {
-      double q1hat = 0;
-      double q2hat = 0;
-
       struct ccase *cpos;
       struct casereader *n_neg ;
       const struct variable *var = roc->vars[i];
@@ -473,6 +474,8 @@ do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict)
                case_unref (cneg);
 
              cneg = casereader_read (n_neg);
+             if ( ! cneg )
+               break;
              dneg = case_data_idx (cneg, VALUE)->f;
            }
        
@@ -484,8 +487,10 @@ do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict)
              double n_neg_lt = case_data_idx (cneg, N_PRED)->f;
 
              rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
-             q1hat += n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
-             q2hat += n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
+             rs[i].q1hat +=
+               n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
+             rs[i].q2hat +=
+               n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
            }
 
          if ( cneg )
@@ -496,7 +501,16 @@ do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict)
       if ( roc->invert ) 
        rs[i].auc = 1 - rs[i].auc;
 
-
+      if ( roc->bi_neg_exp )
+       {
+         rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
+         rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
+       }
+      else
+       {
+         rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
+         rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
+       }
     }
 
   casereader_destroy (positives);
@@ -576,11 +590,8 @@ show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
          double ci ;
          double yy ;
 
-         double q1 = rs[i].auc / ( 2 - rs[i].auc);
-         double q2 = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
-
-         se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (q1 - pow2 (rs[i].auc)) +
-           (rs[i].n2 - 1) * (q2 - pow2 (rs[i].auc));
+         se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
+           (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
 
          se /= rs[i].n1 * rs[i].n2;
 
@@ -611,14 +622,66 @@ show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
 }
 
 
+static void
+show_summary (const struct cmd_roc *roc)
+{
+  const int n_cols = 3;
+  const int n_rows = 4;
+  struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
+
+  tab_title (tbl, _("Case Summary"));
+
+  tab_headers (tbl, 1, 0, 2, 0);
+
+  tab_dim (tbl, tab_natural_dimensions, NULL);
+
+  tab_box (tbl,
+          TAL_2, TAL_2,
+          -1, -1,
+          0, 0,
+          n_cols - 1,
+          n_rows - 1);
+
+  tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
+  tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
+
+
+  tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
+  tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
+
+
+  tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
+  tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
+  tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
+
+  tab_joint_text (tbl, 1, 0, 2, 0,
+                 TAT_TITLE | TAB_CENTER,
+                 _("Valid N (listwise)"));
+
+
+  tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
+  tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
+
+
+#if 0
+  tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
+  tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
+
+  tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
+  tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
+#endif
+
+  tab_submit (tbl);
+}
 
 
 static void
 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
 {
-#if 0
   show_summary (roc);
 
+#if 0
+
   if ( roc->curve )
     draw_roc (rs, roc);
 #endif