treewide: Replace <name>_cnt by n_<name>s and <name>_cap by allocated_<name>.
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/roc.h"
20
21 #include <gsl/gsl_cdf.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/charts/roc-chart.h"
37 #include "output/pivot-table.h"
38
39 #include "gettext.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) msgid
42
43 struct cmd_roc
44 {
45   size_t n_vars;
46   const struct variable **vars;
47   const struct dictionary *dict;
48
49   const struct variable *state_var;
50   union value state_value;
51   size_t state_var_width;
52
53   /* Plot the roc curve */
54   bool curve;
55   /* Plot the reference line */
56   bool reference;
57
58   double ci;
59
60   bool print_coords;
61   bool print_se;
62   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
63                       should be used */
64   enum mv_class exclude;
65
66   bool invert ; /* True iff a smaller test result variable indicates
67                    a positive result */
68
69   double pos;
70   double neg;
71   double pos_weighted;
72   double neg_weighted;
73 };
74
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
76
77 int
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
79 {
80   struct cmd_roc roc ;
81   const struct dictionary *dict = dataset_dict (ds);
82
83   roc.vars = NULL;
84   roc.n_vars = 0;
85   roc.print_se = false;
86   roc.print_coords = false;
87   roc.exclude = MV_ANY;
88   roc.curve = true;
89   roc.reference = false;
90   roc.ci = 95;
91   roc.bi_neg_exp = false;
92   roc.invert = false;
93   roc.pos = roc.pos_weighted = 0;
94   roc.neg = roc.neg_weighted = 0;
95   roc.dict = dataset_dict (ds);
96   roc.state_var = NULL;
97   roc.state_var_width = -1;
98
99   lex_match (lexer, T_SLASH);
100   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
101                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
102     goto error;
103
104   if (! lex_force_match (lexer, T_BY))
105     {
106       goto error;
107     }
108
109   roc.state_var = parse_variable (lexer, dict);
110   if (! roc.state_var)
111     {
112       goto error;
113     }
114
115   if (!lex_force_match (lexer, T_LPAREN))
116     {
117       goto error;
118     }
119
120   roc.state_var_width = var_get_width (roc.state_var);
121   value_init (&roc.state_value, roc.state_var_width);
122   parse_value (lexer, &roc.state_value, roc.state_var);
123
124
125   if (!lex_force_match (lexer, T_RPAREN))
126     {
127       goto error;
128     }
129
130   while (lex_token (lexer) != T_ENDCMD)
131     {
132       lex_match (lexer, T_SLASH);
133       if (lex_match_id (lexer, "MISSING"))
134         {
135           lex_match (lexer, T_EQUALS);
136           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
137             {
138               if (lex_match_id (lexer, "INCLUDE"))
139                 {
140                   roc.exclude = MV_SYSTEM;
141                 }
142               else if (lex_match_id (lexer, "EXCLUDE"))
143                 {
144                   roc.exclude = MV_ANY;
145                 }
146               else
147                 {
148                   lex_error (lexer, NULL);
149                   goto error;
150                 }
151             }
152         }
153       else if (lex_match_id (lexer, "PLOT"))
154         {
155           lex_match (lexer, T_EQUALS);
156           if (lex_match_id (lexer, "CURVE"))
157             {
158               roc.curve = true;
159               if (lex_match (lexer, T_LPAREN))
160                 {
161                   roc.reference = true;
162                   if (! lex_force_match_id (lexer, "REFERENCE"))
163                     goto error;
164                   if (! lex_force_match (lexer, T_RPAREN))
165                     goto error;
166                 }
167             }
168           else if (lex_match_id (lexer, "NONE"))
169             {
170               roc.curve = false;
171             }
172           else
173             {
174               lex_error (lexer, NULL);
175               goto error;
176             }
177         }
178       else if (lex_match_id (lexer, "PRINT"))
179         {
180           lex_match (lexer, T_EQUALS);
181           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
182             {
183               if (lex_match_id (lexer, "SE"))
184                 {
185                   roc.print_se = true;
186                 }
187               else if (lex_match_id (lexer, "COORDINATES"))
188                 {
189                   roc.print_coords = true;
190                 }
191               else
192                 {
193                   lex_error (lexer, NULL);
194                   goto error;
195                 }
196             }
197         }
198       else if (lex_match_id (lexer, "CRITERIA"))
199         {
200           lex_match (lexer, T_EQUALS);
201           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
202             {
203               if (lex_match_id (lexer, "CUTOFF"))
204                 {
205                   if (! lex_force_match (lexer, T_LPAREN))
206                     goto error;
207                   if (lex_match_id (lexer, "INCLUDE"))
208                     {
209                       roc.exclude = MV_SYSTEM;
210                     }
211                   else if (lex_match_id (lexer, "EXCLUDE"))
212                     {
213                       roc.exclude = MV_USER | MV_SYSTEM;
214                     }
215                   else
216                     {
217                       lex_error (lexer, NULL);
218                       goto error;
219                     }
220                   if (! lex_force_match (lexer, T_RPAREN))
221                     goto error;
222                 }
223               else if (lex_match_id (lexer, "TESTPOS"))
224                 {
225                   if (! lex_force_match (lexer, T_LPAREN))
226                     goto error;
227                   if (lex_match_id (lexer, "LARGE"))
228                     {
229                       roc.invert = false;
230                     }
231                   else if (lex_match_id (lexer, "SMALL"))
232                     {
233                       roc.invert = true;
234                     }
235                   else
236                     {
237                       lex_error (lexer, NULL);
238                       goto error;
239                     }
240                   if (! lex_force_match (lexer, T_RPAREN))
241                     goto error;
242                 }
243               else if (lex_match_id (lexer, "CI"))
244                 {
245                   if (!lex_force_match (lexer, T_LPAREN))
246                     goto error;
247                   if (! lex_force_num (lexer))
248                     goto error;
249                   roc.ci = lex_number (lexer);
250                   lex_get (lexer);
251                   if (!lex_force_match (lexer, T_RPAREN))
252                     goto error;
253                 }
254               else if (lex_match_id (lexer, "DISTRIBUTION"))
255                 {
256                   if (!lex_force_match (lexer, T_LPAREN))
257                     goto error;
258                   if (lex_match_id (lexer, "FREE"))
259                     {
260                       roc.bi_neg_exp = false;
261                     }
262                   else if (lex_match_id (lexer, "NEGEXPO"))
263                     {
264                       roc.bi_neg_exp = true;
265                     }
266                   else
267                     {
268                       lex_error (lexer, NULL);
269                       goto error;
270                     }
271                   if (!lex_force_match (lexer, T_RPAREN))
272                     goto error;
273                 }
274               else
275                 {
276                   lex_error (lexer, NULL);
277                   goto error;
278                 }
279             }
280         }
281       else
282         {
283           lex_error (lexer, NULL);
284           break;
285         }
286     }
287
288   if (! run_roc (ds, &roc))
289     goto error;
290
291   if (roc.state_var)
292     value_destroy (&roc.state_value, roc.state_var_width);
293   free (roc.vars);
294   return CMD_SUCCESS;
295
296  error:
297   if (roc.state_var)
298     value_destroy (&roc.state_value, roc.state_var_width);
299   free (roc.vars);
300   return CMD_FAILURE;
301 }
302
303
304
305
306 static void
307 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
308
309
310 static int
311 run_roc (struct dataset *ds, struct cmd_roc *roc)
312 {
313   struct dictionary *dict = dataset_dict (ds);
314   bool ok;
315   struct casereader *group;
316
317   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
318   while (casegrouper_get_next_group (grouper, &group))
319     {
320       do_roc (roc, group, dataset_dict (ds));
321     }
322   ok = casegrouper_destroy (grouper);
323   ok = proc_commit (ds) && ok;
324
325   return ok;
326 }
327
328 #if 0
329 static void
330 dump_casereader (struct casereader *reader)
331 {
332   struct ccase *c;
333   struct casereader *r = casereader_clone (reader);
334
335   for (; (c = casereader_read (r)); case_unref (c))
336     {
337       int i;
338       for (i = 0 ; i < case_get_n_values (c); ++i)
339         printf ("%g ", case_num_idx (c, i));
340       printf ("\n");
341     }
342
343   casereader_destroy (r);
344 }
345 #endif
346
347
348 /*
349    Return true iff the state variable indicates that C has positive actual state.
350
351    As a side effect, this function also accumulates the roc->{pos,neg} and
352    roc->{pos,neg}_weighted counts.
353  */
354 static bool
355 match_positives (const struct ccase *c, void *aux)
356 {
357   struct cmd_roc *roc = aux;
358   const struct variable *wv = dict_get_weight (roc->dict);
359   const double weight = wv ? case_num (c, wv) : 1.0;
360
361   const bool positive =
362   (0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
363     var_get_width (roc->state_var)));
364
365   if (positive)
366     {
367       roc->pos++;
368       roc->pos_weighted += weight;
369     }
370   else
371     {
372       roc->neg++;
373       roc->neg_weighted += weight;
374     }
375
376   return positive;
377 }
378
379
380 #define VALUE  0
381 #define N_EQ   1
382 #define N_PRED 2
383
384 /* Some intermediate state for calculating the cutpoints and the
385    standard error values */
386 struct roc_state
387 {
388   double auc;  /* Area under the curve */
389
390   double n1;  /* total weight of positives */
391   double n2;  /* total weight of negatives */
392
393   /* intermediates for standard error */
394   double q1hat;
395   double q2hat;
396
397   /* intermediates for cutpoints */
398   struct casewriter *cutpoint_wtr;
399   struct casereader *cutpoint_rdr;
400   double prev_result;
401   double min;
402   double max;
403 };
404
405 /*
406    Return a new casereader based upon CUTPOINT_RDR.
407    The number of "positive" cases are placed into
408    the position TRUE_INDEX, and the number of "negative" cases
409    into FALSE_INDEX.
410    POS_COND and RESULT determine the semantics of what is
411    "positive".
412    WEIGHT is the value of a single count.
413  */
414 static struct casereader *
415 accumulate_counts (struct casereader *input,
416                    double result, double weight,
417                    bool (*pos_cond) (double, double),
418                    int true_index, int false_index)
419 {
420   const struct caseproto *proto = casereader_get_proto (input);
421   struct casewriter *w =
422     autopaging_writer_create (proto);
423   struct ccase *cpc;
424   double prev_cp = SYSMIS;
425
426   for (; (cpc = casereader_read (input)); case_unref (cpc))
427     {
428       struct ccase *new_case;
429       const double cp = case_num_idx (cpc, ROC_CUTPOINT);
430
431       assert (cp != SYSMIS);
432
433       /* We don't want duplicates here */
434       if (cp == prev_cp)
435         continue;
436
437       new_case = case_clone (cpc);
438
439       int index = pos_cond (result, cp) ? true_index : false_index;
440       *case_num_rw_idx (new_case, index) += weight;
441
442       prev_cp = cp;
443
444       casewriter_write (w, new_case);
445     }
446   casereader_destroy (input);
447
448   return casewriter_make_reader (w);
449 }
450
451
452
453 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
454
455 /*
456   This function does 3 things:
457
458   1. Counts the number of cases which are equal to every other case in READER,
459   and those cases for which the relationship between it and every other case
460   satifies PRED (normally either > or <).  VAR is variable defining a case's value
461   for this purpose.
462
463   2. Counts the number of true and false cases in reader, and populates
464   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
465   which receive these values.  POS_COND is the condition defining true
466   and false.
467
468   3. CC is filled with the cumulative weight of all cases of READER.
469 */
470 static struct casereader *
471 process_group (const struct variable *var, struct casereader *reader,
472                bool (*pred) (double, double),
473                const struct dictionary *dict,
474                double *cc,
475                struct casereader **cutpoint_rdr,
476                bool (*pos_cond) (double, double),
477                int true_index,
478                int false_index)
479 {
480   const struct variable *w = dict_get_weight (dict);
481
482   struct casereader *r1 =
483     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
484
485   const int weight_idx  = w ? var_get_case_index (w) :
486     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
487
488   struct ccase *c1;
489
490   struct casereader *rclone = casereader_clone (r1);
491   struct casewriter *wtr;
492   struct caseproto *proto = caseproto_create ();
493
494   proto = caseproto_add_width (proto, 0);
495   proto = caseproto_add_width (proto, 0);
496   proto = caseproto_add_width (proto, 0);
497
498   wtr = autopaging_writer_create (proto);
499
500   *cc = 0;
501
502   for (; (c1 = casereader_read (r1)); case_unref (c1))
503     {
504       struct ccase *new_case = case_create (proto);
505       struct ccase *c2;
506       struct casereader *r2 = casereader_clone (rclone);
507
508       const double weight1 = case_num_idx (c1, weight_idx);
509       const double d1 = case_num (c1, var);
510       double n_eq = 0.0;
511       double n_pred = 0.0;
512
513       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
514                                          pos_cond,
515                                          true_index, false_index);
516
517       *cc += weight1;
518
519       for (; (c2 = casereader_read (r2)); case_unref (c2))
520         {
521           const double d2 = case_num (c2, var);
522           const double weight2 = case_num_idx (c2, weight_idx);
523
524           if (d1 == d2)
525             {
526               n_eq += weight2;
527               continue;
528             }
529           else  if (pred (d2, d1))
530             {
531               n_pred += weight2;
532             }
533         }
534
535       *case_num_rw_idx (new_case, VALUE) = d1;
536       *case_num_rw_idx (new_case, N_EQ) = n_eq;
537       *case_num_rw_idx (new_case, N_PRED) = n_pred;
538
539       casewriter_write (wtr, new_case);
540
541       casereader_destroy (r2);
542     }
543
544
545   casereader_destroy (r1);
546   casereader_destroy (rclone);
547
548   caseproto_unref (proto);
549
550   return casewriter_make_reader (wtr);
551 }
552
553 /* Some more indeces into case data */
554 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
555 #define N_POS_GT 2  /* number of positive cases with values greater than n */
556 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
557 #define N_NEG_LT 4  /* number of negative cases with values less than n */
558
559 static bool
560 gt (double d1, double d2)
561 {
562   return d1 > d2;
563 }
564
565
566 static bool
567 ge (double d1, double d2)
568 {
569   return d1 > d2;
570 }
571
572 static bool
573 lt (double d1, double d2)
574 {
575   return d1 < d2;
576 }
577
578
579 /*
580   Return a casereader with width 3,
581   populated with cases based upon READER.
582   The cases will have the values:
583   (N, number of cases equal to N, number of cases greater than N)
584   As a side effect, update RS->n1 with the number of positive cases.
585 */
586 static struct casereader *
587 process_positive_group (const struct variable *var, struct casereader *reader,
588                         const struct dictionary *dict,
589                         struct roc_state *rs)
590 {
591   return process_group (var, reader, gt, dict, &rs->n1,
592                         &rs->cutpoint_rdr,
593                         ge,
594                         ROC_TP, ROC_FN);
595 }
596
597 /*
598   Return a casereader with width 3,
599   populated with cases based upon READER.
600   The cases will have the values:
601   (N, number of cases equal to N, number of cases less than N)
602   As a side effect, update RS->n2 with the number of negative cases.
603 */
604 static struct casereader *
605 process_negative_group (const struct variable *var, struct casereader *reader,
606                         const struct dictionary *dict,
607                         struct roc_state *rs)
608 {
609   return process_group (var, reader, lt, dict, &rs->n2,
610                         &rs->cutpoint_rdr,
611                         lt,
612                         ROC_TN, ROC_FP);
613 }
614
615
616
617
618 static void
619 append_cutpoint (struct casewriter *writer, double cutpoint)
620 {
621   struct ccase *cc = case_create (casewriter_get_proto (writer));
622
623   *case_num_rw_idx (cc, ROC_CUTPOINT) = cutpoint;
624   *case_num_rw_idx (cc, ROC_TP) = 0;
625   *case_num_rw_idx (cc, ROC_FN) = 0;
626   *case_num_rw_idx (cc, ROC_TN) = 0;
627   *case_num_rw_idx (cc, ROC_FP) = 0;
628
629   casewriter_write (writer, cc);
630 }
631
632
633 /*
634    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
635    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
636    reader will be populated with its final number of cases.
637    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
638    value.  The other entries will be initialised to zero.
639 */
640 static void
641 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
642 {
643   int i;
644   struct casereader *r = casereader_clone (input);
645   struct ccase *c;
646
647   {
648     struct caseproto *proto = caseproto_create ();
649     struct subcase ordering;
650     subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
651
652     proto = caseproto_add_width (proto, 0); /* cutpoint */
653     proto = caseproto_add_width (proto, 0); /* ROC_TP */
654     proto = caseproto_add_width (proto, 0); /* ROC_FN */
655     proto = caseproto_add_width (proto, 0); /* ROC_TN */
656     proto = caseproto_add_width (proto, 0); /* ROC_FP */
657
658     for (i = 0 ; i < roc->n_vars; ++i)
659       {
660         rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
661         rs[i].prev_result = SYSMIS;
662         rs[i].max = -DBL_MAX;
663         rs[i].min = DBL_MAX;
664       }
665
666     caseproto_unref (proto);
667     subcase_destroy (&ordering);
668   }
669
670   for (; (c = casereader_read (r)) != NULL; case_unref (c))
671     {
672       for (i = 0 ; i < roc->n_vars; ++i)
673         {
674           const union value *v = case_data (c, roc->vars[i]);
675           const double result = v->f;
676
677           if (mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
678             continue;
679
680           minimize (&rs[i].min, result);
681           maximize (&rs[i].max, result);
682
683           if (rs[i].prev_result != SYSMIS && rs[i].prev_result != result)
684             {
685               const double mean = (result + rs[i].prev_result) / 2.0;
686               append_cutpoint (rs[i].cutpoint_wtr, mean);
687             }
688
689           rs[i].prev_result = result;
690         }
691     }
692   casereader_destroy (r);
693
694
695   /* Append the min and max cutpoints */
696   for (i = 0 ; i < roc->n_vars; ++i)
697     {
698       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
699       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
700
701       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
702     }
703 }
704
705 static void
706 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
707 {
708   int i;
709
710   struct roc_state *rs = XCALLOC (roc->n_vars,  struct roc_state);
711
712   struct casereader *negatives = NULL;
713   struct casereader *positives = NULL;
714
715   struct caseproto *n_proto = NULL;
716
717   struct subcase up_ordering;
718   struct subcase down_ordering;
719
720   struct casewriter *neg_wtr = NULL;
721
722   struct casereader *input = casereader_create_filter_missing (reader,
723                                                                roc->vars, roc->n_vars,
724                                                                roc->exclude,
725                                                                NULL,
726                                                                NULL);
727
728   input = casereader_create_filter_missing (input,
729                                             &roc->state_var, 1,
730                                             roc->exclude,
731                                             NULL,
732                                             NULL);
733
734   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
735
736   prepare_cutpoints (roc, rs, input);
737
738
739   /* Separate the positive actual state cases from the negative ones */
740   positives =
741     casereader_create_filter_func (input,
742                                    match_positives,
743                                    NULL,
744                                    roc,
745                                    neg_wtr);
746
747   n_proto = caseproto_create ();
748
749   n_proto = caseproto_add_width (n_proto, 0);
750   n_proto = caseproto_add_width (n_proto, 0);
751   n_proto = caseproto_add_width (n_proto, 0);
752   n_proto = caseproto_add_width (n_proto, 0);
753   n_proto = caseproto_add_width (n_proto, 0);
754
755   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
756   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
757
758   for (i = 0 ; i < roc->n_vars; ++i)
759     {
760       struct casewriter *w = NULL;
761       struct casereader *r = NULL;
762
763       struct ccase *c;
764
765       struct ccase *cpos;
766       struct casereader *n_neg_reader ;
767       const struct variable *var = roc->vars[i];
768
769       struct casereader *neg ;
770       struct casereader *pos = casereader_clone (positives);
771
772       struct casereader *n_pos_reader =
773         process_positive_group (var, pos, dict, &rs[i]);
774
775       if (negatives == NULL)
776         {
777           negatives = casewriter_make_reader (neg_wtr);
778         }
779
780       neg = casereader_clone (negatives);
781
782       n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
783
784       /* Merge the n_pos and n_neg casereaders */
785       w = sort_create_writer (&up_ordering, n_proto);
786       for (; (cpos = casereader_read (n_pos_reader)); case_unref (cpos))
787         {
788           struct ccase *pos_case = case_create (n_proto);
789           struct ccase *cneg;
790           const double jpos = case_num_idx (cpos, VALUE);
791
792           while ((cneg = casereader_read (n_neg_reader)))
793             {
794               struct ccase *nc = case_create (n_proto);
795
796               const double jneg = case_num_idx (cneg, VALUE);
797
798               *case_num_rw_idx (nc, VALUE) = jneg;
799               *case_num_rw_idx (nc, N_POS_EQ) = 0;
800
801               *case_num_rw_idx (nc, N_POS_GT) = SYSMIS;
802
803               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
804               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
805
806               casewriter_write (w, nc);
807
808               case_unref (cneg);
809               if (jneg > jpos)
810                 break;
811             }
812
813           *case_num_rw_idx (pos_case, VALUE) = jpos;
814           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
815           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
816           *case_num_rw_idx (pos_case, N_NEG_EQ) = 0;
817           *case_num_rw_idx (pos_case, N_NEG_LT) = SYSMIS;
818
819           casewriter_write (w, pos_case);
820         }
821
822       casereader_destroy (n_pos_reader);
823       casereader_destroy (n_neg_reader);
824
825 /* These aren't used anymore */
826 #undef N_EQ
827 #undef N_PRED
828
829       r = casewriter_make_reader (w);
830
831       /* Propagate the N_POS_GT values from the positive cases
832          to the negative ones */
833       {
834         double prev_pos_gt = rs[i].n1;
835         w = sort_create_writer (&down_ordering, n_proto);
836
837         for (; (c = casereader_read (r)); case_unref (c))
838           {
839             double n_pos_gt = case_num_idx (c, N_POS_GT);
840             struct ccase *nc = case_clone (c);
841
842             if (n_pos_gt == SYSMIS)
843               {
844                 n_pos_gt = prev_pos_gt;
845                 *case_num_rw_idx (nc, N_POS_GT) = n_pos_gt;
846               }
847
848             casewriter_write (w, nc);
849             prev_pos_gt = n_pos_gt;
850           }
851
852         casereader_destroy (r);
853         r = casewriter_make_reader (w);
854       }
855
856       /* Propagate the N_NEG_LT values from the negative cases
857          to the positive ones */
858       {
859         double prev_neg_lt = rs[i].n2;
860         w = sort_create_writer (&up_ordering, n_proto);
861
862         for (; (c = casereader_read (r)); case_unref (c))
863           {
864             double n_neg_lt = case_num_idx (c, N_NEG_LT);
865             struct ccase *nc = case_clone (c);
866
867             if (n_neg_lt == SYSMIS)
868               {
869                 n_neg_lt = prev_neg_lt;
870                 *case_num_rw_idx (nc, N_NEG_LT) = n_neg_lt;
871               }
872
873             casewriter_write (w, nc);
874             prev_neg_lt = n_neg_lt;
875           }
876
877         casereader_destroy (r);
878         r = casewriter_make_reader (w);
879       }
880
881       {
882         struct ccase *prev_case = NULL;
883         for (; (c = casereader_read (r)); case_unref (c))
884           {
885             struct ccase *next_case = casereader_peek (r, 0);
886
887             const double j = case_num_idx (c, VALUE);
888             double n_pos_eq = case_num_idx (c, N_POS_EQ);
889             double n_pos_gt = case_num_idx (c, N_POS_GT);
890             double n_neg_eq = case_num_idx (c, N_NEG_EQ);
891             double n_neg_lt = case_num_idx (c, N_NEG_LT);
892
893             if (prev_case && j == case_num_idx (prev_case, VALUE))
894               {
895                 if (0 ==  case_num_idx (c, N_POS_EQ))
896                   {
897                     n_pos_eq = case_num_idx (prev_case, N_POS_EQ);
898                     n_pos_gt = case_num_idx (prev_case, N_POS_GT);
899                   }
900
901                 if (0 ==  case_num_idx (c, N_NEG_EQ))
902                   {
903                     n_neg_eq = case_num_idx (prev_case, N_NEG_EQ);
904                     n_neg_lt = case_num_idx (prev_case, N_NEG_LT);
905                   }
906               }
907
908             if (NULL == next_case || j != case_num_idx (next_case, VALUE))
909               {
910                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
911
912                 rs[i].q1hat +=
913                   n_neg_eq * (pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
914                 rs[i].q2hat +=
915                   n_pos_eq * (pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
916
917               }
918
919             case_unref (next_case);
920             case_unref (prev_case);
921             prev_case = case_clone (c);
922           }
923         casereader_destroy (r);
924         case_unref (prev_case);
925
926         rs[i].auc /=  rs[i].n1 * rs[i].n2;
927         if (roc->invert)
928           rs[i].auc = 1 - rs[i].auc;
929
930         if (roc->bi_neg_exp)
931           {
932             rs[i].q1hat = rs[i].auc / (2 - rs[i].auc);
933             rs[i].q2hat = 2 * pow2 (rs[i].auc) / (1 + rs[i].auc);
934           }
935         else
936           {
937             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
938             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
939           }
940       }
941     }
942
943   casereader_destroy (positives);
944   casereader_destroy (negatives);
945
946   caseproto_unref (n_proto);
947   subcase_destroy (&up_ordering);
948   subcase_destroy (&down_ordering);
949
950   output_roc (rs, roc);
951
952   for (i = 0 ; i < roc->n_vars; ++i)
953     casereader_destroy (rs[i].cutpoint_rdr);
954
955   free (rs);
956 }
957
958 static void
959 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
960 {
961   struct pivot_table *table = pivot_table_create (N_("Area Under the Curve"));
962
963   struct pivot_dimension *statistics = pivot_dimension_create (
964     table, PIVOT_AXIS_COLUMN, N_("Statistics"),
965     N_("Area"), PIVOT_RC_OTHER);
966   if (roc->print_se)
967     {
968       pivot_category_create_leaves (
969         statistics->root,
970         N_("Std. Error"), PIVOT_RC_OTHER,
971         N_("Asymptotic Sig."), PIVOT_RC_SIGNIFICANCE);
972       struct pivot_category *interval = pivot_category_create_group__ (
973         statistics->root,
974         pivot_value_new_text_format (N_("Asymp. %g%% Confidence Interval"),
975                                      roc->ci));
976       pivot_category_create_leaves (interval,
977                                     N_("Lower Bound"), PIVOT_RC_OTHER,
978                                     N_("Upper Bound"), PIVOT_RC_OTHER);
979     }
980
981   struct pivot_dimension *variables = pivot_dimension_create (
982     table, PIVOT_AXIS_ROW, N_("Variable under test"));
983   variables->root->show_label = true;
984
985   for (size_t i = 0 ; i < roc->n_vars ; ++i)
986     {
987       int var_idx = pivot_category_create_leaf (
988         variables->root, pivot_value_new_variable (roc->vars[i]));
989
990       pivot_table_put2 (table, 0, var_idx, pivot_value_new_number (rs[i].auc));
991
992       if (roc->print_se)
993         {
994           double se = (rs[i].auc * (1 - rs[i].auc)
995                        + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc))
996                        + (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc)));
997           se /= rs[i].n1 * rs[i].n2;
998           se = sqrt (se);
999
1000           double ci = 1 - roc->ci / 100.0;
1001           double yy = gsl_cdf_gaussian_Qinv (ci, se);
1002
1003           double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1004                                 (12 * rs[i].n1 * rs[i].n2));
1005           double sig = 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5)
1006                                                         / sd_0_5));
1007           double entries[] = { se, sig, rs[i].auc - yy, rs[i].auc + yy };
1008           for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1009             pivot_table_put2 (table, i + 1, var_idx,
1010                               pivot_value_new_number (entries[i]));
1011         }
1012     }
1013
1014   pivot_table_submit (table);
1015 }
1016
1017
1018 static void
1019 show_summary (const struct cmd_roc *roc)
1020 {
1021   struct pivot_table *table = pivot_table_create (N_("Case Summary"));
1022
1023   struct pivot_dimension *statistics = pivot_dimension_create (
1024     table, PIVOT_AXIS_COLUMN, N_("Valid N (listwise)"),
1025     N_("Unweighted"), PIVOT_RC_INTEGER,
1026     N_("Weighted"), PIVOT_RC_OTHER);
1027   statistics->root->show_label = true;
1028
1029   struct pivot_dimension *cases = pivot_dimension_create__ (
1030     table, PIVOT_AXIS_ROW, pivot_value_new_variable (roc->state_var));
1031   cases->root->show_label = true;
1032   pivot_category_create_leaves (cases->root, N_("Positive"), N_("Negative"));
1033
1034   struct entry
1035     {
1036       int stat_idx;
1037       int case_idx;
1038       double x;
1039     }
1040   entries[] = {
1041     { 0, 0, roc->pos },
1042     { 0, 1, roc->neg },
1043     { 1, 0, roc->pos_weighted },
1044     { 1, 1, roc->neg_weighted },
1045   };
1046   for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1047     {
1048       const struct entry *e = &entries[i];
1049       pivot_table_put2 (table, e->stat_idx, e->case_idx,
1050                         pivot_value_new_number (e->x));
1051     }
1052   pivot_table_submit (table);
1053 }
1054
1055 static void
1056 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1057 {
1058   struct pivot_table *table = pivot_table_create (
1059     N_("Coordinates of the Curve"));
1060
1061   pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
1062                           N_("Positive if greater than or equal to"),
1063                           N_("Sensitivity"), N_("1 - Specificity"));
1064
1065   struct pivot_dimension *coordinates = pivot_dimension_create (
1066     table, PIVOT_AXIS_ROW, N_("Coordinates"));
1067   coordinates->hide_all_labels = true;
1068
1069   struct pivot_dimension *variables = pivot_dimension_create (
1070     table, PIVOT_AXIS_ROW, N_("Test variable"));
1071   variables->root->show_label = true;
1072
1073
1074   int n_coords = 0;
1075   for (size_t i = 0; i < roc->n_vars; ++i)
1076     {
1077       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1078
1079       int var_idx = pivot_category_create_leaf (
1080         variables->root, pivot_value_new_variable (roc->vars[i]));
1081
1082       struct ccase *cc;
1083       int coord_idx = 0;
1084       for (; (cc = casereader_read (r)) != NULL; case_unref (cc))
1085         {
1086           const double se = case_num_idx (cc, ROC_TP) /
1087             (case_num_idx (cc, ROC_TP) + case_num_idx (cc, ROC_FN));
1088
1089           const double sp = case_num_idx (cc, ROC_TN) /
1090             (case_num_idx (cc, ROC_TN) + case_num_idx (cc, ROC_FP));
1091
1092           if (coord_idx >= n_coords)
1093             {
1094               assert (coord_idx == n_coords);
1095               pivot_category_create_leaf (
1096                 coordinates->root, pivot_value_new_integer (++n_coords));
1097             }
1098
1099           pivot_table_put3 (
1100             table, 0, coord_idx, var_idx,
1101             pivot_value_new_var_value (roc->vars[i],
1102                                        case_data_idx (cc, ROC_CUTPOINT)));
1103
1104           pivot_table_put3 (table, 1, coord_idx, var_idx,
1105                             pivot_value_new_number (se));
1106           pivot_table_put3 (table, 2, coord_idx, var_idx,
1107                             pivot_value_new_number (1 - sp));
1108           coord_idx++;
1109         }
1110
1111       casereader_destroy (r);
1112     }
1113
1114   pivot_table_submit (table);
1115 }
1116
1117
1118 static void
1119 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1120 {
1121   show_summary (roc);
1122
1123   if (roc->curve)
1124     {
1125       struct roc_chart *rc;
1126       size_t i;
1127
1128       rc = roc_chart_create (roc->reference);
1129       for (i = 0; i < roc->n_vars; i++)
1130         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1131                            rs[i].cutpoint_rdr);
1132       roc_chart_submit (rc);
1133     }
1134
1135   show_auc (rs, roc);
1136
1137   if (roc->print_coords)
1138     show_coords (rs, roc);
1139 }
1140