output: Introduce pivot tables.
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/roc.h"
20
21 #include <gsl/gsl_cdf.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/chart-item.h"
37 #include "output/charts/roc-chart.h"
38 #include "output/pivot-table.h"
39
40 #include "gettext.h"
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
43
44 struct cmd_roc
45 {
46   size_t n_vars;
47   const struct variable **vars;
48   const struct dictionary *dict;
49
50   const struct variable *state_var;
51   union value state_value;
52   size_t state_var_width;
53
54   /* Plot the roc curve */
55   bool curve;
56   /* Plot the reference line */
57   bool reference;
58
59   double ci;
60
61   bool print_coords;
62   bool print_se;
63   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
64                       should be used */
65   enum mv_class exclude;
66
67   bool invert ; /* True iff a smaller test result variable indicates
68                    a positive result */
69
70   double pos;
71   double neg;
72   double pos_weighted;
73   double neg_weighted;
74 };
75
76 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
77
78 int
79 cmd_roc (struct lexer *lexer, struct dataset *ds)
80 {
81   struct cmd_roc roc ;
82   const struct dictionary *dict = dataset_dict (ds);
83
84   roc.vars = NULL;
85   roc.n_vars = 0;
86   roc.print_se = false;
87   roc.print_coords = false;
88   roc.exclude = MV_ANY;
89   roc.curve = true;
90   roc.reference = false;
91   roc.ci = 95;
92   roc.bi_neg_exp = false;
93   roc.invert = false;
94   roc.pos = roc.pos_weighted = 0;
95   roc.neg = roc.neg_weighted = 0;
96   roc.dict = dataset_dict (ds);
97   roc.state_var = NULL;
98   roc.state_var_width = -1;
99
100   lex_match (lexer, T_SLASH);
101   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
102                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
103     goto error;
104
105   if ( ! lex_force_match (lexer, T_BY))
106     {
107       goto error;
108     }
109
110   roc.state_var = parse_variable (lexer, dict);
111   if (! roc.state_var)
112     {
113       goto error;
114     }
115
116   if ( !lex_force_match (lexer, T_LPAREN))
117     {
118       goto error;
119     }
120
121   roc.state_var_width = var_get_width (roc.state_var);
122   value_init (&roc.state_value, roc.state_var_width);
123   parse_value (lexer, &roc.state_value, roc.state_var);
124
125
126   if ( !lex_force_match (lexer, T_RPAREN))
127     {
128       goto error;
129     }
130
131   while (lex_token (lexer) != T_ENDCMD)
132     {
133       lex_match (lexer, T_SLASH);
134       if (lex_match_id (lexer, "MISSING"))
135         {
136           lex_match (lexer, T_EQUALS);
137           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
138             {
139               if (lex_match_id (lexer, "INCLUDE"))
140                 {
141                   roc.exclude = MV_SYSTEM;
142                 }
143               else if (lex_match_id (lexer, "EXCLUDE"))
144                 {
145                   roc.exclude = MV_ANY;
146                 }
147               else
148                 {
149                   lex_error (lexer, NULL);
150                   goto error;
151                 }
152             }
153         }
154       else if (lex_match_id (lexer, "PLOT"))
155         {
156           lex_match (lexer, T_EQUALS);
157           if (lex_match_id (lexer, "CURVE"))
158             {
159               roc.curve = true;
160               if (lex_match (lexer, T_LPAREN))
161                 {
162                   roc.reference = true;
163                   if (! lex_force_match_id (lexer, "REFERENCE"))
164                     goto error;
165                   if (! lex_force_match (lexer, T_RPAREN))
166                     goto error;
167                 }
168             }
169           else if (lex_match_id (lexer, "NONE"))
170             {
171               roc.curve = false;
172             }
173           else
174             {
175               lex_error (lexer, NULL);
176               goto error;
177             }
178         }
179       else if (lex_match_id (lexer, "PRINT"))
180         {
181           lex_match (lexer, T_EQUALS);
182           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
183             {
184               if (lex_match_id (lexer, "SE"))
185                 {
186                   roc.print_se = true;
187                 }
188               else if (lex_match_id (lexer, "COORDINATES"))
189                 {
190                   roc.print_coords = true;
191                 }
192               else
193                 {
194                   lex_error (lexer, NULL);
195                   goto error;
196                 }
197             }
198         }
199       else if (lex_match_id (lexer, "CRITERIA"))
200         {
201           lex_match (lexer, T_EQUALS);
202           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
203             {
204               if (lex_match_id (lexer, "CUTOFF"))
205                 {
206                   if (! lex_force_match (lexer, T_LPAREN))
207                     goto error;
208                   if (lex_match_id (lexer, "INCLUDE"))
209                     {
210                       roc.exclude = MV_SYSTEM;
211                     }
212                   else if (lex_match_id (lexer, "EXCLUDE"))
213                     {
214                       roc.exclude = MV_USER | MV_SYSTEM;
215                     }
216                   else
217                     {
218                       lex_error (lexer, NULL);
219                       goto error;
220                     }
221                   if (! lex_force_match (lexer, T_RPAREN))
222                     goto error;
223                 }
224               else if (lex_match_id (lexer, "TESTPOS"))
225                 {
226                   if (! lex_force_match (lexer, T_LPAREN))
227                     goto error;
228                   if (lex_match_id (lexer, "LARGE"))
229                     {
230                       roc.invert = false;
231                     }
232                   else if (lex_match_id (lexer, "SMALL"))
233                     {
234                       roc.invert = true;
235                     }
236                   else
237                     {
238                       lex_error (lexer, NULL);
239                       goto error;
240                     }
241                   if (! lex_force_match (lexer, T_RPAREN))
242                     goto error;
243                 }
244               else if (lex_match_id (lexer, "CI"))
245                 {
246                   if (!lex_force_match (lexer, T_LPAREN))
247                     goto error;
248                   if (! lex_force_num (lexer))
249                     goto error;
250                   roc.ci = lex_number (lexer);
251                   lex_get (lexer);
252                   if (!lex_force_match (lexer, T_RPAREN))
253                     goto error;
254                 }
255               else if (lex_match_id (lexer, "DISTRIBUTION"))
256                 {
257                   if (!lex_force_match (lexer, T_LPAREN))
258                     goto error;
259                   if (lex_match_id (lexer, "FREE"))
260                     {
261                       roc.bi_neg_exp = false;
262                     }
263                   else if (lex_match_id (lexer, "NEGEXPO"))
264                     {
265                       roc.bi_neg_exp = true;
266                     }
267                   else
268                     {
269                       lex_error (lexer, NULL);
270                       goto error;
271                     }
272                   if (!lex_force_match (lexer, T_RPAREN))
273                     goto error;
274                 }
275               else
276                 {
277                   lex_error (lexer, NULL);
278                   goto error;
279                 }
280             }
281         }
282       else
283         {
284           lex_error (lexer, NULL);
285           break;
286         }
287     }
288
289   if ( ! run_roc (ds, &roc))
290     goto error;
291
292   if ( roc.state_var)
293     value_destroy (&roc.state_value, roc.state_var_width);
294   free (roc.vars);
295   return CMD_SUCCESS;
296
297  error:
298   if ( roc.state_var)
299     value_destroy (&roc.state_value, roc.state_var_width);
300   free (roc.vars);
301   return CMD_FAILURE;
302 }
303
304
305
306
307 static void
308 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
309
310
311 static int
312 run_roc (struct dataset *ds, struct cmd_roc *roc)
313 {
314   struct dictionary *dict = dataset_dict (ds);
315   bool ok;
316   struct casereader *group;
317
318   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
319   while (casegrouper_get_next_group (grouper, &group))
320     {
321       do_roc (roc, group, dataset_dict (ds));
322     }
323   ok = casegrouper_destroy (grouper);
324   ok = proc_commit (ds) && ok;
325
326   return ok;
327 }
328
329 #if 0
330 static void
331 dump_casereader (struct casereader *reader)
332 {
333   struct ccase *c;
334   struct casereader *r = casereader_clone (reader);
335
336   for ( ; (c = casereader_read (r) ); case_unref (c))
337     {
338       int i;
339       for (i = 0 ; i < case_get_value_cnt (c); ++i)
340         {
341           printf ("%g ", case_data_idx (c, i)->f);
342         }
343       printf ("\n");
344     }
345
346   casereader_destroy (r);
347 }
348 #endif
349
350
351 /*
352    Return true iff the state variable indicates that C has positive actual state.
353
354    As a side effect, this function also accumulates the roc->{pos,neg} and
355    roc->{pos,neg}_weighted counts.
356  */
357 static bool
358 match_positives (const struct ccase *c, void *aux)
359 {
360   struct cmd_roc *roc = aux;
361   const struct variable *wv = dict_get_weight (roc->dict);
362   const double weight = wv ? case_data (c, wv)->f : 1.0;
363
364   const bool positive =
365   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
366     var_get_width (roc->state_var)));
367
368   if ( positive )
369     {
370       roc->pos++;
371       roc->pos_weighted += weight;
372     }
373   else
374     {
375       roc->neg++;
376       roc->neg_weighted += weight;
377     }
378
379   return positive;
380 }
381
382
383 #define VALUE  0
384 #define N_EQ   1
385 #define N_PRED 2
386
387 /* Some intermediate state for calculating the cutpoints and the
388    standard error values */
389 struct roc_state
390 {
391   double auc;  /* Area under the curve */
392
393   double n1;  /* total weight of positives */
394   double n2;  /* total weight of negatives */
395
396   /* intermediates for standard error */
397   double q1hat;
398   double q2hat;
399
400   /* intermediates for cutpoints */
401   struct casewriter *cutpoint_wtr;
402   struct casereader *cutpoint_rdr;
403   double prev_result;
404   double min;
405   double max;
406 };
407
408 /*
409    Return a new casereader based upon CUTPOINT_RDR.
410    The number of "positive" cases are placed into
411    the position TRUE_INDEX, and the number of "negative" cases
412    into FALSE_INDEX.
413    POS_COND and RESULT determine the semantics of what is
414    "positive".
415    WEIGHT is the value of a single count.
416  */
417 static struct casereader *
418 accumulate_counts (struct casereader *input,
419                    double result, double weight,
420                    bool (*pos_cond) (double, double),
421                    int true_index, int false_index)
422 {
423   const struct caseproto *proto = casereader_get_proto (input);
424   struct casewriter *w =
425     autopaging_writer_create (proto);
426   struct ccase *cpc;
427   double prev_cp = SYSMIS;
428
429   for ( ; (cpc = casereader_read (input) ); case_unref (cpc))
430     {
431       struct ccase *new_case;
432       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
433
434       assert (cp != SYSMIS);
435
436       /* We don't want duplicates here */
437       if ( cp == prev_cp )
438         continue;
439
440       new_case = case_clone (cpc);
441
442       if ( pos_cond (result, cp))
443         case_data_rw_idx (new_case, true_index)->f += weight;
444       else
445         case_data_rw_idx (new_case, false_index)->f += weight;
446
447       prev_cp = cp;
448
449       casewriter_write (w, new_case);
450     }
451   casereader_destroy (input);
452
453   return casewriter_make_reader (w);
454 }
455
456
457
458 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
459
460 /*
461   This function does 3 things:
462
463   1. Counts the number of cases which are equal to every other case in READER,
464   and those cases for which the relationship between it and every other case
465   satifies PRED (normally either > or <).  VAR is variable defining a case's value
466   for this purpose.
467
468   2. Counts the number of true and false cases in reader, and populates
469   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
470   which receive these values.  POS_COND is the condition defining true
471   and false.
472
473   3. CC is filled with the cumulative weight of all cases of READER.
474 */
475 static struct casereader *
476 process_group (const struct variable *var, struct casereader *reader,
477                bool (*pred) (double, double),
478                const struct dictionary *dict,
479                double *cc,
480                struct casereader **cutpoint_rdr,
481                bool (*pos_cond) (double, double),
482                int true_index,
483                int false_index)
484 {
485   const struct variable *w = dict_get_weight (dict);
486
487   struct casereader *r1 =
488     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
489
490   const int weight_idx  = w ? var_get_case_index (w) :
491     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
492
493   struct ccase *c1;
494
495   struct casereader *rclone = casereader_clone (r1);
496   struct casewriter *wtr;
497   struct caseproto *proto = caseproto_create ();
498
499   proto = caseproto_add_width (proto, 0);
500   proto = caseproto_add_width (proto, 0);
501   proto = caseproto_add_width (proto, 0);
502
503   wtr = autopaging_writer_create (proto);
504
505   *cc = 0;
506
507   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
508     {
509       struct ccase *new_case = case_create (proto);
510       struct ccase *c2;
511       struct casereader *r2 = casereader_clone (rclone);
512
513       const double weight1 = case_data_idx (c1, weight_idx)->f;
514       const double d1 = case_data (c1, var)->f;
515       double n_eq = 0.0;
516       double n_pred = 0.0;
517
518       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
519                                          pos_cond,
520                                          true_index, false_index);
521
522       *cc += weight1;
523
524       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
525         {
526           const double d2 = case_data (c2, var)->f;
527           const double weight2 = case_data_idx (c2, weight_idx)->f;
528
529           if ( d1 == d2 )
530             {
531               n_eq += weight2;
532               continue;
533             }
534           else  if ( pred (d2, d1))
535             {
536               n_pred += weight2;
537             }
538         }
539
540       case_data_rw_idx (new_case, VALUE)->f = d1;
541       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
542       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
543
544       casewriter_write (wtr, new_case);
545
546       casereader_destroy (r2);
547     }
548
549
550   casereader_destroy (r1);
551   casereader_destroy (rclone);
552
553   caseproto_unref (proto);
554
555   return casewriter_make_reader (wtr);
556 }
557
558 /* Some more indeces into case data */
559 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
560 #define N_POS_GT 2  /* number of positive cases with values greater than n */
561 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
562 #define N_NEG_LT 4  /* number of negative cases with values less than n */
563
564 static bool
565 gt (double d1, double d2)
566 {
567   return d1 > d2;
568 }
569
570
571 static bool
572 ge (double d1, double d2)
573 {
574   return d1 > d2;
575 }
576
577 static bool
578 lt (double d1, double d2)
579 {
580   return d1 < d2;
581 }
582
583
584 /*
585   Return a casereader with width 3,
586   populated with cases based upon READER.
587   The cases will have the values:
588   (N, number of cases equal to N, number of cases greater than N)
589   As a side effect, update RS->n1 with the number of positive cases.
590 */
591 static struct casereader *
592 process_positive_group (const struct variable *var, struct casereader *reader,
593                         const struct dictionary *dict,
594                         struct roc_state *rs)
595 {
596   return process_group (var, reader, gt, dict, &rs->n1,
597                         &rs->cutpoint_rdr,
598                         ge,
599                         ROC_TP, ROC_FN);
600 }
601
602 /*
603   Return a casereader with width 3,
604   populated with cases based upon READER.
605   The cases will have the values:
606   (N, number of cases equal to N, number of cases less than N)
607   As a side effect, update RS->n2 with the number of negative cases.
608 */
609 static struct casereader *
610 process_negative_group (const struct variable *var, struct casereader *reader,
611                         const struct dictionary *dict,
612                         struct roc_state *rs)
613 {
614   return process_group (var, reader, lt, dict, &rs->n2,
615                         &rs->cutpoint_rdr,
616                         lt,
617                         ROC_TN, ROC_FP);
618 }
619
620
621
622
623 static void
624 append_cutpoint (struct casewriter *writer, double cutpoint)
625 {
626   struct ccase *cc = case_create (casewriter_get_proto (writer));
627
628   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
629   case_data_rw_idx (cc, ROC_TP)->f = 0;
630   case_data_rw_idx (cc, ROC_FN)->f = 0;
631   case_data_rw_idx (cc, ROC_TN)->f = 0;
632   case_data_rw_idx (cc, ROC_FP)->f = 0;
633
634   casewriter_write (writer, cc);
635 }
636
637
638 /*
639    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
640    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
641    reader will be populated with its final number of cases.
642    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
643    value.  The other entries will be initialised to zero.
644 */
645 static void
646 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
647 {
648   int i;
649   struct casereader *r = casereader_clone (input);
650   struct ccase *c;
651
652   {
653     struct caseproto *proto = caseproto_create ();
654     struct subcase ordering;
655     subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
656
657     proto = caseproto_add_width (proto, 0); /* cutpoint */
658     proto = caseproto_add_width (proto, 0); /* ROC_TP */
659     proto = caseproto_add_width (proto, 0); /* ROC_FN */
660     proto = caseproto_add_width (proto, 0); /* ROC_TN */
661     proto = caseproto_add_width (proto, 0); /* ROC_FP */
662
663     for (i = 0 ; i < roc->n_vars; ++i)
664       {
665         rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
666         rs[i].prev_result = SYSMIS;
667         rs[i].max = -DBL_MAX;
668         rs[i].min = DBL_MAX;
669       }
670
671     caseproto_unref (proto);
672     subcase_destroy (&ordering);
673   }
674
675   for (; (c = casereader_read (r)) != NULL; case_unref (c))
676     {
677       for (i = 0 ; i < roc->n_vars; ++i)
678         {
679           const union value *v = case_data (c, roc->vars[i]);
680           const double result = v->f;
681
682           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
683             continue;
684
685           minimize (&rs[i].min, result);
686           maximize (&rs[i].max, result);
687
688           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
689             {
690               const double mean = (result + rs[i].prev_result ) / 2.0;
691               append_cutpoint (rs[i].cutpoint_wtr, mean);
692             }
693
694           rs[i].prev_result = result;
695         }
696     }
697   casereader_destroy (r);
698
699
700   /* Append the min and max cutpoints */
701   for (i = 0 ; i < roc->n_vars; ++i)
702     {
703       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
704       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
705
706       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
707     }
708 }
709
710 static void
711 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
712 {
713   int i;
714
715   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
716
717   struct casereader *negatives = NULL;
718   struct casereader *positives = NULL;
719
720   struct caseproto *n_proto = NULL;
721
722   struct subcase up_ordering;
723   struct subcase down_ordering;
724
725   struct casewriter *neg_wtr = NULL;
726
727   struct casereader *input = casereader_create_filter_missing (reader,
728                                                                roc->vars, roc->n_vars,
729                                                                roc->exclude,
730                                                                NULL,
731                                                                NULL);
732
733   input = casereader_create_filter_missing (input,
734                                             &roc->state_var, 1,
735                                             roc->exclude,
736                                             NULL,
737                                             NULL);
738
739   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
740
741   prepare_cutpoints (roc, rs, input);
742
743
744   /* Separate the positive actual state cases from the negative ones */
745   positives =
746     casereader_create_filter_func (input,
747                                    match_positives,
748                                    NULL,
749                                    roc,
750                                    neg_wtr);
751
752   n_proto = caseproto_create ();
753
754   n_proto = caseproto_add_width (n_proto, 0);
755   n_proto = caseproto_add_width (n_proto, 0);
756   n_proto = caseproto_add_width (n_proto, 0);
757   n_proto = caseproto_add_width (n_proto, 0);
758   n_proto = caseproto_add_width (n_proto, 0);
759
760   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
761   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
762
763   for (i = 0 ; i < roc->n_vars; ++i)
764     {
765       struct casewriter *w = NULL;
766       struct casereader *r = NULL;
767
768       struct ccase *c;
769
770       struct ccase *cpos;
771       struct casereader *n_neg_reader ;
772       const struct variable *var = roc->vars[i];
773
774       struct casereader *neg ;
775       struct casereader *pos = casereader_clone (positives);
776
777       struct casereader *n_pos_reader =
778         process_positive_group (var, pos, dict, &rs[i]);
779
780       if ( negatives == NULL)
781         {
782           negatives = casewriter_make_reader (neg_wtr);
783         }
784
785       neg = casereader_clone (negatives);
786
787       n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
788
789       /* Merge the n_pos and n_neg casereaders */
790       w = sort_create_writer (&up_ordering, n_proto);
791       for ( ; (cpos = casereader_read (n_pos_reader) ); case_unref (cpos))
792         {
793           struct ccase *pos_case = case_create (n_proto);
794           struct ccase *cneg;
795           const double jpos = case_data_idx (cpos, VALUE)->f;
796
797           while ((cneg = casereader_read (n_neg_reader)))
798             {
799               struct ccase *nc = case_create (n_proto);
800
801               const double jneg = case_data_idx (cneg, VALUE)->f;
802
803               case_data_rw_idx (nc, VALUE)->f = jneg;
804               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
805
806               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
807
808               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
809               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
810
811               casewriter_write (w, nc);
812
813               case_unref (cneg);
814               if ( jneg > jpos)
815                 break;
816             }
817
818           case_data_rw_idx (pos_case, VALUE)->f = jpos;
819           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
820           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
821           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
822           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
823
824           casewriter_write (w, pos_case);
825         }
826
827       casereader_destroy (n_pos_reader);
828       casereader_destroy (n_neg_reader);
829
830 /* These aren't used anymore */
831 #undef N_EQ
832 #undef N_PRED
833
834       r = casewriter_make_reader (w);
835
836       /* Propagate the N_POS_GT values from the positive cases
837          to the negative ones */
838       {
839         double prev_pos_gt = rs[i].n1;
840         w = sort_create_writer (&down_ordering, n_proto);
841
842         for ( ; (c = casereader_read (r) ); case_unref (c))
843           {
844             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
845             struct ccase *nc = case_clone (c);
846
847             if ( n_pos_gt == SYSMIS)
848               {
849                 n_pos_gt = prev_pos_gt;
850                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
851               }
852
853             casewriter_write (w, nc);
854             prev_pos_gt = n_pos_gt;
855           }
856
857         casereader_destroy (r);
858         r = casewriter_make_reader (w);
859       }
860
861       /* Propagate the N_NEG_LT values from the negative cases
862          to the positive ones */
863       {
864         double prev_neg_lt = rs[i].n2;
865         w = sort_create_writer (&up_ordering, n_proto);
866
867         for ( ; (c = casereader_read (r) ); case_unref (c))
868           {
869             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
870             struct ccase *nc = case_clone (c);
871
872             if ( n_neg_lt == SYSMIS)
873               {
874                 n_neg_lt = prev_neg_lt;
875                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
876               }
877
878             casewriter_write (w, nc);
879             prev_neg_lt = n_neg_lt;
880           }
881
882         casereader_destroy (r);
883         r = casewriter_make_reader (w);
884       }
885
886       {
887         struct ccase *prev_case = NULL;
888         for ( ; (c = casereader_read (r) ); case_unref (c))
889           {
890             struct ccase *next_case = casereader_peek (r, 0);
891
892             const double j = case_data_idx (c, VALUE)->f;
893             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
894             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
895             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
896             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
897
898             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
899               {
900                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
901                   {
902                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
903                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
904                   }
905
906                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
907                   {
908                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
909                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
910                   }
911               }
912
913             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
914               {
915                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
916
917                 rs[i].q1hat +=
918                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
919                 rs[i].q2hat +=
920                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
921
922               }
923
924             case_unref (next_case);
925             case_unref (prev_case);
926             prev_case = case_clone (c);
927           }
928         casereader_destroy (r);
929         case_unref (prev_case);
930
931         rs[i].auc /=  rs[i].n1 * rs[i].n2;
932         if ( roc->invert )
933           rs[i].auc = 1 - rs[i].auc;
934
935         if ( roc->bi_neg_exp )
936           {
937             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
938             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
939           }
940         else
941           {
942             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
943             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
944           }
945       }
946     }
947
948   casereader_destroy (positives);
949   casereader_destroy (negatives);
950
951   caseproto_unref (n_proto);
952   subcase_destroy (&up_ordering);
953   subcase_destroy (&down_ordering);
954
955   output_roc (rs, roc);
956
957   for (i = 0 ; i < roc->n_vars; ++i)
958     casereader_destroy (rs[i].cutpoint_rdr);
959
960   free (rs);
961 }
962
963 static void
964 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
965 {
966   struct pivot_table *table = pivot_table_create (N_("Area Under the Curve"));
967
968   struct pivot_dimension *statistics = pivot_dimension_create (
969     table, PIVOT_AXIS_COLUMN, N_("Statistics"),
970     N_("Area"), PIVOT_RC_OTHER);
971   if (roc->print_se)
972     {
973       pivot_category_create_leaves (
974         statistics->root,
975         N_("Std. Error"), PIVOT_RC_OTHER,
976         N_("Asymptotic Sig."), PIVOT_RC_SIGNIFICANCE);
977       struct pivot_category *interval = pivot_category_create_group__ (
978         statistics->root,
979         pivot_value_new_text_format (N_("Asymp. %g%% Confidence Interval"),
980                                      roc->ci));
981       pivot_category_create_leaves (interval,
982                                     N_("Lower Bound"), PIVOT_RC_OTHER,
983                                     N_("Upper Bound"), PIVOT_RC_OTHER);
984     }
985
986   struct pivot_dimension *variables = pivot_dimension_create (
987     table, PIVOT_AXIS_ROW, N_("Variable under test"));
988   variables->root->show_label = true;
989
990   for (size_t i = 0 ; i < roc->n_vars ; ++i )
991     {
992       int var_idx = pivot_category_create_leaf (
993         variables->root, pivot_value_new_variable (roc->vars[i]));
994
995       pivot_table_put2 (table, 0, var_idx, pivot_value_new_number (rs[i].auc));
996
997       if ( roc->print_se )
998         {
999           double se = (rs[i].auc * (1 - rs[i].auc)
1000                        + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc))
1001                        + (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc)));
1002           se /= rs[i].n1 * rs[i].n2;
1003           se = sqrt (se);
1004
1005           double ci = 1 - roc->ci / 100.0;
1006           double yy = gsl_cdf_gaussian_Qinv (ci, se);
1007
1008           double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1009                                 (12 * rs[i].n1 * rs[i].n2));
1010           double sig = 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5)
1011                                                         / sd_0_5));
1012           double entries[] = { se, sig, rs[i].auc - yy, rs[i].auc + yy };
1013           for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1014             pivot_table_put2 (table, i + 1, var_idx,
1015                               pivot_value_new_number (entries[i]));
1016         }
1017     }
1018
1019   pivot_table_submit (table);
1020 }
1021
1022
1023 static void
1024 show_summary (const struct cmd_roc *roc)
1025 {
1026   struct pivot_table *table = pivot_table_create (N_("Case Summary"));
1027
1028   struct pivot_dimension *statistics = pivot_dimension_create (
1029     table, PIVOT_AXIS_COLUMN, N_("Valid N (listwise)"),
1030     N_("Unweighted"), PIVOT_RC_INTEGER,
1031     N_("Weighted"), PIVOT_RC_OTHER);
1032   statistics->root->show_label = true;
1033
1034   struct pivot_dimension *cases = pivot_dimension_create__ (
1035     table, PIVOT_AXIS_ROW, pivot_value_new_variable (roc->state_var));
1036   cases->root->show_label = true;
1037   pivot_category_create_leaves (cases->root, N_("Positive"), N_("Negative"));
1038
1039   struct entry
1040     {
1041       int stat_idx;
1042       int case_idx;
1043       double x;
1044     }
1045   entries[] = {
1046     { 0, 0, roc->pos },
1047     { 0, 1, roc->neg },
1048     { 1, 0, roc->pos_weighted },
1049     { 1, 1, roc->neg_weighted },
1050   };
1051   for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1052     {
1053       const struct entry *e = &entries[i];
1054       pivot_table_put2 (table, e->stat_idx, e->case_idx,
1055                         pivot_value_new_number (e->x));
1056     }
1057   pivot_table_submit (table);
1058 }
1059
1060 static void
1061 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1062 {
1063   struct pivot_table *table = pivot_table_create (
1064     N_("Coordinates of the Curve"));
1065   table->omit_empty = true;
1066
1067   pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
1068                           N_("Positive if greater than or equal to"),
1069                           N_("Sensitivity"), N_("1 - Specificity"));
1070
1071   struct pivot_dimension *coordinates = pivot_dimension_create (
1072     table, PIVOT_AXIS_ROW, N_("Coordinates"));
1073   coordinates->hide_all_labels = true;
1074
1075   struct pivot_dimension *variables = pivot_dimension_create (
1076     table, PIVOT_AXIS_ROW, N_("Test variable"));
1077   variables->root->show_label = true;
1078
1079
1080   int n_coords = 0;
1081   for (size_t i = 0; i < roc->n_vars; ++i)
1082     {
1083       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1084
1085       int var_idx = pivot_category_create_leaf (
1086         variables->root, pivot_value_new_variable (roc->vars[i]));
1087
1088       struct ccase *cc;
1089       int coord_idx = 0;
1090       for (; (cc = casereader_read (r)) != NULL; case_unref (cc))
1091         {
1092           const double se = case_data_idx (cc, ROC_TP)->f /
1093             (case_data_idx (cc, ROC_TP)->f + case_data_idx (cc, ROC_FN)->f);
1094
1095           const double sp = case_data_idx (cc, ROC_TN)->f /
1096             (case_data_idx (cc, ROC_TN)->f + case_data_idx (cc, ROC_FP)->f);
1097
1098           pivot_table_put3 (
1099             table, 0, coord_idx, var_idx,
1100             pivot_value_new_var_value (roc->vars[i],
1101                                        case_data_idx (cc, ROC_CUTPOINT)));
1102
1103           pivot_table_put3 (table, 1, coord_idx, var_idx,
1104                             pivot_value_new_number (se));
1105           pivot_table_put3 (table, 2, coord_idx, var_idx,
1106                             pivot_value_new_number (1 - sp));
1107           coord_idx++;
1108         }
1109
1110       if (coord_idx > n_coords)
1111         n_coords = coord_idx;
1112
1113       casereader_destroy (r);
1114     }
1115
1116   for (size_t i = 0; i < n_coords; i++)
1117     pivot_category_create_leaf (coordinates->root,
1118                                 pivot_value_new_integer (i + 1));
1119
1120   pivot_table_submit (table);
1121 }
1122
1123
1124 static void
1125 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1126 {
1127   show_summary (roc);
1128
1129   if ( roc->curve )
1130     {
1131       struct roc_chart *rc;
1132       size_t i;
1133
1134       rc = roc_chart_create (roc->reference);
1135       for (i = 0; i < roc->n_vars; i++)
1136         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1137                            rs[i].cutpoint_rdr);
1138       roc_chart_submit (rc);
1139     }
1140
1141   show_auc (rs, roc);
1142
1143   if ( roc->print_coords )
1144     show_coords (rs, roc);
1145 }
1146