Move all command implementations into a single 'commands' directory.
[pspp] / src / language / commands / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/commands/roc.h"
20
21 #include <gsl/gsl_cdf.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/charts/roc-chart.h"
37 #include "output/pivot-table.h"
38
39 #include "gettext.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) msgid
42
43 struct cmd_roc
44 {
45   size_t n_vars;
46   const struct variable **vars;
47   const struct dictionary *dict;
48
49   const struct variable *state_var;
50   union value state_value;
51   size_t state_var_width;
52
53   /* Plot the roc curve */
54   bool curve;
55   /* Plot the reference line */
56   bool reference;
57
58   double ci;
59
60   bool print_coords;
61   bool print_se;
62   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
63                       should be used */
64   enum mv_class exclude;
65
66   bool invert; /* True iff a smaller test result variable indicates
67                   a positive result */
68
69   double pos;
70   double neg;
71   double pos_weighted;
72   double neg_weighted;
73 };
74
75 static int run_roc (struct dataset *, struct cmd_roc *);
76 static void do_roc (struct cmd_roc *, struct casereader *, struct dictionary *);
77
78
79 int
80 cmd_roc (struct lexer *lexer, struct dataset *ds)
81 {
82   const struct dictionary *dict = dataset_dict (ds);
83
84   struct cmd_roc roc = {
85     .exclude = MV_ANY,
86     .curve = true,
87     .ci = 95,
88     .dict = dict,
89     .state_var_width = -1,
90   };
91
92   lex_match (lexer, T_SLASH);
93   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
94                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
95     goto error;
96
97   if (!lex_force_match (lexer, T_BY))
98     goto error;
99
100   roc.state_var = parse_variable (lexer, dict);
101   if (!roc.state_var)
102     goto error;
103
104   if (!lex_force_match (lexer, T_LPAREN))
105     goto error;
106
107   roc.state_var_width = var_get_width (roc.state_var);
108   value_init (&roc.state_value, roc.state_var_width);
109   if (!parse_value (lexer, &roc.state_value, roc.state_var)
110       || !lex_force_match (lexer, T_RPAREN))
111     goto error;
112
113   while (lex_token (lexer) != T_ENDCMD)
114     {
115       lex_match (lexer, T_SLASH);
116       if (lex_match_id (lexer, "MISSING"))
117         {
118           lex_match (lexer, T_EQUALS);
119           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
120             {
121               if (lex_match_id (lexer, "INCLUDE"))
122                 roc.exclude = MV_SYSTEM;
123               else if (lex_match_id (lexer, "EXCLUDE"))
124                 roc.exclude = MV_ANY;
125               else
126                 {
127                   lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
128                   goto error;
129                 }
130             }
131         }
132       else if (lex_match_id (lexer, "PLOT"))
133         {
134           lex_match (lexer, T_EQUALS);
135           if (lex_match_id (lexer, "CURVE"))
136             {
137               roc.curve = true;
138               if (lex_match (lexer, T_LPAREN))
139                 {
140                   roc.reference = true;
141                   if (!lex_force_match_id (lexer, "REFERENCE")
142                       || !lex_force_match (lexer, T_RPAREN))
143                     goto error;
144                 }
145             }
146           else if (lex_match_id (lexer, "NONE"))
147             roc.curve = false;
148           else
149             {
150               lex_error_expecting (lexer, "CURVE", "NONE");
151               goto error;
152             }
153         }
154       else if (lex_match_id (lexer, "PRINT"))
155         {
156           lex_match (lexer, T_EQUALS);
157           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
158             {
159               if (lex_match_id (lexer, "SE"))
160                 roc.print_se = true;
161               else if (lex_match_id (lexer, "COORDINATES"))
162                 roc.print_coords = true;
163               else
164                 {
165                   lex_error_expecting (lexer, "SE", "COORDINATES");
166                   goto error;
167                 }
168             }
169         }
170       else if (lex_match_id (lexer, "CRITERIA"))
171         {
172           lex_match (lexer, T_EQUALS);
173           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
174             {
175               if (lex_match_id (lexer, "CUTOFF"))
176                 {
177                   if (!lex_force_match (lexer, T_LPAREN))
178                     goto error;
179                   if (lex_match_id (lexer, "INCLUDE"))
180                     roc.exclude = MV_SYSTEM;
181                   else if (lex_match_id (lexer, "EXCLUDE"))
182                     roc.exclude = MV_USER | MV_SYSTEM;
183                   else
184                     {
185                       lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
186                       goto error;
187                     }
188                   if (!lex_force_match (lexer, T_RPAREN))
189                     goto error;
190                 }
191               else if (lex_match_id (lexer, "TESTPOS"))
192                 {
193                   if (!lex_force_match (lexer, T_LPAREN))
194                     goto error;
195                   if (lex_match_id (lexer, "LARGE"))
196                     roc.invert = false;
197                   else if (lex_match_id (lexer, "SMALL"))
198                     roc.invert = true;
199                   else
200                     {
201                       lex_error_expecting (lexer, "LARGE", "SMALL");
202                       goto error;
203                     }
204                   if (!lex_force_match (lexer, T_RPAREN))
205                     goto error;
206                 }
207               else if (lex_match_id (lexer, "CI"))
208                 {
209                   if (!lex_force_match (lexer, T_LPAREN))
210                     goto error;
211                   if (!lex_force_num (lexer))
212                     goto error;
213                   roc.ci = lex_number (lexer);
214                   lex_get (lexer);
215                   if (!lex_force_match (lexer, T_RPAREN))
216                     goto error;
217                 }
218               else if (lex_match_id (lexer, "DISTRIBUTION"))
219                 {
220                   if (!lex_force_match (lexer, T_LPAREN))
221                     goto error;
222                   if (lex_match_id (lexer, "FREE"))
223                     roc.bi_neg_exp = false;
224                   else if (lex_match_id (lexer, "NEGEXPO"))
225                     roc.bi_neg_exp = true;
226                   else
227                     {
228                       lex_error_expecting (lexer, "FREE", "NEGEXPO");
229                       goto error;
230                     }
231                   if (!lex_force_match (lexer, T_RPAREN))
232                     goto error;
233                 }
234               else
235                 {
236                   lex_error_expecting (lexer, "CUTOFF", "TESTPOS", "CI",
237                                        "DISTRIBUTION");
238                   goto error;
239                 }
240             }
241         }
242       else
243         {
244           lex_error_expecting (lexer, "MISSING", "PLOT", "PRINT", "CRITERIA");
245           goto error;
246         }
247     }
248
249   if (!run_roc (ds, &roc))
250     goto error;
251
252   if (roc.state_var)
253     value_destroy (&roc.state_value, roc.state_var_width);
254   free (roc.vars);
255   return CMD_SUCCESS;
256
257  error:
258   if (roc.state_var)
259     value_destroy (&roc.state_value, roc.state_var_width);
260   free (roc.vars);
261   return CMD_FAILURE;
262 }
263
264 static int
265 run_roc (struct dataset *ds, struct cmd_roc *roc)
266 {
267   struct dictionary *dict = dataset_dict (ds);
268   struct casereader *group;
269
270   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
271   while (casegrouper_get_next_group (grouper, &group))
272     do_roc (roc, group, dataset_dict (ds));
273
274   bool ok = casegrouper_destroy (grouper);
275   ok = proc_commit (ds) && ok;
276   return ok;
277 }
278
279 #if 0
280 static void
281 dump_casereader (struct casereader *reader)
282 {
283   struct ccase *c;
284   struct casereader *r = casereader_clone (reader);
285
286   for (; (c = casereader_read (r)); case_unref (c))
287     {
288       for (size_t i = 0; i < case_get_n_values (c); ++i)
289         printf ("%g ", case_num_idx (c, i));
290       printf ("\n");
291     }
292
293   casereader_destroy (r);
294 }
295 #endif
296
297
298 /*
299    Return true iff the state variable indicates that C has positive actual state.
300
301    As a side effect, this function also accumulates the roc->{pos,neg} and
302    roc->{pos,neg}_weighted counts.
303  */
304 static bool
305 match_positives (const struct ccase *c, void *aux)
306 {
307   struct cmd_roc *roc = aux;
308   const struct variable *wv = dict_get_weight (roc->dict);
309   const double weight = wv ? case_num (c, wv) : 1.0;
310
311   const bool positive =
312   (0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
313     var_get_width (roc->state_var)));
314
315   if (positive)
316     {
317       roc->pos++;
318       roc->pos_weighted += weight;
319     }
320   else
321     {
322       roc->neg++;
323       roc->neg_weighted += weight;
324     }
325
326   return positive;
327 }
328
329
330 #define VALUE  0
331 #define N_EQ   1
332 #define N_PRED 2
333
334 /* Some intermediate state for calculating the cutpoints and the
335    standard error values */
336 struct roc_state
337 {
338   double auc;  /* Area under the curve */
339
340   double n1;  /* total weight of positives */
341   double n2;  /* total weight of negatives */
342
343   /* intermediates for standard error */
344   double q1hat;
345   double q2hat;
346
347   /* intermediates for cutpoints */
348   struct casewriter *cutpoint_wtr;
349   struct casereader *cutpoint_rdr;
350   double prev_result;
351   double min;
352   double max;
353 };
354
355 /*
356    Return a new casereader based upon CUTPOINT_RDR.
357    The number of "positive" cases are placed into
358    the position TRUE_INDEX, and the number of "negative" cases
359    into FALSE_INDEX.
360    POS_COND and RESULT determine the semantics of what is
361    "positive".
362    WEIGHT is the value of a single count.
363  */
364 static struct casereader *
365 accumulate_counts (struct casereader *input,
366                    double result, double weight,
367                    bool (*pos_cond) (double, double),
368                    int true_index, int false_index)
369 {
370   const struct caseproto *proto = casereader_get_proto (input);
371   struct casewriter *w =
372     autopaging_writer_create (proto);
373   struct ccase *cpc;
374   double prev_cp = SYSMIS;
375
376   for (; (cpc = casereader_read (input)); case_unref (cpc))
377     {
378       struct ccase *new_case;
379       const double cp = case_num_idx (cpc, ROC_CUTPOINT);
380
381       assert (cp != SYSMIS);
382
383       /* We don't want duplicates here */
384       if (cp == prev_cp)
385         continue;
386
387       new_case = case_clone (cpc);
388
389       int index = pos_cond (result, cp) ? true_index : false_index;
390       *case_num_rw_idx (new_case, index) += weight;
391
392       prev_cp = cp;
393
394       casewriter_write (w, new_case);
395     }
396   casereader_destroy (input);
397
398   return casewriter_make_reader (w);
399 }
400
401
402
403 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
404
405 /*
406   This function does 3 things:
407
408   1. Counts the number of cases which are equal to every other case in READER,
409   and those cases for which the relationship between it and every other case
410   satifies PRED (normally either > or <).  VAR is variable defining a case's value
411   for this purpose.
412
413   2. Counts the number of true and false cases in reader, and populates
414   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
415   which receive these values.  POS_COND is the condition defining true
416   and false.
417
418   3. CC is filled with the cumulative weight of all cases of READER.
419 */
420 static struct casereader *
421 process_group (const struct variable *var, struct casereader *reader,
422                bool (*pred) (double, double),
423                const struct dictionary *dict,
424                double *cc,
425                struct casereader **cutpoint_rdr,
426                bool (*pos_cond) (double, double),
427                int true_index,
428                int false_index)
429 {
430   const struct variable *w = dict_get_weight (dict);
431
432   struct casereader *r1 =
433     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
434
435   const int weight_idx  = w ? var_get_case_index (w) :
436     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
437
438   struct ccase *c1;
439
440   struct casereader *rclone = casereader_clone (r1);
441   struct casewriter *wtr;
442   struct caseproto *proto = caseproto_create ();
443
444   proto = caseproto_add_width (proto, 0);
445   proto = caseproto_add_width (proto, 0);
446   proto = caseproto_add_width (proto, 0);
447
448   wtr = autopaging_writer_create (proto);
449
450   *cc = 0;
451
452   for (; (c1 = casereader_read (r1)); case_unref (c1))
453     {
454       struct ccase *new_case = case_create (proto);
455       struct ccase *c2;
456       struct casereader *r2 = casereader_clone (rclone);
457
458       const double weight1 = case_num_idx (c1, weight_idx);
459       const double d1 = case_num (c1, var);
460       double n_eq = 0.0;
461       double n_pred = 0.0;
462
463       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
464                                          pos_cond,
465                                          true_index, false_index);
466
467       *cc += weight1;
468
469       for (; (c2 = casereader_read (r2)); case_unref (c2))
470         {
471           const double d2 = case_num (c2, var);
472           const double weight2 = case_num_idx (c2, weight_idx);
473
474           if (d1 == d2)
475             {
476               n_eq += weight2;
477               continue;
478             }
479           else  if (pred (d2, d1))
480             {
481               n_pred += weight2;
482             }
483         }
484
485       *case_num_rw_idx (new_case, VALUE) = d1;
486       *case_num_rw_idx (new_case, N_EQ) = n_eq;
487       *case_num_rw_idx (new_case, N_PRED) = n_pred;
488
489       casewriter_write (wtr, new_case);
490
491       casereader_destroy (r2);
492     }
493
494   casereader_destroy (r1);
495   casereader_destroy (rclone);
496
497   caseproto_unref (proto);
498
499   return casewriter_make_reader (wtr);
500 }
501
502 /* Some more indeces into case data */
503 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
504 #define N_POS_GT 2  /* number of positive cases with values greater than n */
505 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
506 #define N_NEG_LT 4  /* number of negative cases with values less than n */
507
508 static bool
509 gt (double d1, double d2)
510 {
511   return d1 > d2;
512 }
513
514
515 static bool
516 ge (double d1, double d2)
517 {
518   return d1 > d2;
519 }
520
521 static bool
522 lt (double d1, double d2)
523 {
524   return d1 < d2;
525 }
526
527
528 /*
529   Return a casereader with width 3,
530   populated with cases based upon READER.
531   The cases will have the values:
532   (N, number of cases equal to N, number of cases greater than N)
533   As a side effect, update RS->n1 with the number of positive cases.
534 */
535 static struct casereader *
536 process_positive_group (const struct variable *var, struct casereader *reader,
537                         const struct dictionary *dict,
538                         struct roc_state *rs)
539 {
540   return process_group (var, reader, gt, dict, &rs->n1,
541                         &rs->cutpoint_rdr,
542                         ge,
543                         ROC_TP, ROC_FN);
544 }
545
546 /*
547   Return a casereader with width 3,
548   populated with cases based upon READER.
549   The cases will have the values:
550   (N, number of cases equal to N, number of cases less than N)
551   As a side effect, update RS->n2 with the number of negative cases.
552 */
553 static struct casereader *
554 process_negative_group (const struct variable *var, struct casereader *reader,
555                         const struct dictionary *dict,
556                         struct roc_state *rs)
557 {
558   return process_group (var, reader, lt, dict, &rs->n2,
559                         &rs->cutpoint_rdr,
560                         lt,
561                         ROC_TN, ROC_FP);
562 }
563
564
565
566
567 static void
568 append_cutpoint (struct casewriter *writer, double cutpoint)
569 {
570   struct ccase *cc = case_create (casewriter_get_proto (writer));
571
572   *case_num_rw_idx (cc, ROC_CUTPOINT) = cutpoint;
573   *case_num_rw_idx (cc, ROC_TP) = 0;
574   *case_num_rw_idx (cc, ROC_FN) = 0;
575   *case_num_rw_idx (cc, ROC_TN) = 0;
576   *case_num_rw_idx (cc, ROC_FP) = 0;
577
578   casewriter_write (writer, cc);
579 }
580
581 /*
582    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the
583    readers will be created with width 5, ready to take the values (cutpoint,
584    ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the reader will be populated with its
585    final number of cases.  However on exit from this function, only
586    ROC_CUTPOINT entries will be set to their final value.  The other entries
587    will be initialised to zero.
588 */
589 static struct roc_state *
590 prepare_cutpoints (struct cmd_roc *roc, struct casereader *input)
591 {
592   struct casereader *r = casereader_clone (input);
593   struct ccase *c;
594
595   struct subcase ordering;
596   subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
597
598   struct caseproto *proto = caseproto_create ();
599   proto = caseproto_add_width (proto, 0); /* cutpoint */
600   proto = caseproto_add_width (proto, 0); /* ROC_TP */
601   proto = caseproto_add_width (proto, 0); /* ROC_FN */
602   proto = caseproto_add_width (proto, 0); /* ROC_TN */
603   proto = caseproto_add_width (proto, 0); /* ROC_FP */
604
605   struct roc_state *rs = xnmalloc (roc->n_vars, sizeof *rs);
606   for (size_t i = 0; i < roc->n_vars; ++i)
607     rs[i] = (struct roc_state) {
608       .cutpoint_wtr = sort_create_writer (&ordering, proto),
609       .prev_result = SYSMIS,
610       .max = -DBL_MAX,
611       .min = DBL_MAX,
612     };
613
614   caseproto_unref (proto);
615   subcase_uninit (&ordering);
616
617   for (; (c = casereader_read (r)) != NULL; case_unref (c))
618     for (size_t i = 0; i < roc->n_vars; ++i)
619       {
620         const union value *v = case_data (c, roc->vars[i]);
621         const double result = v->f;
622
623         if (mv_is_value_missing (var_get_missing_values (roc->vars[i]), v)
624             & roc->exclude)
625           continue;
626
627         minimize (&rs[i].min, result);
628         maximize (&rs[i].max, result);
629
630         if (rs[i].prev_result != SYSMIS && rs[i].prev_result != result)
631           {
632             const double mean = (result + rs[i].prev_result) / 2.0;
633             append_cutpoint (rs[i].cutpoint_wtr, mean);
634           }
635
636         rs[i].prev_result = result;
637       }
638   casereader_destroy (r);
639
640   /* Append the min and max cutpoints */
641   for (size_t i = 0; i < roc->n_vars; ++i)
642     {
643       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
644       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
645
646       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
647     }
648
649   return rs;
650 }
651
652 static void
653 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
654 {
655   struct casereader *input = casereader_create_filter_missing (
656     reader, roc->vars, roc->n_vars, roc->exclude, NULL, NULL);
657   input = casereader_create_filter_missing (
658     input, &roc->state_var, 1, roc->exclude, NULL, NULL);
659
660   struct roc_state *rs = prepare_cutpoints (roc, input);
661
662   /* Separate the positive actual state cases from the negative ones */
663   struct casewriter *neg_wtr
664     = autopaging_writer_create (casereader_get_proto (input));
665   struct casereader *positives = casereader_create_filter_func (
666     input, match_positives, NULL, roc, neg_wtr);
667
668   struct caseproto *n_proto = caseproto_create ();
669   for (size_t i = 0; i < 5; i++)
670     n_proto = caseproto_add_width (n_proto, 0);
671
672   struct subcase up_ordering;
673   struct subcase down_ordering;
674   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
675   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
676
677   struct casereader *negatives = NULL;
678   for (size_t i = 0; i < roc->n_vars; ++i)
679     {
680       const struct variable *var = roc->vars[i];
681
682       struct casereader *pos = casereader_clone (positives);
683
684       struct casereader *n_pos_reader =
685         process_positive_group (var, pos, dict, &rs[i]);
686
687       if (!negatives)
688         negatives = casewriter_make_reader (neg_wtr);
689
690       struct casereader *neg = casereader_clone (negatives);
691       struct casereader *n_neg_reader
692         = process_negative_group (var, neg, dict, &rs[i]);
693
694       /* Merge the n_pos and n_neg casereaders */
695       struct casewriter *w = sort_create_writer (&up_ordering, n_proto);
696       struct ccase *cpos;
697       for (; (cpos = casereader_read (n_pos_reader)); case_unref (cpos))
698         {
699           struct ccase *pos_case = case_create (n_proto);
700           const double jpos = case_num_idx (cpos, VALUE);
701
702           struct ccase *cneg;
703           while ((cneg = casereader_read (n_neg_reader)))
704             {
705               struct ccase *nc = case_create (n_proto);
706
707               const double jneg = case_num_idx (cneg, VALUE);
708
709               *case_num_rw_idx (nc, VALUE) = jneg;
710               *case_num_rw_idx (nc, N_POS_EQ) = 0;
711
712               *case_num_rw_idx (nc, N_POS_GT) = SYSMIS;
713
714               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
715               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
716
717               casewriter_write (w, nc);
718
719               case_unref (cneg);
720               if (jneg > jpos)
721                 break;
722             }
723
724           *case_num_rw_idx (pos_case, VALUE) = jpos;
725           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
726           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
727           *case_num_rw_idx (pos_case, N_NEG_EQ) = 0;
728           *case_num_rw_idx (pos_case, N_NEG_LT) = SYSMIS;
729
730           casewriter_write (w, pos_case);
731         }
732
733       casereader_destroy (n_pos_reader);
734       casereader_destroy (n_neg_reader);
735
736       struct casereader *r = casewriter_make_reader (w);
737
738       /* Propagate the N_POS_GT values from the positive cases
739          to the negative ones */
740       double prev_pos_gt = rs[i].n1;
741       w = sort_create_writer (&down_ordering, n_proto);
742
743       struct ccase *c;
744       for (; (c = casereader_read (r)); case_unref (c))
745         {
746           double n_pos_gt = case_num_idx (c, N_POS_GT);
747           struct ccase *nc = case_clone (c);
748
749           if (n_pos_gt == SYSMIS)
750             {
751               n_pos_gt = prev_pos_gt;
752               *case_num_rw_idx (nc, N_POS_GT) = n_pos_gt;
753             }
754
755           casewriter_write (w, nc);
756           prev_pos_gt = n_pos_gt;
757         }
758       casereader_destroy (r);
759       r = casewriter_make_reader (w);
760
761       /* Propagate the N_NEG_LT values from the negative cases
762          to the positive ones */
763       double prev_neg_lt = rs[i].n2;
764       w = sort_create_writer (&up_ordering, n_proto);
765
766       for (; (c = casereader_read (r)); case_unref (c))
767         {
768           double n_neg_lt = case_num_idx (c, N_NEG_LT);
769           struct ccase *nc = case_clone (c);
770
771           if (n_neg_lt == SYSMIS)
772             {
773               n_neg_lt = prev_neg_lt;
774               *case_num_rw_idx (nc, N_NEG_LT) = n_neg_lt;
775             }
776
777           casewriter_write (w, nc);
778           prev_neg_lt = n_neg_lt;
779         }
780
781       casereader_destroy (r);
782       r = casewriter_make_reader (w);
783
784       struct ccase *prev_case = NULL;
785       for (; (c = casereader_read (r)); case_unref (c))
786         {
787           struct ccase *next_case = casereader_peek (r, 0);
788
789           const double j = case_num_idx (c, VALUE);
790           double n_pos_eq = case_num_idx (c, N_POS_EQ);
791           double n_pos_gt = case_num_idx (c, N_POS_GT);
792           double n_neg_eq = case_num_idx (c, N_NEG_EQ);
793           double n_neg_lt = case_num_idx (c, N_NEG_LT);
794
795           if (prev_case && j == case_num_idx (prev_case, VALUE))
796             {
797               if (0 ==  case_num_idx (c, N_POS_EQ))
798                 {
799                   n_pos_eq = case_num_idx (prev_case, N_POS_EQ);
800                   n_pos_gt = case_num_idx (prev_case, N_POS_GT);
801                 }
802
803               if (0 ==  case_num_idx (c, N_NEG_EQ))
804                 {
805                   n_neg_eq = case_num_idx (prev_case, N_NEG_EQ);
806                   n_neg_lt = case_num_idx (prev_case, N_NEG_LT);
807                 }
808             }
809
810           if (NULL == next_case || j != case_num_idx (next_case, VALUE))
811             {
812               rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
813
814               rs[i].q1hat +=
815                 n_neg_eq * (pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
816               rs[i].q2hat +=
817                 n_pos_eq * (pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
818
819             }
820
821           case_unref (next_case);
822           case_unref (prev_case);
823           prev_case = case_clone (c);
824         }
825       casereader_destroy (r);
826       case_unref (prev_case);
827
828       rs[i].auc /= rs[i].n1 * rs[i].n2;
829       if (roc->invert)
830         rs[i].auc = 1 - rs[i].auc;
831
832       if (roc->bi_neg_exp)
833         {
834           rs[i].q1hat = rs[i].auc / (2 - rs[i].auc);
835           rs[i].q2hat = 2 * pow2 (rs[i].auc) / (1 + rs[i].auc);
836         }
837       else
838         {
839           rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
840           rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
841         }
842     }
843
844   casereader_destroy (positives);
845   casereader_destroy (negatives);
846
847   caseproto_unref (n_proto);
848   subcase_uninit (&up_ordering);
849   subcase_uninit (&down_ordering);
850
851   output_roc (rs, roc);
852
853   for (size_t i = 0; i < roc->n_vars; ++i)
854     casereader_destroy (rs[i].cutpoint_rdr);
855
856   free (rs);
857 }
858
859 static void
860 show_auc (struct roc_state *rs, const struct cmd_roc *roc)
861 {
862   struct pivot_table *table = pivot_table_create (N_("Area Under the Curve"));
863
864   struct pivot_dimension *statistics = pivot_dimension_create (
865     table, PIVOT_AXIS_COLUMN, N_("Statistics"),
866     N_("Area"), PIVOT_RC_OTHER);
867   if (roc->print_se)
868     {
869       pivot_category_create_leaves (
870         statistics->root,
871         N_("Std. Error"), PIVOT_RC_OTHER,
872         N_("Asymptotic Sig."), PIVOT_RC_SIGNIFICANCE);
873       struct pivot_category *interval = pivot_category_create_group__ (
874         statistics->root,
875         pivot_value_new_text_format (N_("Asymp. %g%% Confidence Interval"),
876                                      roc->ci));
877       pivot_category_create_leaves (interval,
878                                     N_("Lower Bound"), PIVOT_RC_OTHER,
879                                     N_("Upper Bound"), PIVOT_RC_OTHER);
880     }
881
882   struct pivot_dimension *variables = pivot_dimension_create (
883     table, PIVOT_AXIS_ROW, N_("Variable under test"));
884   variables->root->show_label = true;
885
886   for (size_t i = 0; i < roc->n_vars; ++i)
887     {
888       int var_idx = pivot_category_create_leaf (
889         variables->root, pivot_value_new_variable (roc->vars[i]));
890
891       pivot_table_put2 (table, 0, var_idx, pivot_value_new_number (rs[i].auc));
892
893       if (roc->print_se)
894         {
895           double se = (rs[i].auc * (1 - rs[i].auc)
896                        + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc))
897                        + (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc)));
898           se /= rs[i].n1 * rs[i].n2;
899           se = sqrt (se);
900
901           double ci = 1 - roc->ci / 100.0;
902           double yy = gsl_cdf_gaussian_Qinv (ci, se);
903
904           double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
905                                 (12 * rs[i].n1 * rs[i].n2));
906           double sig = 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5)
907                                                         / sd_0_5));
908           double entries[] = { se, sig, rs[i].auc - yy, rs[i].auc + yy };
909           for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
910             pivot_table_put2 (table, i + 1, var_idx,
911                               pivot_value_new_number (entries[i]));
912         }
913     }
914
915   pivot_table_submit (table);
916 }
917
918 static void
919 show_summary (const struct cmd_roc *roc)
920 {
921   struct pivot_table *table = pivot_table_create (N_("Case Summary"));
922
923   struct pivot_dimension *statistics = pivot_dimension_create (
924     table, PIVOT_AXIS_COLUMN, N_("Valid N (listwise)"),
925     N_("Unweighted"), PIVOT_RC_INTEGER,
926     N_("Weighted"), PIVOT_RC_OTHER);
927   statistics->root->show_label = true;
928
929   struct pivot_dimension *cases = pivot_dimension_create__ (
930     table, PIVOT_AXIS_ROW, pivot_value_new_variable (roc->state_var));
931   cases->root->show_label = true;
932   pivot_category_create_leaves (cases->root, N_("Positive"), N_("Negative"));
933
934   struct entry
935     {
936       int stat_idx;
937       int case_idx;
938       double x;
939     }
940   entries[] = {
941     { 0, 0, roc->pos },
942     { 0, 1, roc->neg },
943     { 1, 0, roc->pos_weighted },
944     { 1, 1, roc->neg_weighted },
945   };
946   for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
947     {
948       const struct entry *e = &entries[i];
949       pivot_table_put2 (table, e->stat_idx, e->case_idx,
950                         pivot_value_new_number (e->x));
951     }
952   pivot_table_submit (table);
953 }
954
955 static void
956 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
957 {
958   struct pivot_table *table = pivot_table_create (
959     N_("Coordinates of the Curve"));
960
961   pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
962                           N_("Positive if greater than or equal to"),
963                           N_("Sensitivity"), N_("1 - Specificity"));
964
965   struct pivot_dimension *coordinates = pivot_dimension_create (
966     table, PIVOT_AXIS_ROW, N_("Coordinates"));
967   coordinates->hide_all_labels = true;
968
969   struct pivot_dimension *variables = pivot_dimension_create (
970     table, PIVOT_AXIS_ROW, N_("Test variable"));
971   variables->root->show_label = true;
972
973
974   int n_coords = 0;
975   for (size_t i = 0; i < roc->n_vars; ++i)
976     {
977       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
978
979       int var_idx = pivot_category_create_leaf (
980         variables->root, pivot_value_new_variable (roc->vars[i]));
981
982       struct ccase *cc;
983       int coord_idx = 0;
984       for (; (cc = casereader_read (r)) != NULL; case_unref (cc))
985         {
986           const double se = case_num_idx (cc, ROC_TP) /
987             (case_num_idx (cc, ROC_TP) + case_num_idx (cc, ROC_FN));
988
989           const double sp = case_num_idx (cc, ROC_TN) /
990             (case_num_idx (cc, ROC_TN) + case_num_idx (cc, ROC_FP));
991
992           if (coord_idx >= n_coords)
993             {
994               assert (coord_idx == n_coords);
995               pivot_category_create_leaf (
996                 coordinates->root, pivot_value_new_integer (++n_coords));
997             }
998
999           pivot_table_put3 (
1000             table, 0, coord_idx, var_idx,
1001             pivot_value_new_var_value (roc->vars[i],
1002                                        case_data_idx (cc, ROC_CUTPOINT)));
1003
1004           pivot_table_put3 (table, 1, coord_idx, var_idx,
1005                             pivot_value_new_number (se));
1006           pivot_table_put3 (table, 2, coord_idx, var_idx,
1007                             pivot_value_new_number (1 - sp));
1008           coord_idx++;
1009         }
1010
1011       casereader_destroy (r);
1012     }
1013
1014   pivot_table_submit (table);
1015 }
1016
1017 static void
1018 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1019 {
1020   show_summary (roc);
1021
1022   if (roc->curve)
1023     {
1024       struct roc_chart *rc = roc_chart_create (roc->reference);
1025       for (size_t i = 0; i < roc->n_vars; i++)
1026         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1027                            rs[i].cutpoint_rdr);
1028       roc_chart_submit (rc);
1029     }
1030
1031   show_auc (rs, roc);
1032
1033   if (roc->print_coords)
1034     show_coords (rs, roc);
1035 }
1036