output-item: Collapse the inheritance hierarchy into a single struct.
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/roc.h"
20
21 #include <gsl/gsl_cdf.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/charts/roc-chart.h"
37 #include "output/pivot-table.h"
38
39 #include "gettext.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) msgid
42
43 struct cmd_roc
44 {
45   size_t n_vars;
46   const struct variable **vars;
47   const struct dictionary *dict;
48
49   const struct variable *state_var;
50   union value state_value;
51   size_t state_var_width;
52
53   /* Plot the roc curve */
54   bool curve;
55   /* Plot the reference line */
56   bool reference;
57
58   double ci;
59
60   bool print_coords;
61   bool print_se;
62   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
63                       should be used */
64   enum mv_class exclude;
65
66   bool invert ; /* True iff a smaller test result variable indicates
67                    a positive result */
68
69   double pos;
70   double neg;
71   double pos_weighted;
72   double neg_weighted;
73 };
74
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
76
77 int
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
79 {
80   struct cmd_roc roc ;
81   const struct dictionary *dict = dataset_dict (ds);
82
83   roc.vars = NULL;
84   roc.n_vars = 0;
85   roc.print_se = false;
86   roc.print_coords = false;
87   roc.exclude = MV_ANY;
88   roc.curve = true;
89   roc.reference = false;
90   roc.ci = 95;
91   roc.bi_neg_exp = false;
92   roc.invert = false;
93   roc.pos = roc.pos_weighted = 0;
94   roc.neg = roc.neg_weighted = 0;
95   roc.dict = dataset_dict (ds);
96   roc.state_var = NULL;
97   roc.state_var_width = -1;
98
99   lex_match (lexer, T_SLASH);
100   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
101                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
102     goto error;
103
104   if (! lex_force_match (lexer, T_BY))
105     {
106       goto error;
107     }
108
109   roc.state_var = parse_variable (lexer, dict);
110   if (! roc.state_var)
111     {
112       goto error;
113     }
114
115   if (!lex_force_match (lexer, T_LPAREN))
116     {
117       goto error;
118     }
119
120   roc.state_var_width = var_get_width (roc.state_var);
121   value_init (&roc.state_value, roc.state_var_width);
122   parse_value (lexer, &roc.state_value, roc.state_var);
123
124
125   if (!lex_force_match (lexer, T_RPAREN))
126     {
127       goto error;
128     }
129
130   while (lex_token (lexer) != T_ENDCMD)
131     {
132       lex_match (lexer, T_SLASH);
133       if (lex_match_id (lexer, "MISSING"))
134         {
135           lex_match (lexer, T_EQUALS);
136           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
137             {
138               if (lex_match_id (lexer, "INCLUDE"))
139                 {
140                   roc.exclude = MV_SYSTEM;
141                 }
142               else if (lex_match_id (lexer, "EXCLUDE"))
143                 {
144                   roc.exclude = MV_ANY;
145                 }
146               else
147                 {
148                   lex_error (lexer, NULL);
149                   goto error;
150                 }
151             }
152         }
153       else if (lex_match_id (lexer, "PLOT"))
154         {
155           lex_match (lexer, T_EQUALS);
156           if (lex_match_id (lexer, "CURVE"))
157             {
158               roc.curve = true;
159               if (lex_match (lexer, T_LPAREN))
160                 {
161                   roc.reference = true;
162                   if (! lex_force_match_id (lexer, "REFERENCE"))
163                     goto error;
164                   if (! lex_force_match (lexer, T_RPAREN))
165                     goto error;
166                 }
167             }
168           else if (lex_match_id (lexer, "NONE"))
169             {
170               roc.curve = false;
171             }
172           else
173             {
174               lex_error (lexer, NULL);
175               goto error;
176             }
177         }
178       else if (lex_match_id (lexer, "PRINT"))
179         {
180           lex_match (lexer, T_EQUALS);
181           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
182             {
183               if (lex_match_id (lexer, "SE"))
184                 {
185                   roc.print_se = true;
186                 }
187               else if (lex_match_id (lexer, "COORDINATES"))
188                 {
189                   roc.print_coords = true;
190                 }
191               else
192                 {
193                   lex_error (lexer, NULL);
194                   goto error;
195                 }
196             }
197         }
198       else if (lex_match_id (lexer, "CRITERIA"))
199         {
200           lex_match (lexer, T_EQUALS);
201           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
202             {
203               if (lex_match_id (lexer, "CUTOFF"))
204                 {
205                   if (! lex_force_match (lexer, T_LPAREN))
206                     goto error;
207                   if (lex_match_id (lexer, "INCLUDE"))
208                     {
209                       roc.exclude = MV_SYSTEM;
210                     }
211                   else if (lex_match_id (lexer, "EXCLUDE"))
212                     {
213                       roc.exclude = MV_USER | MV_SYSTEM;
214                     }
215                   else
216                     {
217                       lex_error (lexer, NULL);
218                       goto error;
219                     }
220                   if (! lex_force_match (lexer, T_RPAREN))
221                     goto error;
222                 }
223               else if (lex_match_id (lexer, "TESTPOS"))
224                 {
225                   if (! lex_force_match (lexer, T_LPAREN))
226                     goto error;
227                   if (lex_match_id (lexer, "LARGE"))
228                     {
229                       roc.invert = false;
230                     }
231                   else if (lex_match_id (lexer, "SMALL"))
232                     {
233                       roc.invert = true;
234                     }
235                   else
236                     {
237                       lex_error (lexer, NULL);
238                       goto error;
239                     }
240                   if (! lex_force_match (lexer, T_RPAREN))
241                     goto error;
242                 }
243               else if (lex_match_id (lexer, "CI"))
244                 {
245                   if (!lex_force_match (lexer, T_LPAREN))
246                     goto error;
247                   if (! lex_force_num (lexer))
248                     goto error;
249                   roc.ci = lex_number (lexer);
250                   lex_get (lexer);
251                   if (!lex_force_match (lexer, T_RPAREN))
252                     goto error;
253                 }
254               else if (lex_match_id (lexer, "DISTRIBUTION"))
255                 {
256                   if (!lex_force_match (lexer, T_LPAREN))
257                     goto error;
258                   if (lex_match_id (lexer, "FREE"))
259                     {
260                       roc.bi_neg_exp = false;
261                     }
262                   else if (lex_match_id (lexer, "NEGEXPO"))
263                     {
264                       roc.bi_neg_exp = true;
265                     }
266                   else
267                     {
268                       lex_error (lexer, NULL);
269                       goto error;
270                     }
271                   if (!lex_force_match (lexer, T_RPAREN))
272                     goto error;
273                 }
274               else
275                 {
276                   lex_error (lexer, NULL);
277                   goto error;
278                 }
279             }
280         }
281       else
282         {
283           lex_error (lexer, NULL);
284           break;
285         }
286     }
287
288   if (! run_roc (ds, &roc))
289     goto error;
290
291   if (roc.state_var)
292     value_destroy (&roc.state_value, roc.state_var_width);
293   free (roc.vars);
294   return CMD_SUCCESS;
295
296  error:
297   if (roc.state_var)
298     value_destroy (&roc.state_value, roc.state_var_width);
299   free (roc.vars);
300   return CMD_FAILURE;
301 }
302
303
304
305
306 static void
307 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
308
309
310 static int
311 run_roc (struct dataset *ds, struct cmd_roc *roc)
312 {
313   struct dictionary *dict = dataset_dict (ds);
314   bool ok;
315   struct casereader *group;
316
317   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
318   while (casegrouper_get_next_group (grouper, &group))
319     {
320       do_roc (roc, group, dataset_dict (ds));
321     }
322   ok = casegrouper_destroy (grouper);
323   ok = proc_commit (ds) && ok;
324
325   return ok;
326 }
327
328 #if 0
329 static void
330 dump_casereader (struct casereader *reader)
331 {
332   struct ccase *c;
333   struct casereader *r = casereader_clone (reader);
334
335   for (; (c = casereader_read (r)); case_unref (c))
336     {
337       int i;
338       for (i = 0 ; i < case_get_value_cnt (c); ++i)
339         {
340           printf ("%g ", case_data_idx (c, i)->f);
341         }
342       printf ("\n");
343     }
344
345   casereader_destroy (r);
346 }
347 #endif
348
349
350 /*
351    Return true iff the state variable indicates that C has positive actual state.
352
353    As a side effect, this function also accumulates the roc->{pos,neg} and
354    roc->{pos,neg}_weighted counts.
355  */
356 static bool
357 match_positives (const struct ccase *c, void *aux)
358 {
359   struct cmd_roc *roc = aux;
360   const struct variable *wv = dict_get_weight (roc->dict);
361   const double weight = wv ? case_data (c, wv)->f : 1.0;
362
363   const bool positive =
364   (0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
365     var_get_width (roc->state_var)));
366
367   if (positive)
368     {
369       roc->pos++;
370       roc->pos_weighted += weight;
371     }
372   else
373     {
374       roc->neg++;
375       roc->neg_weighted += weight;
376     }
377
378   return positive;
379 }
380
381
382 #define VALUE  0
383 #define N_EQ   1
384 #define N_PRED 2
385
386 /* Some intermediate state for calculating the cutpoints and the
387    standard error values */
388 struct roc_state
389 {
390   double auc;  /* Area under the curve */
391
392   double n1;  /* total weight of positives */
393   double n2;  /* total weight of negatives */
394
395   /* intermediates for standard error */
396   double q1hat;
397   double q2hat;
398
399   /* intermediates for cutpoints */
400   struct casewriter *cutpoint_wtr;
401   struct casereader *cutpoint_rdr;
402   double prev_result;
403   double min;
404   double max;
405 };
406
407 /*
408    Return a new casereader based upon CUTPOINT_RDR.
409    The number of "positive" cases are placed into
410    the position TRUE_INDEX, and the number of "negative" cases
411    into FALSE_INDEX.
412    POS_COND and RESULT determine the semantics of what is
413    "positive".
414    WEIGHT is the value of a single count.
415  */
416 static struct casereader *
417 accumulate_counts (struct casereader *input,
418                    double result, double weight,
419                    bool (*pos_cond) (double, double),
420                    int true_index, int false_index)
421 {
422   const struct caseproto *proto = casereader_get_proto (input);
423   struct casewriter *w =
424     autopaging_writer_create (proto);
425   struct ccase *cpc;
426   double prev_cp = SYSMIS;
427
428   for (; (cpc = casereader_read (input)); case_unref (cpc))
429     {
430       struct ccase *new_case;
431       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
432
433       assert (cp != SYSMIS);
434
435       /* We don't want duplicates here */
436       if (cp == prev_cp)
437         continue;
438
439       new_case = case_clone (cpc);
440
441       if (pos_cond (result, cp))
442         case_data_rw_idx (new_case, true_index)->f += weight;
443       else
444         case_data_rw_idx (new_case, false_index)->f += weight;
445
446       prev_cp = cp;
447
448       casewriter_write (w, new_case);
449     }
450   casereader_destroy (input);
451
452   return casewriter_make_reader (w);
453 }
454
455
456
457 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
458
459 /*
460   This function does 3 things:
461
462   1. Counts the number of cases which are equal to every other case in READER,
463   and those cases for which the relationship between it and every other case
464   satifies PRED (normally either > or <).  VAR is variable defining a case's value
465   for this purpose.
466
467   2. Counts the number of true and false cases in reader, and populates
468   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
469   which receive these values.  POS_COND is the condition defining true
470   and false.
471
472   3. CC is filled with the cumulative weight of all cases of READER.
473 */
474 static struct casereader *
475 process_group (const struct variable *var, struct casereader *reader,
476                bool (*pred) (double, double),
477                const struct dictionary *dict,
478                double *cc,
479                struct casereader **cutpoint_rdr,
480                bool (*pos_cond) (double, double),
481                int true_index,
482                int false_index)
483 {
484   const struct variable *w = dict_get_weight (dict);
485
486   struct casereader *r1 =
487     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
488
489   const int weight_idx  = w ? var_get_case_index (w) :
490     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
491
492   struct ccase *c1;
493
494   struct casereader *rclone = casereader_clone (r1);
495   struct casewriter *wtr;
496   struct caseproto *proto = caseproto_create ();
497
498   proto = caseproto_add_width (proto, 0);
499   proto = caseproto_add_width (proto, 0);
500   proto = caseproto_add_width (proto, 0);
501
502   wtr = autopaging_writer_create (proto);
503
504   *cc = 0;
505
506   for (; (c1 = casereader_read (r1)); case_unref (c1))
507     {
508       struct ccase *new_case = case_create (proto);
509       struct ccase *c2;
510       struct casereader *r2 = casereader_clone (rclone);
511
512       const double weight1 = case_data_idx (c1, weight_idx)->f;
513       const double d1 = case_data (c1, var)->f;
514       double n_eq = 0.0;
515       double n_pred = 0.0;
516
517       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
518                                          pos_cond,
519                                          true_index, false_index);
520
521       *cc += weight1;
522
523       for (; (c2 = casereader_read (r2)); case_unref (c2))
524         {
525           const double d2 = case_data (c2, var)->f;
526           const double weight2 = case_data_idx (c2, weight_idx)->f;
527
528           if (d1 == d2)
529             {
530               n_eq += weight2;
531               continue;
532             }
533           else  if (pred (d2, d1))
534             {
535               n_pred += weight2;
536             }
537         }
538
539       case_data_rw_idx (new_case, VALUE)->f = d1;
540       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
541       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
542
543       casewriter_write (wtr, new_case);
544
545       casereader_destroy (r2);
546     }
547
548
549   casereader_destroy (r1);
550   casereader_destroy (rclone);
551
552   caseproto_unref (proto);
553
554   return casewriter_make_reader (wtr);
555 }
556
557 /* Some more indeces into case data */
558 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
559 #define N_POS_GT 2  /* number of positive cases with values greater than n */
560 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
561 #define N_NEG_LT 4  /* number of negative cases with values less than n */
562
563 static bool
564 gt (double d1, double d2)
565 {
566   return d1 > d2;
567 }
568
569
570 static bool
571 ge (double d1, double d2)
572 {
573   return d1 > d2;
574 }
575
576 static bool
577 lt (double d1, double d2)
578 {
579   return d1 < d2;
580 }
581
582
583 /*
584   Return a casereader with width 3,
585   populated with cases based upon READER.
586   The cases will have the values:
587   (N, number of cases equal to N, number of cases greater than N)
588   As a side effect, update RS->n1 with the number of positive cases.
589 */
590 static struct casereader *
591 process_positive_group (const struct variable *var, struct casereader *reader,
592                         const struct dictionary *dict,
593                         struct roc_state *rs)
594 {
595   return process_group (var, reader, gt, dict, &rs->n1,
596                         &rs->cutpoint_rdr,
597                         ge,
598                         ROC_TP, ROC_FN);
599 }
600
601 /*
602   Return a casereader with width 3,
603   populated with cases based upon READER.
604   The cases will have the values:
605   (N, number of cases equal to N, number of cases less than N)
606   As a side effect, update RS->n2 with the number of negative cases.
607 */
608 static struct casereader *
609 process_negative_group (const struct variable *var, struct casereader *reader,
610                         const struct dictionary *dict,
611                         struct roc_state *rs)
612 {
613   return process_group (var, reader, lt, dict, &rs->n2,
614                         &rs->cutpoint_rdr,
615                         lt,
616                         ROC_TN, ROC_FP);
617 }
618
619
620
621
622 static void
623 append_cutpoint (struct casewriter *writer, double cutpoint)
624 {
625   struct ccase *cc = case_create (casewriter_get_proto (writer));
626
627   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
628   case_data_rw_idx (cc, ROC_TP)->f = 0;
629   case_data_rw_idx (cc, ROC_FN)->f = 0;
630   case_data_rw_idx (cc, ROC_TN)->f = 0;
631   case_data_rw_idx (cc, ROC_FP)->f = 0;
632
633   casewriter_write (writer, cc);
634 }
635
636
637 /*
638    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
639    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
640    reader will be populated with its final number of cases.
641    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
642    value.  The other entries will be initialised to zero.
643 */
644 static void
645 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
646 {
647   int i;
648   struct casereader *r = casereader_clone (input);
649   struct ccase *c;
650
651   {
652     struct caseproto *proto = caseproto_create ();
653     struct subcase ordering;
654     subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
655
656     proto = caseproto_add_width (proto, 0); /* cutpoint */
657     proto = caseproto_add_width (proto, 0); /* ROC_TP */
658     proto = caseproto_add_width (proto, 0); /* ROC_FN */
659     proto = caseproto_add_width (proto, 0); /* ROC_TN */
660     proto = caseproto_add_width (proto, 0); /* ROC_FP */
661
662     for (i = 0 ; i < roc->n_vars; ++i)
663       {
664         rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
665         rs[i].prev_result = SYSMIS;
666         rs[i].max = -DBL_MAX;
667         rs[i].min = DBL_MAX;
668       }
669
670     caseproto_unref (proto);
671     subcase_destroy (&ordering);
672   }
673
674   for (; (c = casereader_read (r)) != NULL; case_unref (c))
675     {
676       for (i = 0 ; i < roc->n_vars; ++i)
677         {
678           const union value *v = case_data (c, roc->vars[i]);
679           const double result = v->f;
680
681           if (mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
682             continue;
683
684           minimize (&rs[i].min, result);
685           maximize (&rs[i].max, result);
686
687           if (rs[i].prev_result != SYSMIS && rs[i].prev_result != result)
688             {
689               const double mean = (result + rs[i].prev_result) / 2.0;
690               append_cutpoint (rs[i].cutpoint_wtr, mean);
691             }
692
693           rs[i].prev_result = result;
694         }
695     }
696   casereader_destroy (r);
697
698
699   /* Append the min and max cutpoints */
700   for (i = 0 ; i < roc->n_vars; ++i)
701     {
702       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
703       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
704
705       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
706     }
707 }
708
709 static void
710 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
711 {
712   int i;
713
714   struct roc_state *rs = XCALLOC (roc->n_vars,  struct roc_state);
715
716   struct casereader *negatives = NULL;
717   struct casereader *positives = NULL;
718
719   struct caseproto *n_proto = NULL;
720
721   struct subcase up_ordering;
722   struct subcase down_ordering;
723
724   struct casewriter *neg_wtr = NULL;
725
726   struct casereader *input = casereader_create_filter_missing (reader,
727                                                                roc->vars, roc->n_vars,
728                                                                roc->exclude,
729                                                                NULL,
730                                                                NULL);
731
732   input = casereader_create_filter_missing (input,
733                                             &roc->state_var, 1,
734                                             roc->exclude,
735                                             NULL,
736                                             NULL);
737
738   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
739
740   prepare_cutpoints (roc, rs, input);
741
742
743   /* Separate the positive actual state cases from the negative ones */
744   positives =
745     casereader_create_filter_func (input,
746                                    match_positives,
747                                    NULL,
748                                    roc,
749                                    neg_wtr);
750
751   n_proto = caseproto_create ();
752
753   n_proto = caseproto_add_width (n_proto, 0);
754   n_proto = caseproto_add_width (n_proto, 0);
755   n_proto = caseproto_add_width (n_proto, 0);
756   n_proto = caseproto_add_width (n_proto, 0);
757   n_proto = caseproto_add_width (n_proto, 0);
758
759   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
760   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
761
762   for (i = 0 ; i < roc->n_vars; ++i)
763     {
764       struct casewriter *w = NULL;
765       struct casereader *r = NULL;
766
767       struct ccase *c;
768
769       struct ccase *cpos;
770       struct casereader *n_neg_reader ;
771       const struct variable *var = roc->vars[i];
772
773       struct casereader *neg ;
774       struct casereader *pos = casereader_clone (positives);
775
776       struct casereader *n_pos_reader =
777         process_positive_group (var, pos, dict, &rs[i]);
778
779       if (negatives == NULL)
780         {
781           negatives = casewriter_make_reader (neg_wtr);
782         }
783
784       neg = casereader_clone (negatives);
785
786       n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
787
788       /* Merge the n_pos and n_neg casereaders */
789       w = sort_create_writer (&up_ordering, n_proto);
790       for (; (cpos = casereader_read (n_pos_reader)); case_unref (cpos))
791         {
792           struct ccase *pos_case = case_create (n_proto);
793           struct ccase *cneg;
794           const double jpos = case_data_idx (cpos, VALUE)->f;
795
796           while ((cneg = casereader_read (n_neg_reader)))
797             {
798               struct ccase *nc = case_create (n_proto);
799
800               const double jneg = case_data_idx (cneg, VALUE)->f;
801
802               case_data_rw_idx (nc, VALUE)->f = jneg;
803               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
804
805               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
806
807               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
808               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
809
810               casewriter_write (w, nc);
811
812               case_unref (cneg);
813               if (jneg > jpos)
814                 break;
815             }
816
817           case_data_rw_idx (pos_case, VALUE)->f = jpos;
818           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
819           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
820           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
821           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
822
823           casewriter_write (w, pos_case);
824         }
825
826       casereader_destroy (n_pos_reader);
827       casereader_destroy (n_neg_reader);
828
829 /* These aren't used anymore */
830 #undef N_EQ
831 #undef N_PRED
832
833       r = casewriter_make_reader (w);
834
835       /* Propagate the N_POS_GT values from the positive cases
836          to the negative ones */
837       {
838         double prev_pos_gt = rs[i].n1;
839         w = sort_create_writer (&down_ordering, n_proto);
840
841         for (; (c = casereader_read (r)); case_unref (c))
842           {
843             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
844             struct ccase *nc = case_clone (c);
845
846             if (n_pos_gt == SYSMIS)
847               {
848                 n_pos_gt = prev_pos_gt;
849                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
850               }
851
852             casewriter_write (w, nc);
853             prev_pos_gt = n_pos_gt;
854           }
855
856         casereader_destroy (r);
857         r = casewriter_make_reader (w);
858       }
859
860       /* Propagate the N_NEG_LT values from the negative cases
861          to the positive ones */
862       {
863         double prev_neg_lt = rs[i].n2;
864         w = sort_create_writer (&up_ordering, n_proto);
865
866         for (; (c = casereader_read (r)); case_unref (c))
867           {
868             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
869             struct ccase *nc = case_clone (c);
870
871             if (n_neg_lt == SYSMIS)
872               {
873                 n_neg_lt = prev_neg_lt;
874                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
875               }
876
877             casewriter_write (w, nc);
878             prev_neg_lt = n_neg_lt;
879           }
880
881         casereader_destroy (r);
882         r = casewriter_make_reader (w);
883       }
884
885       {
886         struct ccase *prev_case = NULL;
887         for (; (c = casereader_read (r)); case_unref (c))
888           {
889             struct ccase *next_case = casereader_peek (r, 0);
890
891             const double j = case_data_idx (c, VALUE)->f;
892             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
893             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
894             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
895             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
896
897             if (prev_case && j == case_data_idx (prev_case, VALUE)->f)
898               {
899                 if (0 ==  case_data_idx (c, N_POS_EQ)->f)
900                   {
901                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
902                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
903                   }
904
905                 if (0 ==  case_data_idx (c, N_NEG_EQ)->f)
906                   {
907                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
908                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
909                   }
910               }
911
912             if (NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
913               {
914                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
915
916                 rs[i].q1hat +=
917                   n_neg_eq * (pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
918                 rs[i].q2hat +=
919                   n_pos_eq * (pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
920
921               }
922
923             case_unref (next_case);
924             case_unref (prev_case);
925             prev_case = case_clone (c);
926           }
927         casereader_destroy (r);
928         case_unref (prev_case);
929
930         rs[i].auc /=  rs[i].n1 * rs[i].n2;
931         if (roc->invert)
932           rs[i].auc = 1 - rs[i].auc;
933
934         if (roc->bi_neg_exp)
935           {
936             rs[i].q1hat = rs[i].auc / (2 - rs[i].auc);
937             rs[i].q2hat = 2 * pow2 (rs[i].auc) / (1 + rs[i].auc);
938           }
939         else
940           {
941             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
942             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
943           }
944       }
945     }
946
947   casereader_destroy (positives);
948   casereader_destroy (negatives);
949
950   caseproto_unref (n_proto);
951   subcase_destroy (&up_ordering);
952   subcase_destroy (&down_ordering);
953
954   output_roc (rs, roc);
955
956   for (i = 0 ; i < roc->n_vars; ++i)
957     casereader_destroy (rs[i].cutpoint_rdr);
958
959   free (rs);
960 }
961
962 static void
963 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
964 {
965   struct pivot_table *table = pivot_table_create (N_("Area Under the Curve"));
966
967   struct pivot_dimension *statistics = pivot_dimension_create (
968     table, PIVOT_AXIS_COLUMN, N_("Statistics"),
969     N_("Area"), PIVOT_RC_OTHER);
970   if (roc->print_se)
971     {
972       pivot_category_create_leaves (
973         statistics->root,
974         N_("Std. Error"), PIVOT_RC_OTHER,
975         N_("Asymptotic Sig."), PIVOT_RC_SIGNIFICANCE);
976       struct pivot_category *interval = pivot_category_create_group__ (
977         statistics->root,
978         pivot_value_new_text_format (N_("Asymp. %g%% Confidence Interval"),
979                                      roc->ci));
980       pivot_category_create_leaves (interval,
981                                     N_("Lower Bound"), PIVOT_RC_OTHER,
982                                     N_("Upper Bound"), PIVOT_RC_OTHER);
983     }
984
985   struct pivot_dimension *variables = pivot_dimension_create (
986     table, PIVOT_AXIS_ROW, N_("Variable under test"));
987   variables->root->show_label = true;
988
989   for (size_t i = 0 ; i < roc->n_vars ; ++i)
990     {
991       int var_idx = pivot_category_create_leaf (
992         variables->root, pivot_value_new_variable (roc->vars[i]));
993
994       pivot_table_put2 (table, 0, var_idx, pivot_value_new_number (rs[i].auc));
995
996       if (roc->print_se)
997         {
998           double se = (rs[i].auc * (1 - rs[i].auc)
999                        + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc))
1000                        + (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc)));
1001           se /= rs[i].n1 * rs[i].n2;
1002           se = sqrt (se);
1003
1004           double ci = 1 - roc->ci / 100.0;
1005           double yy = gsl_cdf_gaussian_Qinv (ci, se);
1006
1007           double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1008                                 (12 * rs[i].n1 * rs[i].n2));
1009           double sig = 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5)
1010                                                         / sd_0_5));
1011           double entries[] = { se, sig, rs[i].auc - yy, rs[i].auc + yy };
1012           for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1013             pivot_table_put2 (table, i + 1, var_idx,
1014                               pivot_value_new_number (entries[i]));
1015         }
1016     }
1017
1018   pivot_table_submit (table);
1019 }
1020
1021
1022 static void
1023 show_summary (const struct cmd_roc *roc)
1024 {
1025   struct pivot_table *table = pivot_table_create (N_("Case Summary"));
1026
1027   struct pivot_dimension *statistics = pivot_dimension_create (
1028     table, PIVOT_AXIS_COLUMN, N_("Valid N (listwise)"),
1029     N_("Unweighted"), PIVOT_RC_INTEGER,
1030     N_("Weighted"), PIVOT_RC_OTHER);
1031   statistics->root->show_label = true;
1032
1033   struct pivot_dimension *cases = pivot_dimension_create__ (
1034     table, PIVOT_AXIS_ROW, pivot_value_new_variable (roc->state_var));
1035   cases->root->show_label = true;
1036   pivot_category_create_leaves (cases->root, N_("Positive"), N_("Negative"));
1037
1038   struct entry
1039     {
1040       int stat_idx;
1041       int case_idx;
1042       double x;
1043     }
1044   entries[] = {
1045     { 0, 0, roc->pos },
1046     { 0, 1, roc->neg },
1047     { 1, 0, roc->pos_weighted },
1048     { 1, 1, roc->neg_weighted },
1049   };
1050   for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1051     {
1052       const struct entry *e = &entries[i];
1053       pivot_table_put2 (table, e->stat_idx, e->case_idx,
1054                         pivot_value_new_number (e->x));
1055     }
1056   pivot_table_submit (table);
1057 }
1058
1059 static void
1060 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1061 {
1062   struct pivot_table *table = pivot_table_create (
1063     N_("Coordinates of the Curve"));
1064
1065   pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
1066                           N_("Positive if greater than or equal to"),
1067                           N_("Sensitivity"), N_("1 - Specificity"));
1068
1069   struct pivot_dimension *coordinates = pivot_dimension_create (
1070     table, PIVOT_AXIS_ROW, N_("Coordinates"));
1071   coordinates->hide_all_labels = true;
1072
1073   struct pivot_dimension *variables = pivot_dimension_create (
1074     table, PIVOT_AXIS_ROW, N_("Test variable"));
1075   variables->root->show_label = true;
1076
1077
1078   int n_coords = 0;
1079   for (size_t i = 0; i < roc->n_vars; ++i)
1080     {
1081       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1082
1083       int var_idx = pivot_category_create_leaf (
1084         variables->root, pivot_value_new_variable (roc->vars[i]));
1085
1086       struct ccase *cc;
1087       int coord_idx = 0;
1088       for (; (cc = casereader_read (r)) != NULL; case_unref (cc))
1089         {
1090           const double se = case_data_idx (cc, ROC_TP)->f /
1091             (case_data_idx (cc, ROC_TP)->f + case_data_idx (cc, ROC_FN)->f);
1092
1093           const double sp = case_data_idx (cc, ROC_TN)->f /
1094             (case_data_idx (cc, ROC_TN)->f + case_data_idx (cc, ROC_FP)->f);
1095
1096           if (coord_idx >= n_coords)
1097             {
1098               assert (coord_idx == n_coords);
1099               pivot_category_create_leaf (
1100                 coordinates->root, pivot_value_new_integer (++n_coords));
1101             }
1102
1103           pivot_table_put3 (
1104             table, 0, coord_idx, var_idx,
1105             pivot_value_new_var_value (roc->vars[i],
1106                                        case_data_idx (cc, ROC_CUTPOINT)));
1107
1108           pivot_table_put3 (table, 1, coord_idx, var_idx,
1109                             pivot_value_new_number (se));
1110           pivot_table_put3 (table, 2, coord_idx, var_idx,
1111                             pivot_value_new_number (1 - sp));
1112           coord_idx++;
1113         }
1114
1115       casereader_destroy (r);
1116     }
1117
1118   pivot_table_submit (table);
1119 }
1120
1121
1122 static void
1123 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1124 {
1125   show_summary (roc);
1126
1127   if (roc->curve)
1128     {
1129       struct roc_chart *rc;
1130       size_t i;
1131
1132       rc = roc_chart_create (roc->reference);
1133       for (i = 0; i < roc->n_vars; i++)
1134         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1135                            rs[i].cutpoint_rdr);
1136       roc_chart_submit (rc);
1137     }
1138
1139   show_auc (rs, roc);
1140
1141   if (roc->print_coords)
1142     show_coords (rs, roc);
1143 }
1144