math: Coding style updates in some order-stat implementations.
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/roc.h"
20
21 #include <gsl/gsl_cdf.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/charts/roc-chart.h"
37 #include "output/pivot-table.h"
38
39 #include "gettext.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) msgid
42
43 struct cmd_roc
44 {
45   size_t n_vars;
46   const struct variable **vars;
47   const struct dictionary *dict;
48
49   const struct variable *state_var;
50   union value state_value;
51   size_t state_var_width;
52
53   /* Plot the roc curve */
54   bool curve;
55   /* Plot the reference line */
56   bool reference;
57
58   double ci;
59
60   bool print_coords;
61   bool print_se;
62   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
63                       should be used */
64   enum mv_class exclude;
65
66   bool invert ; /* True iff a smaller test result variable indicates
67                    a positive result */
68
69   double pos;
70   double neg;
71   double pos_weighted;
72   double neg_weighted;
73 };
74
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
76
77 int
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
79 {
80   struct cmd_roc roc ;
81   const struct dictionary *dict = dataset_dict (ds);
82
83   roc.vars = NULL;
84   roc.n_vars = 0;
85   roc.print_se = false;
86   roc.print_coords = false;
87   roc.exclude = MV_ANY;
88   roc.curve = true;
89   roc.reference = false;
90   roc.ci = 95;
91   roc.bi_neg_exp = false;
92   roc.invert = false;
93   roc.pos = roc.pos_weighted = 0;
94   roc.neg = roc.neg_weighted = 0;
95   roc.dict = dataset_dict (ds);
96   roc.state_var = NULL;
97   roc.state_var_width = -1;
98
99   lex_match (lexer, T_SLASH);
100   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
101                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
102     goto error;
103
104   if (! lex_force_match (lexer, T_BY))
105     {
106       goto error;
107     }
108
109   roc.state_var = parse_variable (lexer, dict);
110   if (! roc.state_var)
111     {
112       goto error;
113     }
114
115   if (!lex_force_match (lexer, T_LPAREN))
116     {
117       goto error;
118     }
119
120   roc.state_var_width = var_get_width (roc.state_var);
121   value_init (&roc.state_value, roc.state_var_width);
122   parse_value (lexer, &roc.state_value, roc.state_var);
123
124
125   if (!lex_force_match (lexer, T_RPAREN))
126     {
127       goto error;
128     }
129
130   while (lex_token (lexer) != T_ENDCMD)
131     {
132       lex_match (lexer, T_SLASH);
133       if (lex_match_id (lexer, "MISSING"))
134         {
135           lex_match (lexer, T_EQUALS);
136           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
137             {
138               if (lex_match_id (lexer, "INCLUDE"))
139                 {
140                   roc.exclude = MV_SYSTEM;
141                 }
142               else if (lex_match_id (lexer, "EXCLUDE"))
143                 {
144                   roc.exclude = MV_ANY;
145                 }
146               else
147                 {
148                   lex_error (lexer, NULL);
149                   goto error;
150                 }
151             }
152         }
153       else if (lex_match_id (lexer, "PLOT"))
154         {
155           lex_match (lexer, T_EQUALS);
156           if (lex_match_id (lexer, "CURVE"))
157             {
158               roc.curve = true;
159               if (lex_match (lexer, T_LPAREN))
160                 {
161                   roc.reference = true;
162                   if (! lex_force_match_id (lexer, "REFERENCE"))
163                     goto error;
164                   if (! lex_force_match (lexer, T_RPAREN))
165                     goto error;
166                 }
167             }
168           else if (lex_match_id (lexer, "NONE"))
169             {
170               roc.curve = false;
171             }
172           else
173             {
174               lex_error (lexer, NULL);
175               goto error;
176             }
177         }
178       else if (lex_match_id (lexer, "PRINT"))
179         {
180           lex_match (lexer, T_EQUALS);
181           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
182             {
183               if (lex_match_id (lexer, "SE"))
184                 {
185                   roc.print_se = true;
186                 }
187               else if (lex_match_id (lexer, "COORDINATES"))
188                 {
189                   roc.print_coords = true;
190                 }
191               else
192                 {
193                   lex_error (lexer, NULL);
194                   goto error;
195                 }
196             }
197         }
198       else if (lex_match_id (lexer, "CRITERIA"))
199         {
200           lex_match (lexer, T_EQUALS);
201           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
202             {
203               if (lex_match_id (lexer, "CUTOFF"))
204                 {
205                   if (! lex_force_match (lexer, T_LPAREN))
206                     goto error;
207                   if (lex_match_id (lexer, "INCLUDE"))
208                     {
209                       roc.exclude = MV_SYSTEM;
210                     }
211                   else if (lex_match_id (lexer, "EXCLUDE"))
212                     {
213                       roc.exclude = MV_USER | MV_SYSTEM;
214                     }
215                   else
216                     {
217                       lex_error (lexer, NULL);
218                       goto error;
219                     }
220                   if (! lex_force_match (lexer, T_RPAREN))
221                     goto error;
222                 }
223               else if (lex_match_id (lexer, "TESTPOS"))
224                 {
225                   if (! lex_force_match (lexer, T_LPAREN))
226                     goto error;
227                   if (lex_match_id (lexer, "LARGE"))
228                     {
229                       roc.invert = false;
230                     }
231                   else if (lex_match_id (lexer, "SMALL"))
232                     {
233                       roc.invert = true;
234                     }
235                   else
236                     {
237                       lex_error (lexer, NULL);
238                       goto error;
239                     }
240                   if (! lex_force_match (lexer, T_RPAREN))
241                     goto error;
242                 }
243               else if (lex_match_id (lexer, "CI"))
244                 {
245                   if (!lex_force_match (lexer, T_LPAREN))
246                     goto error;
247                   if (! lex_force_num (lexer))
248                     goto error;
249                   roc.ci = lex_number (lexer);
250                   lex_get (lexer);
251                   if (!lex_force_match (lexer, T_RPAREN))
252                     goto error;
253                 }
254               else if (lex_match_id (lexer, "DISTRIBUTION"))
255                 {
256                   if (!lex_force_match (lexer, T_LPAREN))
257                     goto error;
258                   if (lex_match_id (lexer, "FREE"))
259                     {
260                       roc.bi_neg_exp = false;
261                     }
262                   else if (lex_match_id (lexer, "NEGEXPO"))
263                     {
264                       roc.bi_neg_exp = true;
265                     }
266                   else
267                     {
268                       lex_error (lexer, NULL);
269                       goto error;
270                     }
271                   if (!lex_force_match (lexer, T_RPAREN))
272                     goto error;
273                 }
274               else
275                 {
276                   lex_error (lexer, NULL);
277                   goto error;
278                 }
279             }
280         }
281       else
282         {
283           lex_error (lexer, NULL);
284           break;
285         }
286     }
287
288   if (! run_roc (ds, &roc))
289     goto error;
290
291   if (roc.state_var)
292     value_destroy (&roc.state_value, roc.state_var_width);
293   free (roc.vars);
294   return CMD_SUCCESS;
295
296  error:
297   if (roc.state_var)
298     value_destroy (&roc.state_value, roc.state_var_width);
299   free (roc.vars);
300   return CMD_FAILURE;
301 }
302
303
304
305
306 static void
307 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
308
309
310 static int
311 run_roc (struct dataset *ds, struct cmd_roc *roc)
312 {
313   struct dictionary *dict = dataset_dict (ds);
314   bool ok;
315   struct casereader *group;
316
317   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
318   while (casegrouper_get_next_group (grouper, &group))
319     {
320       do_roc (roc, group, dataset_dict (ds));
321     }
322   ok = casegrouper_destroy (grouper);
323   ok = proc_commit (ds) && ok;
324
325   return ok;
326 }
327
328 #if 0
329 static void
330 dump_casereader (struct casereader *reader)
331 {
332   struct ccase *c;
333   struct casereader *r = casereader_clone (reader);
334
335   for (; (c = casereader_read (r)); case_unref (c))
336     {
337       int i;
338       for (i = 0 ; i < case_get_n_values (c); ++i)
339         printf ("%g ", case_num_idx (c, i));
340       printf ("\n");
341     }
342
343   casereader_destroy (r);
344 }
345 #endif
346
347
348 /*
349    Return true iff the state variable indicates that C has positive actual state.
350
351    As a side effect, this function also accumulates the roc->{pos,neg} and
352    roc->{pos,neg}_weighted counts.
353  */
354 static bool
355 match_positives (const struct ccase *c, void *aux)
356 {
357   struct cmd_roc *roc = aux;
358   const struct variable *wv = dict_get_weight (roc->dict);
359   const double weight = wv ? case_num (c, wv) : 1.0;
360
361   const bool positive =
362   (0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
363     var_get_width (roc->state_var)));
364
365   if (positive)
366     {
367       roc->pos++;
368       roc->pos_weighted += weight;
369     }
370   else
371     {
372       roc->neg++;
373       roc->neg_weighted += weight;
374     }
375
376   return positive;
377 }
378
379
380 #define VALUE  0
381 #define N_EQ   1
382 #define N_PRED 2
383
384 /* Some intermediate state for calculating the cutpoints and the
385    standard error values */
386 struct roc_state
387 {
388   double auc;  /* Area under the curve */
389
390   double n1;  /* total weight of positives */
391   double n2;  /* total weight of negatives */
392
393   /* intermediates for standard error */
394   double q1hat;
395   double q2hat;
396
397   /* intermediates for cutpoints */
398   struct casewriter *cutpoint_wtr;
399   struct casereader *cutpoint_rdr;
400   double prev_result;
401   double min;
402   double max;
403 };
404
405 /*
406    Return a new casereader based upon CUTPOINT_RDR.
407    The number of "positive" cases are placed into
408    the position TRUE_INDEX, and the number of "negative" cases
409    into FALSE_INDEX.
410    POS_COND and RESULT determine the semantics of what is
411    "positive".
412    WEIGHT is the value of a single count.
413  */
414 static struct casereader *
415 accumulate_counts (struct casereader *input,
416                    double result, double weight,
417                    bool (*pos_cond) (double, double),
418                    int true_index, int false_index)
419 {
420   const struct caseproto *proto = casereader_get_proto (input);
421   struct casewriter *w =
422     autopaging_writer_create (proto);
423   struct ccase *cpc;
424   double prev_cp = SYSMIS;
425
426   for (; (cpc = casereader_read (input)); case_unref (cpc))
427     {
428       struct ccase *new_case;
429       const double cp = case_num_idx (cpc, ROC_CUTPOINT);
430
431       assert (cp != SYSMIS);
432
433       /* We don't want duplicates here */
434       if (cp == prev_cp)
435         continue;
436
437       new_case = case_clone (cpc);
438
439       int index = pos_cond (result, cp) ? true_index : false_index;
440       *case_num_rw_idx (new_case, index) += weight;
441
442       prev_cp = cp;
443
444       casewriter_write (w, new_case);
445     }
446   casereader_destroy (input);
447
448   return casewriter_make_reader (w);
449 }
450
451
452
453 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
454
455 /*
456   This function does 3 things:
457
458   1. Counts the number of cases which are equal to every other case in READER,
459   and those cases for which the relationship between it and every other case
460   satifies PRED (normally either > or <).  VAR is variable defining a case's value
461   for this purpose.
462
463   2. Counts the number of true and false cases in reader, and populates
464   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
465   which receive these values.  POS_COND is the condition defining true
466   and false.
467
468   3. CC is filled with the cumulative weight of all cases of READER.
469 */
470 static struct casereader *
471 process_group (const struct variable *var, struct casereader *reader,
472                bool (*pred) (double, double),
473                const struct dictionary *dict,
474                double *cc,
475                struct casereader **cutpoint_rdr,
476                bool (*pos_cond) (double, double),
477                int true_index,
478                int false_index)
479 {
480   const struct variable *w = dict_get_weight (dict);
481
482   struct casereader *r1 =
483     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
484
485   const int weight_idx  = w ? var_get_case_index (w) :
486     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
487
488   struct ccase *c1;
489
490   struct casereader *rclone = casereader_clone (r1);
491   struct casewriter *wtr;
492   struct caseproto *proto = caseproto_create ();
493
494   proto = caseproto_add_width (proto, 0);
495   proto = caseproto_add_width (proto, 0);
496   proto = caseproto_add_width (proto, 0);
497
498   wtr = autopaging_writer_create (proto);
499
500   *cc = 0;
501
502   for (; (c1 = casereader_read (r1)); case_unref (c1))
503     {
504       struct ccase *new_case = case_create (proto);
505       struct ccase *c2;
506       struct casereader *r2 = casereader_clone (rclone);
507
508       const double weight1 = case_num_idx (c1, weight_idx);
509       const double d1 = case_num (c1, var);
510       double n_eq = 0.0;
511       double n_pred = 0.0;
512
513       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
514                                          pos_cond,
515                                          true_index, false_index);
516
517       *cc += weight1;
518
519       for (; (c2 = casereader_read (r2)); case_unref (c2))
520         {
521           const double d2 = case_num (c2, var);
522           const double weight2 = case_num_idx (c2, weight_idx);
523
524           if (d1 == d2)
525             {
526               n_eq += weight2;
527               continue;
528             }
529           else  if (pred (d2, d1))
530             {
531               n_pred += weight2;
532             }
533         }
534
535       *case_num_rw_idx (new_case, VALUE) = d1;
536       *case_num_rw_idx (new_case, N_EQ) = n_eq;
537       *case_num_rw_idx (new_case, N_PRED) = n_pred;
538
539       casewriter_write (wtr, new_case);
540
541       casereader_destroy (r2);
542     }
543
544
545   casereader_destroy (r1);
546   casereader_destroy (rclone);
547
548   caseproto_unref (proto);
549
550   return casewriter_make_reader (wtr);
551 }
552
553 /* Some more indeces into case data */
554 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
555 #define N_POS_GT 2  /* number of positive cases with values greater than n */
556 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
557 #define N_NEG_LT 4  /* number of negative cases with values less than n */
558
559 static bool
560 gt (double d1, double d2)
561 {
562   return d1 > d2;
563 }
564
565
566 static bool
567 ge (double d1, double d2)
568 {
569   return d1 > d2;
570 }
571
572 static bool
573 lt (double d1, double d2)
574 {
575   return d1 < d2;
576 }
577
578
579 /*
580   Return a casereader with width 3,
581   populated with cases based upon READER.
582   The cases will have the values:
583   (N, number of cases equal to N, number of cases greater than N)
584   As a side effect, update RS->n1 with the number of positive cases.
585 */
586 static struct casereader *
587 process_positive_group (const struct variable *var, struct casereader *reader,
588                         const struct dictionary *dict,
589                         struct roc_state *rs)
590 {
591   return process_group (var, reader, gt, dict, &rs->n1,
592                         &rs->cutpoint_rdr,
593                         ge,
594                         ROC_TP, ROC_FN);
595 }
596
597 /*
598   Return a casereader with width 3,
599   populated with cases based upon READER.
600   The cases will have the values:
601   (N, number of cases equal to N, number of cases less than N)
602   As a side effect, update RS->n2 with the number of negative cases.
603 */
604 static struct casereader *
605 process_negative_group (const struct variable *var, struct casereader *reader,
606                         const struct dictionary *dict,
607                         struct roc_state *rs)
608 {
609   return process_group (var, reader, lt, dict, &rs->n2,
610                         &rs->cutpoint_rdr,
611                         lt,
612                         ROC_TN, ROC_FP);
613 }
614
615
616
617
618 static void
619 append_cutpoint (struct casewriter *writer, double cutpoint)
620 {
621   struct ccase *cc = case_create (casewriter_get_proto (writer));
622
623   *case_num_rw_idx (cc, ROC_CUTPOINT) = cutpoint;
624   *case_num_rw_idx (cc, ROC_TP) = 0;
625   *case_num_rw_idx (cc, ROC_FN) = 0;
626   *case_num_rw_idx (cc, ROC_TN) = 0;
627   *case_num_rw_idx (cc, ROC_FP) = 0;
628
629   casewriter_write (writer, cc);
630 }
631
632
633 /*
634    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
635    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
636    reader will be populated with its final number of cases.
637    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
638    value.  The other entries will be initialised to zero.
639 */
640 static void
641 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
642 {
643   int i;
644   struct casereader *r = casereader_clone (input);
645   struct ccase *c;
646
647   {
648     struct caseproto *proto = caseproto_create ();
649     struct subcase ordering;
650     subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
651
652     proto = caseproto_add_width (proto, 0); /* cutpoint */
653     proto = caseproto_add_width (proto, 0); /* ROC_TP */
654     proto = caseproto_add_width (proto, 0); /* ROC_FN */
655     proto = caseproto_add_width (proto, 0); /* ROC_TN */
656     proto = caseproto_add_width (proto, 0); /* ROC_FP */
657
658     for (i = 0 ; i < roc->n_vars; ++i)
659       {
660         rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
661         rs[i].prev_result = SYSMIS;
662         rs[i].max = -DBL_MAX;
663         rs[i].min = DBL_MAX;
664       }
665
666     caseproto_unref (proto);
667     subcase_destroy (&ordering);
668   }
669
670   for (; (c = casereader_read (r)) != NULL; case_unref (c))
671     {
672       for (i = 0 ; i < roc->n_vars; ++i)
673         {
674           const union value *v = case_data (c, roc->vars[i]);
675           const double result = v->f;
676
677           if (mv_is_value_missing (var_get_missing_values (roc->vars[i]), v)
678               & roc->exclude)
679             continue;
680
681           minimize (&rs[i].min, result);
682           maximize (&rs[i].max, result);
683
684           if (rs[i].prev_result != SYSMIS && rs[i].prev_result != result)
685             {
686               const double mean = (result + rs[i].prev_result) / 2.0;
687               append_cutpoint (rs[i].cutpoint_wtr, mean);
688             }
689
690           rs[i].prev_result = result;
691         }
692     }
693   casereader_destroy (r);
694
695
696   /* Append the min and max cutpoints */
697   for (i = 0 ; i < roc->n_vars; ++i)
698     {
699       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
700       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
701
702       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
703     }
704 }
705
706 static void
707 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
708 {
709   int i;
710
711   struct roc_state *rs = XCALLOC (roc->n_vars,  struct roc_state);
712
713   struct casereader *negatives = NULL;
714   struct casereader *positives = NULL;
715
716   struct caseproto *n_proto = NULL;
717
718   struct subcase up_ordering;
719   struct subcase down_ordering;
720
721   struct casewriter *neg_wtr = NULL;
722
723   struct casereader *input = casereader_create_filter_missing (reader,
724                                                                roc->vars, roc->n_vars,
725                                                                roc->exclude,
726                                                                NULL,
727                                                                NULL);
728
729   input = casereader_create_filter_missing (input,
730                                             &roc->state_var, 1,
731                                             roc->exclude,
732                                             NULL,
733                                             NULL);
734
735   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
736
737   prepare_cutpoints (roc, rs, input);
738
739
740   /* Separate the positive actual state cases from the negative ones */
741   positives =
742     casereader_create_filter_func (input,
743                                    match_positives,
744                                    NULL,
745                                    roc,
746                                    neg_wtr);
747
748   n_proto = caseproto_create ();
749
750   n_proto = caseproto_add_width (n_proto, 0);
751   n_proto = caseproto_add_width (n_proto, 0);
752   n_proto = caseproto_add_width (n_proto, 0);
753   n_proto = caseproto_add_width (n_proto, 0);
754   n_proto = caseproto_add_width (n_proto, 0);
755
756   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
757   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
758
759   for (i = 0 ; i < roc->n_vars; ++i)
760     {
761       struct casewriter *w = NULL;
762       struct casereader *r = NULL;
763
764       struct ccase *c;
765
766       struct ccase *cpos;
767       struct casereader *n_neg_reader ;
768       const struct variable *var = roc->vars[i];
769
770       struct casereader *neg ;
771       struct casereader *pos = casereader_clone (positives);
772
773       struct casereader *n_pos_reader =
774         process_positive_group (var, pos, dict, &rs[i]);
775
776       if (negatives == NULL)
777         {
778           negatives = casewriter_make_reader (neg_wtr);
779         }
780
781       neg = casereader_clone (negatives);
782
783       n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
784
785       /* Merge the n_pos and n_neg casereaders */
786       w = sort_create_writer (&up_ordering, n_proto);
787       for (; (cpos = casereader_read (n_pos_reader)); case_unref (cpos))
788         {
789           struct ccase *pos_case = case_create (n_proto);
790           struct ccase *cneg;
791           const double jpos = case_num_idx (cpos, VALUE);
792
793           while ((cneg = casereader_read (n_neg_reader)))
794             {
795               struct ccase *nc = case_create (n_proto);
796
797               const double jneg = case_num_idx (cneg, VALUE);
798
799               *case_num_rw_idx (nc, VALUE) = jneg;
800               *case_num_rw_idx (nc, N_POS_EQ) = 0;
801
802               *case_num_rw_idx (nc, N_POS_GT) = SYSMIS;
803
804               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
805               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
806
807               casewriter_write (w, nc);
808
809               case_unref (cneg);
810               if (jneg > jpos)
811                 break;
812             }
813
814           *case_num_rw_idx (pos_case, VALUE) = jpos;
815           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
816           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
817           *case_num_rw_idx (pos_case, N_NEG_EQ) = 0;
818           *case_num_rw_idx (pos_case, N_NEG_LT) = SYSMIS;
819
820           casewriter_write (w, pos_case);
821         }
822
823       casereader_destroy (n_pos_reader);
824       casereader_destroy (n_neg_reader);
825
826 /* These aren't used anymore */
827 #undef N_EQ
828 #undef N_PRED
829
830       r = casewriter_make_reader (w);
831
832       /* Propagate the N_POS_GT values from the positive cases
833          to the negative ones */
834       {
835         double prev_pos_gt = rs[i].n1;
836         w = sort_create_writer (&down_ordering, n_proto);
837
838         for (; (c = casereader_read (r)); case_unref (c))
839           {
840             double n_pos_gt = case_num_idx (c, N_POS_GT);
841             struct ccase *nc = case_clone (c);
842
843             if (n_pos_gt == SYSMIS)
844               {
845                 n_pos_gt = prev_pos_gt;
846                 *case_num_rw_idx (nc, N_POS_GT) = n_pos_gt;
847               }
848
849             casewriter_write (w, nc);
850             prev_pos_gt = n_pos_gt;
851           }
852
853         casereader_destroy (r);
854         r = casewriter_make_reader (w);
855       }
856
857       /* Propagate the N_NEG_LT values from the negative cases
858          to the positive ones */
859       {
860         double prev_neg_lt = rs[i].n2;
861         w = sort_create_writer (&up_ordering, n_proto);
862
863         for (; (c = casereader_read (r)); case_unref (c))
864           {
865             double n_neg_lt = case_num_idx (c, N_NEG_LT);
866             struct ccase *nc = case_clone (c);
867
868             if (n_neg_lt == SYSMIS)
869               {
870                 n_neg_lt = prev_neg_lt;
871                 *case_num_rw_idx (nc, N_NEG_LT) = n_neg_lt;
872               }
873
874             casewriter_write (w, nc);
875             prev_neg_lt = n_neg_lt;
876           }
877
878         casereader_destroy (r);
879         r = casewriter_make_reader (w);
880       }
881
882       {
883         struct ccase *prev_case = NULL;
884         for (; (c = casereader_read (r)); case_unref (c))
885           {
886             struct ccase *next_case = casereader_peek (r, 0);
887
888             const double j = case_num_idx (c, VALUE);
889             double n_pos_eq = case_num_idx (c, N_POS_EQ);
890             double n_pos_gt = case_num_idx (c, N_POS_GT);
891             double n_neg_eq = case_num_idx (c, N_NEG_EQ);
892             double n_neg_lt = case_num_idx (c, N_NEG_LT);
893
894             if (prev_case && j == case_num_idx (prev_case, VALUE))
895               {
896                 if (0 ==  case_num_idx (c, N_POS_EQ))
897                   {
898                     n_pos_eq = case_num_idx (prev_case, N_POS_EQ);
899                     n_pos_gt = case_num_idx (prev_case, N_POS_GT);
900                   }
901
902                 if (0 ==  case_num_idx (c, N_NEG_EQ))
903                   {
904                     n_neg_eq = case_num_idx (prev_case, N_NEG_EQ);
905                     n_neg_lt = case_num_idx (prev_case, N_NEG_LT);
906                   }
907               }
908
909             if (NULL == next_case || j != case_num_idx (next_case, VALUE))
910               {
911                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
912
913                 rs[i].q1hat +=
914                   n_neg_eq * (pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
915                 rs[i].q2hat +=
916                   n_pos_eq * (pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
917
918               }
919
920             case_unref (next_case);
921             case_unref (prev_case);
922             prev_case = case_clone (c);
923           }
924         casereader_destroy (r);
925         case_unref (prev_case);
926
927         rs[i].auc /=  rs[i].n1 * rs[i].n2;
928         if (roc->invert)
929           rs[i].auc = 1 - rs[i].auc;
930
931         if (roc->bi_neg_exp)
932           {
933             rs[i].q1hat = rs[i].auc / (2 - rs[i].auc);
934             rs[i].q2hat = 2 * pow2 (rs[i].auc) / (1 + rs[i].auc);
935           }
936         else
937           {
938             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
939             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
940           }
941       }
942     }
943
944   casereader_destroy (positives);
945   casereader_destroy (negatives);
946
947   caseproto_unref (n_proto);
948   subcase_destroy (&up_ordering);
949   subcase_destroy (&down_ordering);
950
951   output_roc (rs, roc);
952
953   for (i = 0 ; i < roc->n_vars; ++i)
954     casereader_destroy (rs[i].cutpoint_rdr);
955
956   free (rs);
957 }
958
959 static void
960 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
961 {
962   struct pivot_table *table = pivot_table_create (N_("Area Under the Curve"));
963
964   struct pivot_dimension *statistics = pivot_dimension_create (
965     table, PIVOT_AXIS_COLUMN, N_("Statistics"),
966     N_("Area"), PIVOT_RC_OTHER);
967   if (roc->print_se)
968     {
969       pivot_category_create_leaves (
970         statistics->root,
971         N_("Std. Error"), PIVOT_RC_OTHER,
972         N_("Asymptotic Sig."), PIVOT_RC_SIGNIFICANCE);
973       struct pivot_category *interval = pivot_category_create_group__ (
974         statistics->root,
975         pivot_value_new_text_format (N_("Asymp. %g%% Confidence Interval"),
976                                      roc->ci));
977       pivot_category_create_leaves (interval,
978                                     N_("Lower Bound"), PIVOT_RC_OTHER,
979                                     N_("Upper Bound"), PIVOT_RC_OTHER);
980     }
981
982   struct pivot_dimension *variables = pivot_dimension_create (
983     table, PIVOT_AXIS_ROW, N_("Variable under test"));
984   variables->root->show_label = true;
985
986   for (size_t i = 0 ; i < roc->n_vars ; ++i)
987     {
988       int var_idx = pivot_category_create_leaf (
989         variables->root, pivot_value_new_variable (roc->vars[i]));
990
991       pivot_table_put2 (table, 0, var_idx, pivot_value_new_number (rs[i].auc));
992
993       if (roc->print_se)
994         {
995           double se = (rs[i].auc * (1 - rs[i].auc)
996                        + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc))
997                        + (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc)));
998           se /= rs[i].n1 * rs[i].n2;
999           se = sqrt (se);
1000
1001           double ci = 1 - roc->ci / 100.0;
1002           double yy = gsl_cdf_gaussian_Qinv (ci, se);
1003
1004           double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1005                                 (12 * rs[i].n1 * rs[i].n2));
1006           double sig = 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5)
1007                                                         / sd_0_5));
1008           double entries[] = { se, sig, rs[i].auc - yy, rs[i].auc + yy };
1009           for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1010             pivot_table_put2 (table, i + 1, var_idx,
1011                               pivot_value_new_number (entries[i]));
1012         }
1013     }
1014
1015   pivot_table_submit (table);
1016 }
1017
1018
1019 static void
1020 show_summary (const struct cmd_roc *roc)
1021 {
1022   struct pivot_table *table = pivot_table_create (N_("Case Summary"));
1023
1024   struct pivot_dimension *statistics = pivot_dimension_create (
1025     table, PIVOT_AXIS_COLUMN, N_("Valid N (listwise)"),
1026     N_("Unweighted"), PIVOT_RC_INTEGER,
1027     N_("Weighted"), PIVOT_RC_OTHER);
1028   statistics->root->show_label = true;
1029
1030   struct pivot_dimension *cases = pivot_dimension_create__ (
1031     table, PIVOT_AXIS_ROW, pivot_value_new_variable (roc->state_var));
1032   cases->root->show_label = true;
1033   pivot_category_create_leaves (cases->root, N_("Positive"), N_("Negative"));
1034
1035   struct entry
1036     {
1037       int stat_idx;
1038       int case_idx;
1039       double x;
1040     }
1041   entries[] = {
1042     { 0, 0, roc->pos },
1043     { 0, 1, roc->neg },
1044     { 1, 0, roc->pos_weighted },
1045     { 1, 1, roc->neg_weighted },
1046   };
1047   for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1048     {
1049       const struct entry *e = &entries[i];
1050       pivot_table_put2 (table, e->stat_idx, e->case_idx,
1051                         pivot_value_new_number (e->x));
1052     }
1053   pivot_table_submit (table);
1054 }
1055
1056 static void
1057 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1058 {
1059   struct pivot_table *table = pivot_table_create (
1060     N_("Coordinates of the Curve"));
1061
1062   pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
1063                           N_("Positive if greater than or equal to"),
1064                           N_("Sensitivity"), N_("1 - Specificity"));
1065
1066   struct pivot_dimension *coordinates = pivot_dimension_create (
1067     table, PIVOT_AXIS_ROW, N_("Coordinates"));
1068   coordinates->hide_all_labels = true;
1069
1070   struct pivot_dimension *variables = pivot_dimension_create (
1071     table, PIVOT_AXIS_ROW, N_("Test variable"));
1072   variables->root->show_label = true;
1073
1074
1075   int n_coords = 0;
1076   for (size_t i = 0; i < roc->n_vars; ++i)
1077     {
1078       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1079
1080       int var_idx = pivot_category_create_leaf (
1081         variables->root, pivot_value_new_variable (roc->vars[i]));
1082
1083       struct ccase *cc;
1084       int coord_idx = 0;
1085       for (; (cc = casereader_read (r)) != NULL; case_unref (cc))
1086         {
1087           const double se = case_num_idx (cc, ROC_TP) /
1088             (case_num_idx (cc, ROC_TP) + case_num_idx (cc, ROC_FN));
1089
1090           const double sp = case_num_idx (cc, ROC_TN) /
1091             (case_num_idx (cc, ROC_TN) + case_num_idx (cc, ROC_FP));
1092
1093           if (coord_idx >= n_coords)
1094             {
1095               assert (coord_idx == n_coords);
1096               pivot_category_create_leaf (
1097                 coordinates->root, pivot_value_new_integer (++n_coords));
1098             }
1099
1100           pivot_table_put3 (
1101             table, 0, coord_idx, var_idx,
1102             pivot_value_new_var_value (roc->vars[i],
1103                                        case_data_idx (cc, ROC_CUTPOINT)));
1104
1105           pivot_table_put3 (table, 1, coord_idx, var_idx,
1106                             pivot_value_new_number (se));
1107           pivot_table_put3 (table, 2, coord_idx, var_idx,
1108                             pivot_value_new_number (1 - sp));
1109           coord_idx++;
1110         }
1111
1112       casereader_destroy (r);
1113     }
1114
1115   pivot_table_submit (table);
1116 }
1117
1118
1119 static void
1120 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1121 {
1122   show_summary (roc);
1123
1124   if (roc->curve)
1125     {
1126       struct roc_chart *rc;
1127       size_t i;
1128
1129       rc = roc_chart_create (roc->reference);
1130       for (i = 0; i < roc->n_vars; i++)
1131         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1132                            rs[i].cutpoint_rdr);
1133       roc_chart_submit (rc);
1134     }
1135
1136   show_auc (rs, roc);
1137
1138   if (roc->print_coords)
1139     show_coords (rs, roc);
1140 }
1141