Be more forgiving about ROC syntax
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <language/stats/roc.h>
20
21 #include <data/casegrouper.h>
22 #include <data/casereader.h>
23 #include <data/casewriter.h>
24 #include <data/dictionary.h>
25 #include <data/format.h>
26 #include <data/procedure.h>
27 #include <data/subcase.h>
28 #include <language/command.h>
29 #include <language/lexer/lexer.h>
30 #include <language/lexer/value-parser.h>
31 #include <language/lexer/variable-parser.h>
32 #include <libpspp/misc.h>
33 #include <math/sort.h>
34 #include <output/chart-item.h>
35 #include <output/charts/roc-chart.h>
36 #include <output/tab.h>
37
38 #include <gsl/gsl_cdf.h>
39
40 #include "gettext.h"
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
43
44 struct cmd_roc
45 {
46   size_t n_vars;
47   const struct variable **vars;
48   const struct dictionary *dict;
49
50   const struct variable *state_var;
51   union value state_value;
52
53   /* Plot the roc curve */
54   bool curve;
55   /* Plot the reference line */
56   bool reference;
57
58   double ci;
59
60   bool print_coords;
61   bool print_se;
62   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
63                       should be used */
64   enum mv_class exclude;
65
66   bool invert ; /* True iff a smaller test result variable indicates
67                    a positive result */
68
69   double pos;
70   double neg;
71   double pos_weighted;
72   double neg_weighted;
73 };
74
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
76
77 int
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
79 {
80   struct cmd_roc roc ;
81   const struct dictionary *dict = dataset_dict (ds);
82
83   roc.vars = NULL;
84   roc.n_vars = 0;
85   roc.print_se = false;
86   roc.print_coords = false;
87   roc.exclude = MV_ANY;
88   roc.curve = true;
89   roc.reference = false;
90   roc.ci = 95;
91   roc.bi_neg_exp = false;
92   roc.invert = false;
93   roc.pos = roc.pos_weighted = 0;
94   roc.neg = roc.neg_weighted = 0;
95   roc.dict = dataset_dict (ds);
96   roc.state_var = NULL;
97
98   lex_match (lexer, '/');
99   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
100                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
101     goto error;
102
103   if ( ! lex_force_match (lexer, T_BY))
104     {
105       goto error;
106     }
107
108   roc.state_var = parse_variable (lexer, dict);
109
110   if ( !lex_force_match (lexer, '('))
111     {
112       goto error;
113     }
114
115   value_init (&roc.state_value, var_get_width (roc.state_var));
116   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
117
118
119   if ( !lex_force_match (lexer, ')'))
120     {
121       goto error;
122     }
123
124
125   while (lex_token (lexer) != '.')
126     {
127       lex_match (lexer, '/');
128       if (lex_match_id (lexer, "MISSING"))
129         {
130           lex_match (lexer, '=');
131           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
132             {
133               if (lex_match_id (lexer, "INCLUDE"))
134                 {
135                   roc.exclude = MV_SYSTEM;
136                 }
137               else if (lex_match_id (lexer, "EXCLUDE"))
138                 {
139                   roc.exclude = MV_ANY;
140                 }
141               else
142                 {
143                   lex_error (lexer, NULL);
144                   goto error;
145                 }
146             }
147         }
148       else if (lex_match_id (lexer, "PLOT"))
149         {
150           lex_match (lexer, '=');
151           if (lex_match_id (lexer, "CURVE"))
152             {
153               roc.curve = true;
154               if (lex_match (lexer, '('))
155                 {
156                   roc.reference = true;
157                   lex_force_match_id (lexer, "REFERENCE");
158                   lex_force_match (lexer, ')');
159                 }
160             }
161           else if (lex_match_id (lexer, "NONE"))
162             {
163               roc.curve = false;
164             }
165           else
166             {
167               lex_error (lexer, NULL);
168               goto error;
169             }
170         }
171       else if (lex_match_id (lexer, "PRINT"))
172         {
173           lex_match (lexer, '=');
174           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
175             {
176               if (lex_match_id (lexer, "SE"))
177                 {
178                   roc.print_se = true;
179                 }
180               else if (lex_match_id (lexer, "COORDINATES"))
181                 {
182                   roc.print_coords = true;
183                 }
184               else
185                 {
186                   lex_error (lexer, NULL);
187                   goto error;
188                 }
189             }
190         }
191       else if (lex_match_id (lexer, "CRITERIA"))
192         {
193           lex_match (lexer, '=');
194           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
195             {
196               if (lex_match_id (lexer, "CUTOFF"))
197                 {
198                   lex_force_match (lexer, '(');
199                   if (lex_match_id (lexer, "INCLUDE"))
200                     {
201                       roc.exclude = MV_SYSTEM;
202                     }
203                   else if (lex_match_id (lexer, "EXCLUDE"))
204                     {
205                       roc.exclude = MV_USER | MV_SYSTEM;
206                     }
207                   else
208                     {
209                       lex_error (lexer, NULL);
210                       goto error;
211                     }
212                   lex_force_match (lexer, ')');
213                 }
214               else if (lex_match_id (lexer, "TESTPOS"))
215                 {
216                   lex_force_match (lexer, '(');
217                   if (lex_match_id (lexer, "LARGE"))
218                     {
219                       roc.invert = false;
220                     }
221                   else if (lex_match_id (lexer, "SMALL"))
222                     {
223                       roc.invert = true;
224                     }
225                   else
226                     {
227                       lex_error (lexer, NULL);
228                       goto error;
229                     }
230                   lex_force_match (lexer, ')');
231                 }
232               else if (lex_match_id (lexer, "CI"))
233                 {
234                   lex_force_match (lexer, '(');
235                   lex_force_num (lexer);
236                   roc.ci = lex_number (lexer);
237                   lex_get (lexer);
238                   lex_force_match (lexer, ')');
239                 }
240               else if (lex_match_id (lexer, "DISTRIBUTION"))
241                 {
242                   lex_force_match (lexer, '(');
243                   if (lex_match_id (lexer, "FREE"))
244                     {
245                       roc.bi_neg_exp = false;
246                     }
247                   else if (lex_match_id (lexer, "NEGEXPO"))
248                     {
249                       roc.bi_neg_exp = true;
250                     }
251                   else
252                     {
253                       lex_error (lexer, NULL);
254                       goto error;
255                     }
256                   lex_force_match (lexer, ')');
257                 }
258               else
259                 {
260                   lex_error (lexer, NULL);
261                   goto error;
262                 }
263             }
264         }
265       else
266         {
267           lex_error (lexer, NULL);
268           break;
269         }
270     }
271
272   if ( ! run_roc (ds, &roc)) 
273     goto error;
274
275   value_destroy (&roc.state_value, var_get_width (roc.state_var));
276   free (roc.vars);
277   return CMD_SUCCESS;
278
279  error:
280   if ( roc.state_var)
281     value_destroy (&roc.state_value, var_get_width (roc.state_var));
282   free (roc.vars);
283   return CMD_FAILURE;
284 }
285
286
287
288
289 static void
290 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
291
292
293 static int
294 run_roc (struct dataset *ds, struct cmd_roc *roc)
295 {
296   struct dictionary *dict = dataset_dict (ds);
297   bool ok;
298   struct casereader *group;
299
300   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
301   while (casegrouper_get_next_group (grouper, &group))
302     {
303       do_roc (roc, group, dataset_dict (ds));
304     }
305   ok = casegrouper_destroy (grouper);
306   ok = proc_commit (ds) && ok;
307
308   return ok;
309 }
310
311 #if 0
312 static void
313 dump_casereader (struct casereader *reader)
314 {
315   struct ccase *c;
316   struct casereader *r = casereader_clone (reader);
317
318   for ( ; (c = casereader_read (r) ); case_unref (c))
319     {
320       int i;
321       for (i = 0 ; i < case_get_value_cnt (c); ++i)
322         {
323           printf ("%g ", case_data_idx (c, i)->f);
324         }
325       printf ("\n");
326     }
327
328   casereader_destroy (r);
329 }
330 #endif
331
332
333 /* 
334    Return true iff the state variable indicates that C has positive actual state.
335
336    As a side effect, this function also accumulates the roc->{pos,neg} and 
337    roc->{pos,neg}_weighted counts.
338  */
339 static bool
340 match_positives (const struct ccase *c, void *aux)
341 {
342   struct cmd_roc *roc = aux;
343   const struct variable *wv = dict_get_weight (roc->dict);
344   const double weight = wv ? case_data (c, wv)->f : 1.0;
345
346   const bool positive =
347   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
348     var_get_width (roc->state_var)));
349
350   if ( positive )
351     {
352       roc->pos++;
353       roc->pos_weighted += weight;
354     }
355   else
356     {
357       roc->neg++;
358       roc->neg_weighted += weight;
359     }
360
361   return positive;
362 }
363
364
365 #define VALUE  0
366 #define N_EQ   1
367 #define N_PRED 2
368
369 /* Some intermediate state for calculating the cutpoints and the 
370    standard error values */
371 struct roc_state
372 {
373   double auc;  /* Area under the curve */
374
375   double n1;  /* total weight of positives */
376   double n2;  /* total weight of negatives */
377
378   /* intermediates for standard error */
379   double q1hat; 
380   double q2hat;
381
382   /* intermediates for cutpoints */
383   struct casewriter *cutpoint_wtr;
384   struct casereader *cutpoint_rdr;
385   double prev_result;
386   double min;
387   double max;
388 };
389
390 /* 
391    Return a new casereader based upon CUTPOINT_RDR.
392    The number of "positive" cases are placed into
393    the position TRUE_INDEX, and the number of "negative" cases
394    into FALSE_INDEX.
395    POS_COND and RESULT determine the semantics of what is 
396    "positive".
397    WEIGHT is the value of a single count.
398  */
399 static struct casereader *
400 accumulate_counts (struct casereader *cutpoint_rdr, 
401                    double result, double weight, 
402                    bool (*pos_cond) (double, double),
403                    int true_index, int false_index)
404 {
405   const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
406   struct casewriter *w =
407     autopaging_writer_create (proto);
408   struct casereader *r = casereader_clone (cutpoint_rdr);
409   struct ccase *cpc;
410   double prev_cp = SYSMIS;
411
412   for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
413     {
414       struct ccase *new_case;
415       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
416
417       assert (cp != SYSMIS);
418
419       /* We don't want duplicates here */
420       if ( cp == prev_cp )
421         continue;
422
423       new_case = case_clone (cpc);
424
425       if ( pos_cond (result, cp))
426         case_data_rw_idx (new_case, true_index)->f += weight;
427       else
428         case_data_rw_idx (new_case, false_index)->f += weight;
429
430       prev_cp = cp;
431
432       casewriter_write (w, new_case);
433     }
434   casereader_destroy (r);
435
436   return casewriter_make_reader (w);
437 }
438
439
440
441 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
442
443 /*
444   This function does 3 things:
445
446   1. Counts the number of cases which are equal to every other case in READER,
447   and those cases for which the relationship between it and every other case
448   satifies PRED (normally either > or <).  VAR is variable defining a case's value
449   for this purpose.
450
451   2. Counts the number of true and false cases in reader, and populates
452   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
453   which receive these values.  POS_COND is the condition defining true
454   and false.
455   
456   3. CC is filled with the cumulative weight of all cases of READER.
457 */
458 static struct casereader *
459 process_group (const struct variable *var, struct casereader *reader,
460                bool (*pred) (double, double),
461                const struct dictionary *dict,
462                double *cc,
463                struct casereader **cutpoint_rdr, 
464                bool (*pos_cond) (double, double),
465                int true_index,
466                int false_index)
467 {
468   const struct variable *w = dict_get_weight (dict);
469
470   struct casereader *r1 =
471     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
472
473   const int weight_idx  = w ? var_get_case_index (w) :
474     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
475   
476   struct ccase *c1;
477
478   struct casereader *rclone = casereader_clone (r1);
479   struct casewriter *wtr;
480   struct caseproto *proto = caseproto_create ();
481
482   proto = caseproto_add_width (proto, 0);
483   proto = caseproto_add_width (proto, 0);
484   proto = caseproto_add_width (proto, 0);
485
486   wtr = autopaging_writer_create (proto);  
487
488   *cc = 0;
489
490   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
491     {
492       struct ccase *new_case = case_create (proto);
493       struct ccase *c2;
494       struct casereader *r2 = casereader_clone (rclone);
495
496       const double weight1 = case_data_idx (c1, weight_idx)->f;
497       const double d1 = case_data (c1, var)->f;
498       double n_eq = 0.0;
499       double n_pred = 0.0;
500
501       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
502                                          pos_cond,
503                                          true_index, false_index);
504
505       *cc += weight1;
506
507       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
508         {
509           const double d2 = case_data (c2, var)->f;
510           const double weight2 = case_data_idx (c2, weight_idx)->f;
511
512           if ( d1 == d2 )
513             {
514               n_eq += weight2;
515               continue;
516             }
517           else  if ( pred (d2, d1))
518             {
519               n_pred += weight2;
520             }
521         }
522
523       case_data_rw_idx (new_case, VALUE)->f = d1;
524       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
525       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
526
527       casewriter_write (wtr, new_case);
528
529       casereader_destroy (r2);
530     }
531
532   casereader_destroy (r1);
533   casereader_destroy (rclone);
534
535   return casewriter_make_reader (wtr);
536 }
537
538 /* Some more indeces into case data */
539 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
540 #define N_POS_GT 2  /* number of postive cases with values greater than n */
541 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
542 #define N_NEG_LT 4  /* number of negative cases with values less than n */
543
544 static bool
545 gt (double d1, double d2)
546 {
547   return d1 > d2;
548 }
549
550
551 static bool
552 ge (double d1, double d2)
553 {
554   return d1 > d2;
555 }
556
557 static bool
558 lt (double d1, double d2)
559 {
560   return d1 < d2;
561 }
562
563
564 /*
565   Return a casereader with width 3,
566   populated with cases based upon READER.
567   The cases will have the values:
568   (N, number of cases equal to N, number of cases greater than N)
569   As a side effect, update RS->n1 with the number of positive cases.
570 */
571 static struct casereader *
572 process_positive_group (const struct variable *var, struct casereader *reader,
573                         const struct dictionary *dict,
574                         struct roc_state *rs)
575 {
576   return process_group (var, reader, gt, dict, &rs->n1,
577                         &rs->cutpoint_rdr,
578                         ge,
579                         ROC_TP, ROC_FN);
580 }
581
582 /*
583   Return a casereader with width 3,
584   populated with cases based upon READER.
585   The cases will have the values:
586   (N, number of cases equal to N, number of cases less than N)
587   As a side effect, update RS->n2 with the number of negative cases.
588 */
589 static struct casereader *
590 process_negative_group (const struct variable *var, struct casereader *reader,
591                         const struct dictionary *dict,
592                         struct roc_state *rs)
593 {
594   return process_group (var, reader, lt, dict, &rs->n2,
595                         &rs->cutpoint_rdr,
596                         lt,
597                         ROC_TN, ROC_FP);
598 }
599
600
601
602
603 static void
604 append_cutpoint (struct casewriter *writer, double cutpoint)
605 {
606   struct ccase *cc = case_create (casewriter_get_proto (writer));
607
608   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
609   case_data_rw_idx (cc, ROC_TP)->f = 0;
610   case_data_rw_idx (cc, ROC_FN)->f = 0;
611   case_data_rw_idx (cc, ROC_TN)->f = 0;
612   case_data_rw_idx (cc, ROC_FP)->f = 0;
613
614   casewriter_write (writer, cc);
615 }
616
617
618 /* 
619    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
620    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
621    reader will be populated with its final number of cases.
622    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
623    value.  The other entries will be initialised to zero.
624 */
625 static void
626 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
627 {
628   int i;
629   struct casereader *r = casereader_clone (input);
630   struct ccase *c;
631   struct caseproto *proto = caseproto_create ();
632
633   struct subcase ordering;
634   subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
635
636   proto = caseproto_add_width (proto, 0); /* cutpoint */
637   proto = caseproto_add_width (proto, 0); /* ROC_TP */
638   proto = caseproto_add_width (proto, 0); /* ROC_FN */
639   proto = caseproto_add_width (proto, 0); /* ROC_TN */
640   proto = caseproto_add_width (proto, 0); /* ROC_FP */
641
642   for (i = 0 ; i < roc->n_vars; ++i)
643     {
644       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
645       rs[i].prev_result = SYSMIS;
646       rs[i].max = -DBL_MAX;
647       rs[i].min = DBL_MAX;
648     }
649
650   for (; (c = casereader_read (r)) != NULL; case_unref (c))
651     {
652       for (i = 0 ; i < roc->n_vars; ++i)
653         {
654           const union value *v = case_data (c, roc->vars[i]); 
655           const double result = v->f;
656
657           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
658             continue;
659
660           minimize (&rs[i].min, result);
661           maximize (&rs[i].max, result);
662
663           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
664             {
665               const double mean = (result + rs[i].prev_result ) / 2.0;
666               append_cutpoint (rs[i].cutpoint_wtr, mean);
667             }
668
669           rs[i].prev_result = result;
670         }
671     }
672   casereader_destroy (r);
673
674
675   /* Append the min and max cutpoints */
676   for (i = 0 ; i < roc->n_vars; ++i)
677     {
678       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
679       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
680
681       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
682     }
683 }
684
685 static void
686 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
687 {
688   int i;
689
690   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
691
692   struct casereader *negatives = NULL;
693   struct casereader *positives = NULL;
694
695   struct caseproto *n_proto = caseproto_create ();
696
697   struct subcase up_ordering;
698   struct subcase down_ordering;
699
700   struct casewriter *neg_wtr = NULL;
701
702   struct casereader *input = casereader_create_filter_missing (reader,
703                                                                roc->vars, roc->n_vars,
704                                                                roc->exclude,
705                                                                NULL,
706                                                                NULL);
707
708   input = casereader_create_filter_missing (input,
709                                             &roc->state_var, 1,
710                                             roc->exclude,
711                                             NULL,
712                                             NULL);
713
714   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
715
716   prepare_cutpoints (roc, rs, input);
717
718
719   /* Separate the positive actual state cases from the negative ones */
720   positives = 
721     casereader_create_filter_func (input,
722                                    match_positives,
723                                    NULL,
724                                    roc,
725                                    neg_wtr);
726
727   n_proto = caseproto_create ();
728       
729   n_proto = caseproto_add_width (n_proto, 0);
730   n_proto = caseproto_add_width (n_proto, 0);
731   n_proto = caseproto_add_width (n_proto, 0);
732   n_proto = caseproto_add_width (n_proto, 0);
733   n_proto = caseproto_add_width (n_proto, 0);
734
735   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
736   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
737
738   for (i = 0 ; i < roc->n_vars; ++i)
739     {
740       struct casewriter *w = NULL;
741       struct casereader *r = NULL;
742
743       struct ccase *c;
744
745       struct ccase *cpos;
746       struct casereader *n_neg ;
747       const struct variable *var = roc->vars[i];
748
749       struct casereader *neg ;
750       struct casereader *pos = casereader_clone (positives);
751
752
753       struct casereader *n_pos =
754         process_positive_group (var, pos, dict, &rs[i]);
755
756       if ( negatives == NULL)
757         {
758           negatives = casewriter_make_reader (neg_wtr);
759         }
760
761       neg = casereader_clone (negatives);
762
763       n_neg = process_negative_group (var, neg, dict, &rs[i]);
764
765
766       /* Merge the n_pos and n_neg casereaders */
767       w = sort_create_writer (&up_ordering, n_proto);
768       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
769         {
770           struct ccase *pos_case = case_create (n_proto);
771           struct ccase *cneg;
772           const double jpos = case_data_idx (cpos, VALUE)->f;
773
774           while ((cneg = casereader_read (n_neg)))
775             {
776               struct ccase *nc = case_create (n_proto);
777
778               const double jneg = case_data_idx (cneg, VALUE)->f;
779
780               case_data_rw_idx (nc, VALUE)->f = jneg;
781               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
782
783               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
784
785               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
786               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
787
788               casewriter_write (w, nc);
789
790               case_unref (cneg);
791               if ( jneg > jpos)
792                 break;
793             }
794
795           case_data_rw_idx (pos_case, VALUE)->f = jpos;
796           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
797           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
798           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
799           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
800
801           casewriter_write (w, pos_case);
802         }
803
804 /* These aren't used anymore */
805 #undef N_EQ
806 #undef N_PRED
807
808       r = casewriter_make_reader (w);
809
810       /* Propagate the N_POS_GT values from the positive cases
811          to the negative ones */
812       {
813         double prev_pos_gt = rs[i].n1;
814         w = sort_create_writer (&down_ordering, n_proto);
815
816         for ( ; (c = casereader_read (r) ); case_unref (c))
817           {
818             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
819             struct ccase *nc = case_clone (c);
820
821             if ( n_pos_gt == SYSMIS)
822               {
823                 n_pos_gt = prev_pos_gt;
824                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
825               }
826             
827             casewriter_write (w, nc);
828             prev_pos_gt = n_pos_gt;
829           }
830
831         r = casewriter_make_reader (w);
832       }
833
834       /* Propagate the N_NEG_LT values from the negative cases
835          to the positive ones */
836       {
837         double prev_neg_lt = rs[i].n2;
838         w = sort_create_writer (&up_ordering, n_proto);
839
840         for ( ; (c = casereader_read (r) ); case_unref (c))
841           {
842             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
843             struct ccase *nc = case_clone (c);
844
845             if ( n_neg_lt == SYSMIS)
846               {
847                 n_neg_lt = prev_neg_lt;
848                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
849               }
850             
851             casewriter_write (w, nc);
852             prev_neg_lt = n_neg_lt;
853           }
854
855         r = casewriter_make_reader (w);
856       }
857
858       {
859         struct ccase *prev_case = NULL;
860         for ( ; (c = casereader_read (r) ); case_unref (c))
861           {
862             const struct ccase *next_case = casereader_peek (r, 0);
863
864             const double j = case_data_idx (c, VALUE)->f;
865             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
866             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
867             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
868             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
869
870             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
871               {
872                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
873                   {
874                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
875                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
876                   }
877
878                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
879                   {
880                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
881                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
882                   }
883               }
884
885             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
886               {
887                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
888
889                 rs[i].q1hat +=
890                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
891                 rs[i].q2hat +=
892                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
893
894               }
895
896             case_unref (prev_case);
897             prev_case = case_clone (c);
898           }
899
900         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
901         if ( roc->invert ) 
902           rs[i].auc = 1 - rs[i].auc;
903
904         if ( roc->bi_neg_exp )
905           {
906             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
907             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
908           }
909         else
910           {
911             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
912             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
913           }
914       }
915     }
916
917   casereader_destroy (positives);
918   casereader_destroy (negatives);
919
920   output_roc (rs, roc);
921
922   free (rs);
923 }
924
925 static void
926 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
927 {
928   int i;
929   const int n_fields = roc->print_se ? 5 : 1;
930   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
931   const int n_rows = 2 + roc->n_vars;
932   struct tab_table *tbl = tab_create (n_cols, n_rows);
933
934   if ( roc->n_vars > 1)
935     tab_title (tbl, _("Area Under the Curve"));
936   else
937     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
938
939   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
940
941
942   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
943
944   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
945
946   tab_box (tbl,
947            TAL_2, TAL_2,
948            -1, TAL_1,
949            0, 0,
950            n_cols - 1,
951            n_rows - 1);
952
953   if ( roc->print_se )
954     {
955       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
956       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
957
958       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
959       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
960
961       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
962                              TAT_TITLE | TAB_CENTER,
963                              _("Asymp. %g%% Confidence Interval"), roc->ci);
964       tab_vline (tbl, 0, n_cols - 1, 0, 0);
965       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
966     }
967
968   if ( roc->n_vars > 1)
969     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
970
971   if ( roc->n_vars > 1)
972     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
973
974
975   for ( i = 0 ; i < roc->n_vars ; ++i )
976     {
977       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
978
979       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
980
981       if ( roc->print_se )
982         {
983           double se ;
984           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
985                                       (12 * rs[i].n1 * rs[i].n2));
986           double ci ;
987           double yy ;
988
989           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
990             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
991
992           se /= rs[i].n1 * rs[i].n2;
993
994           se = sqrt (se);
995
996           tab_double (tbl, n_cols - 4, 2 + i, 0,
997                       se,
998                       NULL);
999
1000           ci = 1 - roc->ci / 100.0;
1001           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1002
1003           tab_double (tbl, n_cols - 2, 2 + i, 0,
1004                       rs[i].auc - yy,
1005                       NULL);
1006
1007           tab_double (tbl, n_cols - 1, 2 + i, 0,
1008                       rs[i].auc + yy,
1009                       NULL);
1010
1011           tab_double (tbl, n_cols - 3, 2 + i, 0,
1012                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1013                       NULL);
1014         }
1015     }
1016
1017   tab_submit (tbl);
1018 }
1019
1020
1021 static void
1022 show_summary (const struct cmd_roc *roc)
1023 {
1024   const int n_cols = 3;
1025   const int n_rows = 4;
1026   struct tab_table *tbl = tab_create (n_cols, n_rows);
1027
1028   tab_title (tbl, _("Case Summary"));
1029
1030   tab_headers (tbl, 1, 0, 2, 0);
1031
1032   tab_box (tbl,
1033            TAL_2, TAL_2,
1034            -1, -1,
1035            0, 0,
1036            n_cols - 1,
1037            n_rows - 1);
1038
1039   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1040   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1041
1042
1043   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1044   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1045
1046
1047   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1048   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1049   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1050
1051   tab_joint_text (tbl, 1, 0, 2, 0,
1052                   TAT_TITLE | TAB_CENTER,
1053                   _("Valid N (listwise)"));
1054
1055
1056   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1057   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1058
1059
1060   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1061   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1062
1063   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1064   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1065
1066   tab_submit (tbl);
1067 }
1068
1069
1070 static void
1071 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1072 {
1073   int x = 1;
1074   int i;
1075   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1076   int n_rows = 1;
1077   struct tab_table *tbl ;
1078
1079   for (i = 0; i < roc->n_vars; ++i)
1080     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1081
1082   tbl = tab_create (n_cols, n_rows);
1083
1084   if ( roc->n_vars > 1)
1085     tab_title (tbl, _("Coordinates of the Curve"));
1086   else
1087     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1088
1089
1090   tab_headers (tbl, 1, 0, 1, 0);
1091
1092   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1093
1094   if ( roc->n_vars > 1)
1095     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1096
1097   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1098   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1099   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1100
1101   tab_box (tbl,
1102            TAL_2, TAL_2,
1103            -1, TAL_1,
1104            0, 0,
1105            n_cols - 1,
1106            n_rows - 1);
1107
1108   if ( roc->n_vars > 1)
1109     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1110
1111   for (i = 0; i < roc->n_vars; ++i)
1112     {
1113       struct ccase *cc;
1114       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1115
1116       if ( roc->n_vars > 1)
1117         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1118
1119       if ( i > 0)
1120         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1121
1122
1123       for (; (cc = casereader_read (r)) != NULL;
1124            case_unref (cc), x++)
1125         {
1126           const double se = case_data_idx (cc, ROC_TP)->f /
1127             (
1128              case_data_idx (cc, ROC_TP)->f
1129              +
1130              case_data_idx (cc, ROC_FN)->f
1131              );
1132
1133           const double sp = case_data_idx (cc, ROC_TN)->f /
1134             (
1135              case_data_idx (cc, ROC_TN)->f
1136              +
1137              case_data_idx (cc, ROC_FP)->f
1138              );
1139
1140           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1141                       var_get_print_format (roc->vars[i]));
1142
1143           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1144           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1145         }
1146
1147       casereader_destroy (r);
1148     }
1149
1150   tab_submit (tbl);
1151 }
1152
1153
1154 static void
1155 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1156 {
1157   show_summary (roc);
1158
1159   if ( roc->curve )
1160     {
1161       struct roc_chart *rc;
1162       size_t i;
1163
1164       rc = roc_chart_create (roc->reference);
1165       for (i = 0; i < roc->n_vars; i++)
1166         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1167                            rs[i].cutpoint_rdr);
1168       roc_chart_submit (rc);
1169     }
1170
1171   show_auc (rs, roc);
1172
1173   if ( roc->print_coords )
1174     show_coords (rs, roc);
1175 }
1176