Fix crash on erroneous ROC input
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <language/stats/roc.h>
20
21 #include <data/casegrouper.h>
22 #include <data/casereader.h>
23 #include <data/casewriter.h>
24 #include <data/dictionary.h>
25 #include <data/format.h>
26 #include <data/procedure.h>
27 #include <data/subcase.h>
28 #include <language/command.h>
29 #include <language/lexer/lexer.h>
30 #include <language/lexer/value-parser.h>
31 #include <language/lexer/variable-parser.h>
32 #include <libpspp/misc.h>
33 #include <math/sort.h>
34 #include <output/chart-item.h>
35 #include <output/charts/roc-chart.h>
36 #include <output/tab.h>
37
38 #include <gsl/gsl_cdf.h>
39
40 #include "gettext.h"
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
43
44 struct cmd_roc
45 {
46   size_t n_vars;
47   const struct variable **vars;
48   const struct dictionary *dict;
49
50   const struct variable *state_var;
51   union value state_value;
52
53   /* Plot the roc curve */
54   bool curve;
55   /* Plot the reference line */
56   bool reference;
57
58   double ci;
59
60   bool print_coords;
61   bool print_se;
62   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
63                       should be used */
64   enum mv_class exclude;
65
66   bool invert ; /* True iff a smaller test result variable indicates
67                    a positive result */
68
69   double pos;
70   double neg;
71   double pos_weighted;
72   double neg_weighted;
73 };
74
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
76
77 int
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
79 {
80   struct cmd_roc roc ;
81   const struct dictionary *dict = dataset_dict (ds);
82
83   roc.vars = NULL;
84   roc.n_vars = 0;
85   roc.print_se = false;
86   roc.print_coords = false;
87   roc.exclude = MV_ANY;
88   roc.curve = true;
89   roc.reference = false;
90   roc.ci = 95;
91   roc.bi_neg_exp = false;
92   roc.invert = false;
93   roc.pos = roc.pos_weighted = 0;
94   roc.neg = roc.neg_weighted = 0;
95   roc.dict = dataset_dict (ds);
96   roc.state_var = NULL;
97
98   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
99                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
100     goto error;
101
102   if ( ! lex_force_match (lexer, T_BY))
103     {
104       goto error;
105     }
106
107   roc.state_var = parse_variable (lexer, dict);
108
109   if ( !lex_force_match (lexer, '('))
110     {
111       goto error;
112     }
113
114   value_init (&roc.state_value, var_get_width (roc.state_var));
115   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
116
117
118   if ( !lex_force_match (lexer, ')'))
119     {
120       goto error;
121     }
122
123
124   while (lex_token (lexer) != '.')
125     {
126       lex_match (lexer, '/');
127       if (lex_match_id (lexer, "MISSING"))
128         {
129           lex_match (lexer, '=');
130           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
131             {
132               if (lex_match_id (lexer, "INCLUDE"))
133                 {
134                   roc.exclude = MV_SYSTEM;
135                 }
136               else if (lex_match_id (lexer, "EXCLUDE"))
137                 {
138                   roc.exclude = MV_ANY;
139                 }
140               else
141                 {
142                   lex_error (lexer, NULL);
143                   goto error;
144                 }
145             }
146         }
147       else if (lex_match_id (lexer, "PLOT"))
148         {
149           lex_match (lexer, '=');
150           if (lex_match_id (lexer, "CURVE"))
151             {
152               roc.curve = true;
153               if (lex_match (lexer, '('))
154                 {
155                   roc.reference = true;
156                   lex_force_match_id (lexer, "REFERENCE");
157                   lex_force_match (lexer, ')');
158                 }
159             }
160           else if (lex_match_id (lexer, "NONE"))
161             {
162               roc.curve = false;
163             }
164           else
165             {
166               lex_error (lexer, NULL);
167               goto error;
168             }
169         }
170       else if (lex_match_id (lexer, "PRINT"))
171         {
172           lex_match (lexer, '=');
173           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
174             {
175               if (lex_match_id (lexer, "SE"))
176                 {
177                   roc.print_se = true;
178                 }
179               else if (lex_match_id (lexer, "COORDINATES"))
180                 {
181                   roc.print_coords = true;
182                 }
183               else
184                 {
185                   lex_error (lexer, NULL);
186                   goto error;
187                 }
188             }
189         }
190       else if (lex_match_id (lexer, "CRITERIA"))
191         {
192           lex_match (lexer, '=');
193           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
194             {
195               if (lex_match_id (lexer, "CUTOFF"))
196                 {
197                   lex_force_match (lexer, '(');
198                   if (lex_match_id (lexer, "INCLUDE"))
199                     {
200                       roc.exclude = MV_SYSTEM;
201                     }
202                   else if (lex_match_id (lexer, "EXCLUDE"))
203                     {
204                       roc.exclude = MV_USER | MV_SYSTEM;
205                     }
206                   else
207                     {
208                       lex_error (lexer, NULL);
209                       goto error;
210                     }
211                   lex_force_match (lexer, ')');
212                 }
213               else if (lex_match_id (lexer, "TESTPOS"))
214                 {
215                   lex_force_match (lexer, '(');
216                   if (lex_match_id (lexer, "LARGE"))
217                     {
218                       roc.invert = false;
219                     }
220                   else if (lex_match_id (lexer, "SMALL"))
221                     {
222                       roc.invert = true;
223                     }
224                   else
225                     {
226                       lex_error (lexer, NULL);
227                       goto error;
228                     }
229                   lex_force_match (lexer, ')');
230                 }
231               else if (lex_match_id (lexer, "CI"))
232                 {
233                   lex_force_match (lexer, '(');
234                   lex_force_num (lexer);
235                   roc.ci = lex_number (lexer);
236                   lex_get (lexer);
237                   lex_force_match (lexer, ')');
238                 }
239               else if (lex_match_id (lexer, "DISTRIBUTION"))
240                 {
241                   lex_force_match (lexer, '(');
242                   if (lex_match_id (lexer, "FREE"))
243                     {
244                       roc.bi_neg_exp = false;
245                     }
246                   else if (lex_match_id (lexer, "NEGEXPO"))
247                     {
248                       roc.bi_neg_exp = true;
249                     }
250                   else
251                     {
252                       lex_error (lexer, NULL);
253                       goto error;
254                     }
255                   lex_force_match (lexer, ')');
256                 }
257               else
258                 {
259                   lex_error (lexer, NULL);
260                   goto error;
261                 }
262             }
263         }
264       else
265         {
266           lex_error (lexer, NULL);
267           break;
268         }
269     }
270
271   if ( ! run_roc (ds, &roc)) 
272     goto error;
273
274   value_destroy (&roc.state_value, var_get_width (roc.state_var));
275   free (roc.vars);
276   return CMD_SUCCESS;
277
278  error:
279   if ( roc.state_var)
280     value_destroy (&roc.state_value, var_get_width (roc.state_var));
281   free (roc.vars);
282   return CMD_FAILURE;
283 }
284
285
286
287
288 static void
289 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
290
291
292 static int
293 run_roc (struct dataset *ds, struct cmd_roc *roc)
294 {
295   struct dictionary *dict = dataset_dict (ds);
296   bool ok;
297   struct casereader *group;
298
299   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
300   while (casegrouper_get_next_group (grouper, &group))
301     {
302       do_roc (roc, group, dataset_dict (ds));
303     }
304   ok = casegrouper_destroy (grouper);
305   ok = proc_commit (ds) && ok;
306
307   return ok;
308 }
309
310 #if 0
311 static void
312 dump_casereader (struct casereader *reader)
313 {
314   struct ccase *c;
315   struct casereader *r = casereader_clone (reader);
316
317   for ( ; (c = casereader_read (r) ); case_unref (c))
318     {
319       int i;
320       for (i = 0 ; i < case_get_value_cnt (c); ++i)
321         {
322           printf ("%g ", case_data_idx (c, i)->f);
323         }
324       printf ("\n");
325     }
326
327   casereader_destroy (r);
328 }
329 #endif
330
331
332 /* 
333    Return true iff the state variable indicates that C has positive actual state.
334
335    As a side effect, this function also accumulates the roc->{pos,neg} and 
336    roc->{pos,neg}_weighted counts.
337  */
338 static bool
339 match_positives (const struct ccase *c, void *aux)
340 {
341   struct cmd_roc *roc = aux;
342   const struct variable *wv = dict_get_weight (roc->dict);
343   const double weight = wv ? case_data (c, wv)->f : 1.0;
344
345   const bool positive =
346   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
347     var_get_width (roc->state_var)));
348
349   if ( positive )
350     {
351       roc->pos++;
352       roc->pos_weighted += weight;
353     }
354   else
355     {
356       roc->neg++;
357       roc->neg_weighted += weight;
358     }
359
360   return positive;
361 }
362
363
364 #define VALUE  0
365 #define N_EQ   1
366 #define N_PRED 2
367
368 /* Some intermediate state for calculating the cutpoints and the 
369    standard error values */
370 struct roc_state
371 {
372   double auc;  /* Area under the curve */
373
374   double n1;  /* total weight of positives */
375   double n2;  /* total weight of negatives */
376
377   /* intermediates for standard error */
378   double q1hat; 
379   double q2hat;
380
381   /* intermediates for cutpoints */
382   struct casewriter *cutpoint_wtr;
383   struct casereader *cutpoint_rdr;
384   double prev_result;
385   double min;
386   double max;
387 };
388
389 /* 
390    Return a new casereader based upon CUTPOINT_RDR.
391    The number of "positive" cases are placed into
392    the position TRUE_INDEX, and the number of "negative" cases
393    into FALSE_INDEX.
394    POS_COND and RESULT determine the semantics of what is 
395    "positive".
396    WEIGHT is the value of a single count.
397  */
398 static struct casereader *
399 accumulate_counts (struct casereader *cutpoint_rdr, 
400                    double result, double weight, 
401                    bool (*pos_cond) (double, double),
402                    int true_index, int false_index)
403 {
404   const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
405   struct casewriter *w =
406     autopaging_writer_create (proto);
407   struct casereader *r = casereader_clone (cutpoint_rdr);
408   struct ccase *cpc;
409   double prev_cp = SYSMIS;
410
411   for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
412     {
413       struct ccase *new_case;
414       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
415
416       assert (cp != SYSMIS);
417
418       /* We don't want duplicates here */
419       if ( cp == prev_cp )
420         continue;
421
422       new_case = case_clone (cpc);
423
424       if ( pos_cond (result, cp))
425         case_data_rw_idx (new_case, true_index)->f += weight;
426       else
427         case_data_rw_idx (new_case, false_index)->f += weight;
428
429       prev_cp = cp;
430
431       casewriter_write (w, new_case);
432     }
433   casereader_destroy (r);
434
435   return casewriter_make_reader (w);
436 }
437
438
439
440 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
441
442 /*
443   This function does 3 things:
444
445   1. Counts the number of cases which are equal to every other case in READER,
446   and those cases for which the relationship between it and every other case
447   satifies PRED (normally either > or <).  VAR is variable defining a case's value
448   for this purpose.
449
450   2. Counts the number of true and false cases in reader, and populates
451   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
452   which receive these values.  POS_COND is the condition defining true
453   and false.
454   
455   3. CC is filled with the cumulative weight of all cases of READER.
456 */
457 static struct casereader *
458 process_group (const struct variable *var, struct casereader *reader,
459                bool (*pred) (double, double),
460                const struct dictionary *dict,
461                double *cc,
462                struct casereader **cutpoint_rdr, 
463                bool (*pos_cond) (double, double),
464                int true_index,
465                int false_index)
466 {
467   const struct variable *w = dict_get_weight (dict);
468
469   struct casereader *r1 =
470     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
471
472   const int weight_idx  = w ? var_get_case_index (w) :
473     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
474   
475   struct ccase *c1;
476
477   struct casereader *rclone = casereader_clone (r1);
478   struct casewriter *wtr;
479   struct caseproto *proto = caseproto_create ();
480
481   proto = caseproto_add_width (proto, 0);
482   proto = caseproto_add_width (proto, 0);
483   proto = caseproto_add_width (proto, 0);
484
485   wtr = autopaging_writer_create (proto);  
486
487   *cc = 0;
488
489   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
490     {
491       struct ccase *new_case = case_create (proto);
492       struct ccase *c2;
493       struct casereader *r2 = casereader_clone (rclone);
494
495       const double weight1 = case_data_idx (c1, weight_idx)->f;
496       const double d1 = case_data (c1, var)->f;
497       double n_eq = 0.0;
498       double n_pred = 0.0;
499
500       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
501                                          pos_cond,
502                                          true_index, false_index);
503
504       *cc += weight1;
505
506       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
507         {
508           const double d2 = case_data (c2, var)->f;
509           const double weight2 = case_data_idx (c2, weight_idx)->f;
510
511           if ( d1 == d2 )
512             {
513               n_eq += weight2;
514               continue;
515             }
516           else  if ( pred (d2, d1))
517             {
518               n_pred += weight2;
519             }
520         }
521
522       case_data_rw_idx (new_case, VALUE)->f = d1;
523       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
524       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
525
526       casewriter_write (wtr, new_case);
527
528       casereader_destroy (r2);
529     }
530
531   casereader_destroy (r1);
532   casereader_destroy (rclone);
533
534   return casewriter_make_reader (wtr);
535 }
536
537 /* Some more indeces into case data */
538 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
539 #define N_POS_GT 2  /* number of postive cases with values greater than n */
540 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
541 #define N_NEG_LT 4  /* number of negative cases with values less than n */
542
543 static bool
544 gt (double d1, double d2)
545 {
546   return d1 > d2;
547 }
548
549
550 static bool
551 ge (double d1, double d2)
552 {
553   return d1 > d2;
554 }
555
556 static bool
557 lt (double d1, double d2)
558 {
559   return d1 < d2;
560 }
561
562
563 /*
564   Return a casereader with width 3,
565   populated with cases based upon READER.
566   The cases will have the values:
567   (N, number of cases equal to N, number of cases greater than N)
568   As a side effect, update RS->n1 with the number of positive cases.
569 */
570 static struct casereader *
571 process_positive_group (const struct variable *var, struct casereader *reader,
572                         const struct dictionary *dict,
573                         struct roc_state *rs)
574 {
575   return process_group (var, reader, gt, dict, &rs->n1,
576                         &rs->cutpoint_rdr,
577                         ge,
578                         ROC_TP, ROC_FN);
579 }
580
581 /*
582   Return a casereader with width 3,
583   populated with cases based upon READER.
584   The cases will have the values:
585   (N, number of cases equal to N, number of cases less than N)
586   As a side effect, update RS->n2 with the number of negative cases.
587 */
588 static struct casereader *
589 process_negative_group (const struct variable *var, struct casereader *reader,
590                         const struct dictionary *dict,
591                         struct roc_state *rs)
592 {
593   return process_group (var, reader, lt, dict, &rs->n2,
594                         &rs->cutpoint_rdr,
595                         lt,
596                         ROC_TN, ROC_FP);
597 }
598
599
600
601
602 static void
603 append_cutpoint (struct casewriter *writer, double cutpoint)
604 {
605   struct ccase *cc = case_create (casewriter_get_proto (writer));
606
607   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
608   case_data_rw_idx (cc, ROC_TP)->f = 0;
609   case_data_rw_idx (cc, ROC_FN)->f = 0;
610   case_data_rw_idx (cc, ROC_TN)->f = 0;
611   case_data_rw_idx (cc, ROC_FP)->f = 0;
612
613   casewriter_write (writer, cc);
614 }
615
616
617 /* 
618    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
619    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
620    reader will be populated with its final number of cases.
621    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
622    value.  The other entries will be initialised to zero.
623 */
624 static void
625 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
626 {
627   int i;
628   struct casereader *r = casereader_clone (input);
629   struct ccase *c;
630   struct caseproto *proto = caseproto_create ();
631
632   struct subcase ordering;
633   subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
634
635   proto = caseproto_add_width (proto, 0); /* cutpoint */
636   proto = caseproto_add_width (proto, 0); /* ROC_TP */
637   proto = caseproto_add_width (proto, 0); /* ROC_FN */
638   proto = caseproto_add_width (proto, 0); /* ROC_TN */
639   proto = caseproto_add_width (proto, 0); /* ROC_FP */
640
641   for (i = 0 ; i < roc->n_vars; ++i)
642     {
643       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
644       rs[i].prev_result = SYSMIS;
645       rs[i].max = -DBL_MAX;
646       rs[i].min = DBL_MAX;
647     }
648
649   for (; (c = casereader_read (r)) != NULL; case_unref (c))
650     {
651       for (i = 0 ; i < roc->n_vars; ++i)
652         {
653           const union value *v = case_data (c, roc->vars[i]); 
654           const double result = v->f;
655
656           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
657             continue;
658
659           minimize (&rs[i].min, result);
660           maximize (&rs[i].max, result);
661
662           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
663             {
664               const double mean = (result + rs[i].prev_result ) / 2.0;
665               append_cutpoint (rs[i].cutpoint_wtr, mean);
666             }
667
668           rs[i].prev_result = result;
669         }
670     }
671   casereader_destroy (r);
672
673
674   /* Append the min and max cutpoints */
675   for (i = 0 ; i < roc->n_vars; ++i)
676     {
677       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
678       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
679
680       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
681     }
682 }
683
684 static void
685 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
686 {
687   int i;
688
689   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
690
691   struct casereader *negatives = NULL;
692   struct casereader *positives = NULL;
693
694   struct caseproto *n_proto = caseproto_create ();
695
696   struct subcase up_ordering;
697   struct subcase down_ordering;
698
699   struct casewriter *neg_wtr = NULL;
700
701   struct casereader *input = casereader_create_filter_missing (reader,
702                                                                roc->vars, roc->n_vars,
703                                                                roc->exclude,
704                                                                NULL,
705                                                                NULL);
706
707   input = casereader_create_filter_missing (input,
708                                             &roc->state_var, 1,
709                                             roc->exclude,
710                                             NULL,
711                                             NULL);
712
713   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
714
715   prepare_cutpoints (roc, rs, input);
716
717
718   /* Separate the positive actual state cases from the negative ones */
719   positives = 
720     casereader_create_filter_func (input,
721                                    match_positives,
722                                    NULL,
723                                    roc,
724                                    neg_wtr);
725
726   n_proto = caseproto_create ();
727       
728   n_proto = caseproto_add_width (n_proto, 0);
729   n_proto = caseproto_add_width (n_proto, 0);
730   n_proto = caseproto_add_width (n_proto, 0);
731   n_proto = caseproto_add_width (n_proto, 0);
732   n_proto = caseproto_add_width (n_proto, 0);
733
734   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
735   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
736
737   for (i = 0 ; i < roc->n_vars; ++i)
738     {
739       struct casewriter *w = NULL;
740       struct casereader *r = NULL;
741
742       struct ccase *c;
743
744       struct ccase *cpos;
745       struct casereader *n_neg ;
746       const struct variable *var = roc->vars[i];
747
748       struct casereader *neg ;
749       struct casereader *pos = casereader_clone (positives);
750
751
752       struct casereader *n_pos =
753         process_positive_group (var, pos, dict, &rs[i]);
754
755       if ( negatives == NULL)
756         {
757           negatives = casewriter_make_reader (neg_wtr);
758         }
759
760       neg = casereader_clone (negatives);
761
762       n_neg = process_negative_group (var, neg, dict, &rs[i]);
763
764
765       /* Merge the n_pos and n_neg casereaders */
766       w = sort_create_writer (&up_ordering, n_proto);
767       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
768         {
769           struct ccase *pos_case = case_create (n_proto);
770           struct ccase *cneg;
771           const double jpos = case_data_idx (cpos, VALUE)->f;
772
773           while ((cneg = casereader_read (n_neg)))
774             {
775               struct ccase *nc = case_create (n_proto);
776
777               const double jneg = case_data_idx (cneg, VALUE)->f;
778
779               case_data_rw_idx (nc, VALUE)->f = jneg;
780               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
781
782               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
783
784               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
785               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
786
787               casewriter_write (w, nc);
788
789               case_unref (cneg);
790               if ( jneg > jpos)
791                 break;
792             }
793
794           case_data_rw_idx (pos_case, VALUE)->f = jpos;
795           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
796           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
797           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
798           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
799
800           casewriter_write (w, pos_case);
801         }
802
803 /* These aren't used anymore */
804 #undef N_EQ
805 #undef N_PRED
806
807       r = casewriter_make_reader (w);
808
809       /* Propagate the N_POS_GT values from the positive cases
810          to the negative ones */
811       {
812         double prev_pos_gt = rs[i].n1;
813         w = sort_create_writer (&down_ordering, n_proto);
814
815         for ( ; (c = casereader_read (r) ); case_unref (c))
816           {
817             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
818             struct ccase *nc = case_clone (c);
819
820             if ( n_pos_gt == SYSMIS)
821               {
822                 n_pos_gt = prev_pos_gt;
823                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
824               }
825             
826             casewriter_write (w, nc);
827             prev_pos_gt = n_pos_gt;
828           }
829
830         r = casewriter_make_reader (w);
831       }
832
833       /* Propagate the N_NEG_LT values from the negative cases
834          to the positive ones */
835       {
836         double prev_neg_lt = rs[i].n2;
837         w = sort_create_writer (&up_ordering, n_proto);
838
839         for ( ; (c = casereader_read (r) ); case_unref (c))
840           {
841             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
842             struct ccase *nc = case_clone (c);
843
844             if ( n_neg_lt == SYSMIS)
845               {
846                 n_neg_lt = prev_neg_lt;
847                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
848               }
849             
850             casewriter_write (w, nc);
851             prev_neg_lt = n_neg_lt;
852           }
853
854         r = casewriter_make_reader (w);
855       }
856
857       {
858         struct ccase *prev_case = NULL;
859         for ( ; (c = casereader_read (r) ); case_unref (c))
860           {
861             const struct ccase *next_case = casereader_peek (r, 0);
862
863             const double j = case_data_idx (c, VALUE)->f;
864             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
865             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
866             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
867             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
868
869             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
870               {
871                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
872                   {
873                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
874                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
875                   }
876
877                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
878                   {
879                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
880                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
881                   }
882               }
883
884             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
885               {
886                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
887
888                 rs[i].q1hat +=
889                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
890                 rs[i].q2hat +=
891                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
892
893               }
894
895             case_unref (prev_case);
896             prev_case = case_clone (c);
897           }
898
899         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
900         if ( roc->invert ) 
901           rs[i].auc = 1 - rs[i].auc;
902
903         if ( roc->bi_neg_exp )
904           {
905             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
906             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
907           }
908         else
909           {
910             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
911             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
912           }
913       }
914     }
915
916   casereader_destroy (positives);
917   casereader_destroy (negatives);
918
919   output_roc (rs, roc);
920
921   free (rs);
922 }
923
924 static void
925 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
926 {
927   int i;
928   const int n_fields = roc->print_se ? 5 : 1;
929   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
930   const int n_rows = 2 + roc->n_vars;
931   struct tab_table *tbl = tab_create (n_cols, n_rows);
932
933   if ( roc->n_vars > 1)
934     tab_title (tbl, _("Area Under the Curve"));
935   else
936     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
937
938   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
939
940
941   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
942
943   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
944
945   tab_box (tbl,
946            TAL_2, TAL_2,
947            -1, TAL_1,
948            0, 0,
949            n_cols - 1,
950            n_rows - 1);
951
952   if ( roc->print_se )
953     {
954       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
955       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
956
957       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
958       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
959
960       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
961                              TAT_TITLE | TAB_CENTER,
962                              _("Asymp. %g%% Confidence Interval"), roc->ci);
963       tab_vline (tbl, 0, n_cols - 1, 0, 0);
964       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
965     }
966
967   if ( roc->n_vars > 1)
968     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
969
970   if ( roc->n_vars > 1)
971     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
972
973
974   for ( i = 0 ; i < roc->n_vars ; ++i )
975     {
976       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
977
978       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
979
980       if ( roc->print_se )
981         {
982           double se ;
983           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
984                                       (12 * rs[i].n1 * rs[i].n2));
985           double ci ;
986           double yy ;
987
988           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
989             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
990
991           se /= rs[i].n1 * rs[i].n2;
992
993           se = sqrt (se);
994
995           tab_double (tbl, n_cols - 4, 2 + i, 0,
996                       se,
997                       NULL);
998
999           ci = 1 - roc->ci / 100.0;
1000           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1001
1002           tab_double (tbl, n_cols - 2, 2 + i, 0,
1003                       rs[i].auc - yy,
1004                       NULL);
1005
1006           tab_double (tbl, n_cols - 1, 2 + i, 0,
1007                       rs[i].auc + yy,
1008                       NULL);
1009
1010           tab_double (tbl, n_cols - 3, 2 + i, 0,
1011                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1012                       NULL);
1013         }
1014     }
1015
1016   tab_submit (tbl);
1017 }
1018
1019
1020 static void
1021 show_summary (const struct cmd_roc *roc)
1022 {
1023   const int n_cols = 3;
1024   const int n_rows = 4;
1025   struct tab_table *tbl = tab_create (n_cols, n_rows);
1026
1027   tab_title (tbl, _("Case Summary"));
1028
1029   tab_headers (tbl, 1, 0, 2, 0);
1030
1031   tab_box (tbl,
1032            TAL_2, TAL_2,
1033            -1, -1,
1034            0, 0,
1035            n_cols - 1,
1036            n_rows - 1);
1037
1038   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1039   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1040
1041
1042   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1043   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1044
1045
1046   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1047   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1048   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1049
1050   tab_joint_text (tbl, 1, 0, 2, 0,
1051                   TAT_TITLE | TAB_CENTER,
1052                   _("Valid N (listwise)"));
1053
1054
1055   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1056   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1057
1058
1059   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1060   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1061
1062   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1063   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1064
1065   tab_submit (tbl);
1066 }
1067
1068
1069 static void
1070 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1071 {
1072   int x = 1;
1073   int i;
1074   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1075   int n_rows = 1;
1076   struct tab_table *tbl ;
1077
1078   for (i = 0; i < roc->n_vars; ++i)
1079     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1080
1081   tbl = tab_create (n_cols, n_rows);
1082
1083   if ( roc->n_vars > 1)
1084     tab_title (tbl, _("Coordinates of the Curve"));
1085   else
1086     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1087
1088
1089   tab_headers (tbl, 1, 0, 1, 0);
1090
1091   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1092
1093   if ( roc->n_vars > 1)
1094     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1095
1096   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1097   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1098   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1099
1100   tab_box (tbl,
1101            TAL_2, TAL_2,
1102            -1, TAL_1,
1103            0, 0,
1104            n_cols - 1,
1105            n_rows - 1);
1106
1107   if ( roc->n_vars > 1)
1108     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1109
1110   for (i = 0; i < roc->n_vars; ++i)
1111     {
1112       struct ccase *cc;
1113       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1114
1115       if ( roc->n_vars > 1)
1116         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1117
1118       if ( i > 0)
1119         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1120
1121
1122       for (; (cc = casereader_read (r)) != NULL;
1123            case_unref (cc), x++)
1124         {
1125           const double se = case_data_idx (cc, ROC_TP)->f /
1126             (
1127              case_data_idx (cc, ROC_TP)->f
1128              +
1129              case_data_idx (cc, ROC_FN)->f
1130              );
1131
1132           const double sp = case_data_idx (cc, ROC_TN)->f /
1133             (
1134              case_data_idx (cc, ROC_TN)->f
1135              +
1136              case_data_idx (cc, ROC_FP)->f
1137              );
1138
1139           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1140                       var_get_print_format (roc->vars[i]));
1141
1142           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1143           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1144         }
1145
1146       casereader_destroy (r);
1147     }
1148
1149   tab_submit (tbl);
1150 }
1151
1152
1153 static void
1154 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1155 {
1156   show_summary (roc);
1157
1158   if ( roc->curve )
1159     {
1160       struct roc_chart *rc;
1161       size_t i;
1162
1163       rc = roc_chart_create (roc->reference);
1164       for (i = 0; i < roc->n_vars; i++)
1165         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1166                            rs[i].cutpoint_rdr);
1167       roc_chart_submit (rc);
1168     }
1169
1170   show_auc (rs, roc);
1171
1172   if ( roc->print_coords )
1173     show_coords (rs, roc);
1174 }
1175