Merge master into output branch.
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <language/stats/roc.h>
20
21 #include <data/procedure.h>
22 #include <language/lexer/variable-parser.h>
23 #include <language/lexer/value-parser.h>
24 #include <language/command.h>
25 #include <language/lexer/lexer.h>
26
27 #include <data/casegrouper.h>
28 #include <data/casereader.h>
29 #include <data/casewriter.h>
30 #include <data/dictionary.h>
31 #include <data/format.h>
32 #include <math/sort.h>
33 #include <data/subcase.h>
34
35
36 #include <libpspp/misc.h>
37
38 #include <gsl/gsl_cdf.h>
39 #include <output/table.h>
40
41 #include <output/chart.h>
42 #include <output/charts/roc-chart.h>
43
44 #include "gettext.h"
45 #define _(msgid) gettext (msgid)
46 #define N_(msgid) msgid
47
48 struct cmd_roc
49 {
50   size_t n_vars;
51   const struct variable **vars;
52   const struct dictionary *dict;
53
54   const struct variable *state_var ;
55   union value state_value;
56
57   /* Plot the roc curve */
58   bool curve;
59   /* Plot the reference line */
60   bool reference;
61
62   double ci;
63
64   bool print_coords;
65   bool print_se;
66   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
67                       should be used */
68   enum mv_class exclude;
69
70   bool invert ; /* True iff a smaller test result variable indicates
71                    a positive result */
72
73   double pos;
74   double neg;
75   double pos_weighted;
76   double neg_weighted;
77 };
78
79 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
80
81 int
82 cmd_roc (struct lexer *lexer, struct dataset *ds)
83 {
84   struct cmd_roc roc ;
85   const struct dictionary *dict = dataset_dict (ds);
86
87   roc.vars = NULL;
88   roc.n_vars = 0;
89   roc.print_se = false;
90   roc.print_coords = false;
91   roc.exclude = MV_ANY;
92   roc.curve = true;
93   roc.reference = false;
94   roc.ci = 95;
95   roc.bi_neg_exp = false;
96   roc.invert = false;
97   roc.pos = roc.pos_weighted = 0;
98   roc.neg = roc.neg_weighted = 0;
99   roc.dict = dataset_dict (ds);
100
101   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
102                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
103     goto error;
104
105   if ( ! lex_force_match (lexer, T_BY))
106     {
107       goto error;
108     }
109
110   roc.state_var = parse_variable (lexer, dict);
111
112   if ( !lex_force_match (lexer, '('))
113     {
114       goto error;
115     }
116
117   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
118
119
120   if ( !lex_force_match (lexer, ')'))
121     {
122       goto error;
123     }
124
125
126   while (lex_token (lexer) != '.')
127     {
128       lex_match (lexer, '/');
129       if (lex_match_id (lexer, "MISSING"))
130         {
131           lex_match (lexer, '=');
132           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
133             {
134               if (lex_match_id (lexer, "INCLUDE"))
135                 {
136                   roc.exclude = MV_SYSTEM;
137                 }
138               else if (lex_match_id (lexer, "EXCLUDE"))
139                 {
140                   roc.exclude = MV_ANY;
141                 }
142               else
143                 {
144                   lex_error (lexer, NULL);
145                   goto error;
146                 }
147             }
148         }
149       else if (lex_match_id (lexer, "PLOT"))
150         {
151           lex_match (lexer, '=');
152           if (lex_match_id (lexer, "CURVE"))
153             {
154               roc.curve = true;
155               if (lex_match (lexer, '('))
156                 {
157                   roc.reference = true;
158                   lex_force_match_id (lexer, "REFERENCE");
159                   lex_force_match (lexer, ')');
160                 }
161             }
162           else if (lex_match_id (lexer, "NONE"))
163             {
164               roc.curve = false;
165             }
166           else
167             {
168               lex_error (lexer, NULL);
169               goto error;
170             }
171         }
172       else if (lex_match_id (lexer, "PRINT"))
173         {
174           lex_match (lexer, '=');
175           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
176             {
177               if (lex_match_id (lexer, "SE"))
178                 {
179                   roc.print_se = true;
180                 }
181               else if (lex_match_id (lexer, "COORDINATES"))
182                 {
183                   roc.print_coords = true;
184                 }
185               else
186                 {
187                   lex_error (lexer, NULL);
188                   goto error;
189                 }
190             }
191         }
192       else if (lex_match_id (lexer, "CRITERIA"))
193         {
194           lex_match (lexer, '=');
195           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
196             {
197               if (lex_match_id (lexer, "CUTOFF"))
198                 {
199                   lex_force_match (lexer, '(');
200                   if (lex_match_id (lexer, "INCLUDE"))
201                     {
202                       roc.exclude = MV_SYSTEM;
203                     }
204                   else if (lex_match_id (lexer, "EXCLUDE"))
205                     {
206                       roc.exclude = MV_USER | MV_SYSTEM;
207                     }
208                   else
209                     {
210                       lex_error (lexer, NULL);
211                       goto error;
212                     }
213                   lex_force_match (lexer, ')');
214                 }
215               else if (lex_match_id (lexer, "TESTPOS"))
216                 {
217                   lex_force_match (lexer, '(');
218                   if (lex_match_id (lexer, "LARGE"))
219                     {
220                       roc.invert = false;
221                     }
222                   else if (lex_match_id (lexer, "SMALL"))
223                     {
224                       roc.invert = true;
225                     }
226                   else
227                     {
228                       lex_error (lexer, NULL);
229                       goto error;
230                     }
231                   lex_force_match (lexer, ')');
232                 }
233               else if (lex_match_id (lexer, "CI"))
234                 {
235                   lex_force_match (lexer, '(');
236                   lex_force_num (lexer);
237                   roc.ci = lex_number (lexer);
238                   lex_get (lexer);
239                   lex_force_match (lexer, ')');
240                 }
241               else if (lex_match_id (lexer, "DISTRIBUTION"))
242                 {
243                   lex_force_match (lexer, '(');
244                   if (lex_match_id (lexer, "FREE"))
245                     {
246                       roc.bi_neg_exp = false;
247                     }
248                   else if (lex_match_id (lexer, "NEGEXPO"))
249                     {
250                       roc.bi_neg_exp = true;
251                     }
252                   else
253                     {
254                       lex_error (lexer, NULL);
255                       goto error;
256                     }
257                   lex_force_match (lexer, ')');
258                 }
259               else
260                 {
261                   lex_error (lexer, NULL);
262                   goto error;
263                 }
264             }
265         }
266       else
267         {
268           lex_error (lexer, NULL);
269           break;
270         }
271     }
272
273   if ( ! run_roc (ds, &roc)) 
274     goto error;
275
276   free (roc.vars);
277   return CMD_SUCCESS;
278
279  error:
280   free (roc.vars);
281   return CMD_FAILURE;
282 }
283
284
285
286
287 static void
288 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
289
290
291 static int
292 run_roc (struct dataset *ds, struct cmd_roc *roc)
293 {
294   struct dictionary *dict = dataset_dict (ds);
295   bool ok;
296   struct casereader *group;
297
298   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
299   while (casegrouper_get_next_group (grouper, &group))
300     {
301       do_roc (roc, group, dataset_dict (ds));
302     }
303   ok = casegrouper_destroy (grouper);
304   ok = proc_commit (ds) && ok;
305
306   return ok;
307 }
308
309 #if 0
310 static void
311 dump_casereader (struct casereader *reader)
312 {
313   struct ccase *c;
314   struct casereader *r = casereader_clone (reader);
315
316   for ( ; (c = casereader_read (r) ); case_unref (c))
317     {
318       int i;
319       for (i = 0 ; i < case_get_value_cnt (c); ++i)
320         {
321           printf ("%g ", case_data_idx (c, i)->f);
322         }
323       printf ("\n");
324     }
325
326   casereader_destroy (r);
327 }
328 #endif
329
330
331 /* 
332    Return true iff the state variable indicates that C has positive actual state.
333
334    As a side effect, this function also accumulates the roc->{pos,neg} and 
335    roc->{pos,neg}_weighted counts.
336  */
337 static bool
338 match_positives (const struct ccase *c, void *aux)
339 {
340   struct cmd_roc *roc = aux;
341   const struct variable *wv = dict_get_weight (roc->dict);
342   const double weight = wv ? case_data (c, wv)->f : 1.0;
343
344   const bool positive =
345   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
346     var_get_width (roc->state_var)));
347
348   if ( positive )
349     {
350       roc->pos++;
351       roc->pos_weighted += weight;
352     }
353   else
354     {
355       roc->neg++;
356       roc->neg_weighted += weight;
357     }
358
359   return positive;
360 }
361
362
363 #define VALUE  0
364 #define N_EQ   1
365 #define N_PRED 2
366
367 /* Some intermediate state for calculating the cutpoints and the 
368    standard error values */
369 struct roc_state
370 {
371   double auc;  /* Area under the curve */
372
373   double n1;  /* total weight of positives */
374   double n2;  /* total weight of negatives */
375
376   /* intermediates for standard error */
377   double q1hat; 
378   double q2hat;
379
380   /* intermediates for cutpoints */
381   struct casewriter *cutpoint_wtr;
382   struct casereader *cutpoint_rdr;
383   double prev_result;
384   double min;
385   double max;
386 };
387
388 /* 
389    Return a new casereader based upon CUTPOINT_RDR.
390    The number of "positive" cases are placed into
391    the position TRUE_INDEX, and the number of "negative" cases
392    into FALSE_INDEX.
393    POS_COND and RESULT determine the semantics of what is 
394    "positive".
395    WEIGHT is the value of a single count.
396  */
397 static struct casereader *
398 accumulate_counts (struct casereader *cutpoint_rdr, 
399                    double result, double weight, 
400                    bool (*pos_cond) (double, double),
401                    int true_index, int false_index)
402 {
403   const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
404   struct casewriter *w =
405     autopaging_writer_create (proto);
406   struct casereader *r = casereader_clone (cutpoint_rdr);
407   struct ccase *cpc;
408   double prev_cp = SYSMIS;
409
410   for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
411     {
412       struct ccase *new_case;
413       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
414
415       assert (cp != SYSMIS);
416
417       /* We don't want duplicates here */
418       if ( cp == prev_cp )
419         continue;
420
421       new_case = case_clone (cpc);
422
423       if ( pos_cond (result, cp))
424         case_data_rw_idx (new_case, true_index)->f += weight;
425       else
426         case_data_rw_idx (new_case, false_index)->f += weight;
427
428       prev_cp = cp;
429
430       casewriter_write (w, new_case);
431     }
432   casereader_destroy (r);
433
434   return casewriter_make_reader (w);
435 }
436
437
438
439 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
440
441 /*
442   This function does 3 things:
443
444   1. Counts the number of cases which are equal to every other case in READER,
445   and those cases for which the relationship between it and every other case
446   satifies PRED (normally either > or <).  VAR is variable defining a case's value
447   for this purpose.
448
449   2. Counts the number of true and false cases in reader, and populates
450   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
451   which receive these values.  POS_COND is the condition defining true
452   and false.
453   
454   3. CC is filled with the cumulative weight of all cases of READER.
455 */
456 static struct casereader *
457 process_group (const struct variable *var, struct casereader *reader,
458                bool (*pred) (double, double),
459                const struct dictionary *dict,
460                double *cc,
461                struct casereader **cutpoint_rdr, 
462                bool (*pos_cond) (double, double),
463                int true_index,
464                int false_index)
465 {
466   const struct variable *w = dict_get_weight (dict);
467
468   struct casereader *r1 =
469     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
470
471   const int weight_idx  = w ? var_get_case_index (w) :
472     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
473   
474   struct ccase *c1;
475
476   struct casereader *rclone = casereader_clone (r1);
477   struct casewriter *wtr;
478   struct caseproto *proto = caseproto_create ();
479
480   proto = caseproto_add_width (proto, 0);
481   proto = caseproto_add_width (proto, 0);
482   proto = caseproto_add_width (proto, 0);
483
484   wtr = autopaging_writer_create (proto);  
485
486   *cc = 0;
487
488   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
489     {
490       struct ccase *new_case = case_create (proto);
491       struct ccase *c2;
492       struct casereader *r2 = casereader_clone (rclone);
493
494       const double weight1 = case_data_idx (c1, weight_idx)->f;
495       const double d1 = case_data (c1, var)->f;
496       double n_eq = 0.0;
497       double n_pred = 0.0;
498
499       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
500                                          pos_cond,
501                                          true_index, false_index);
502
503       *cc += weight1;
504
505       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
506         {
507           const double d2 = case_data (c2, var)->f;
508           const double weight2 = case_data_idx (c2, weight_idx)->f;
509
510           if ( d1 == d2 )
511             {
512               n_eq += weight2;
513               continue;
514             }
515           else  if ( pred (d2, d1))
516             {
517               n_pred += weight2;
518             }
519         }
520
521       case_data_rw_idx (new_case, VALUE)->f = d1;
522       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
523       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
524
525       casewriter_write (wtr, new_case);
526
527       casereader_destroy (r2);
528     }
529
530   casereader_destroy (r1);
531   casereader_destroy (rclone);
532
533   return casewriter_make_reader (wtr);
534 }
535
536 /* Some more indeces into case data */
537 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
538 #define N_POS_GT 2  /* number of postive cases with values greater than n */
539 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
540 #define N_NEG_LT 4  /* number of negative cases with values less than n */
541
542 static bool
543 gt (double d1, double d2)
544 {
545   return d1 > d2;
546 }
547
548
549 static bool
550 ge (double d1, double d2)
551 {
552   return d1 > d2;
553 }
554
555 static bool
556 lt (double d1, double d2)
557 {
558   return d1 < d2;
559 }
560
561
562 /*
563   Return a casereader with width 3,
564   populated with cases based upon READER.
565   The cases will have the values:
566   (N, number of cases equal to N, number of cases greater than N)
567   As a side effect, update RS->n1 with the number of positive cases.
568 */
569 static struct casereader *
570 process_positive_group (const struct variable *var, struct casereader *reader,
571                         const struct dictionary *dict,
572                         struct roc_state *rs)
573 {
574   return process_group (var, reader, gt, dict, &rs->n1,
575                         &rs->cutpoint_rdr,
576                         ge,
577                         ROC_TP, ROC_FN);
578 }
579
580 /*
581   Return a casereader with width 3,
582   populated with cases based upon READER.
583   The cases will have the values:
584   (N, number of cases equal to N, number of cases less than N)
585   As a side effect, update RS->n2 with the number of negative cases.
586 */
587 static struct casereader *
588 process_negative_group (const struct variable *var, struct casereader *reader,
589                         const struct dictionary *dict,
590                         struct roc_state *rs)
591 {
592   return process_group (var, reader, lt, dict, &rs->n2,
593                         &rs->cutpoint_rdr,
594                         lt,
595                         ROC_TN, ROC_FP);
596 }
597
598
599
600
601 static void
602 append_cutpoint (struct casewriter *writer, double cutpoint)
603 {
604   struct ccase *cc = case_create (casewriter_get_proto (writer));
605
606   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
607   case_data_rw_idx (cc, ROC_TP)->f = 0;
608   case_data_rw_idx (cc, ROC_FN)->f = 0;
609   case_data_rw_idx (cc, ROC_TN)->f = 0;
610   case_data_rw_idx (cc, ROC_FP)->f = 0;
611
612   casewriter_write (writer, cc);
613 }
614
615
616 /* 
617    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
618    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
619    reader will be populated with its final number of cases.
620    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
621    value.  The other entries will be initialised to zero.
622 */
623 static void
624 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
625 {
626   int i;
627   struct casereader *r = casereader_clone (input);
628   struct ccase *c;
629   struct caseproto *proto = caseproto_create ();
630
631   struct subcase ordering;
632   subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
633
634   proto = caseproto_add_width (proto, 0); /* cutpoint */
635   proto = caseproto_add_width (proto, 0); /* ROC_TP */
636   proto = caseproto_add_width (proto, 0); /* ROC_FN */
637   proto = caseproto_add_width (proto, 0); /* ROC_TN */
638   proto = caseproto_add_width (proto, 0); /* ROC_FP */
639
640   for (i = 0 ; i < roc->n_vars; ++i)
641     {
642       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
643       rs[i].prev_result = SYSMIS;
644       rs[i].max = -DBL_MAX;
645       rs[i].min = DBL_MAX;
646     }
647
648   for (; (c = casereader_read (r)) != NULL; case_unref (c))
649     {
650       for (i = 0 ; i < roc->n_vars; ++i)
651         {
652           const union value *v = case_data (c, roc->vars[i]); 
653           const double result = v->f;
654
655           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
656             continue;
657
658           minimize (&rs[i].min, result);
659           maximize (&rs[i].max, result);
660
661           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
662             {
663               const double mean = (result + rs[i].prev_result ) / 2.0;
664               append_cutpoint (rs[i].cutpoint_wtr, mean);
665             }
666
667           rs[i].prev_result = result;
668         }
669     }
670   casereader_destroy (r);
671
672
673   /* Append the min and max cutpoints */
674   for (i = 0 ; i < roc->n_vars; ++i)
675     {
676       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
677       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
678
679       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
680     }
681 }
682
683 static void
684 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
685 {
686   int i;
687
688   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
689
690   struct casereader *negatives = NULL;
691   struct casereader *positives = NULL;
692
693   struct caseproto *n_proto = caseproto_create ();
694
695   struct subcase up_ordering;
696   struct subcase down_ordering;
697
698   struct casewriter *neg_wtr = NULL;
699
700   struct casereader *input = casereader_create_filter_missing (reader,
701                                                                roc->vars, roc->n_vars,
702                                                                roc->exclude,
703                                                                NULL,
704                                                                NULL);
705
706   input = casereader_create_filter_missing (input,
707                                             &roc->state_var, 1,
708                                             roc->exclude,
709                                             NULL,
710                                             NULL);
711
712   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
713
714   prepare_cutpoints (roc, rs, input);
715
716
717   /* Separate the positive actual state cases from the negative ones */
718   positives = 
719     casereader_create_filter_func (input,
720                                    match_positives,
721                                    NULL,
722                                    roc,
723                                    neg_wtr);
724
725   n_proto = caseproto_create ();
726       
727   n_proto = caseproto_add_width (n_proto, 0);
728   n_proto = caseproto_add_width (n_proto, 0);
729   n_proto = caseproto_add_width (n_proto, 0);
730   n_proto = caseproto_add_width (n_proto, 0);
731   n_proto = caseproto_add_width (n_proto, 0);
732
733   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
734   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
735
736   for (i = 0 ; i < roc->n_vars; ++i)
737     {
738       struct casewriter *w = NULL;
739       struct casereader *r = NULL;
740
741       struct ccase *c;
742
743       struct ccase *cpos;
744       struct casereader *n_neg ;
745       const struct variable *var = roc->vars[i];
746
747       struct casereader *neg ;
748       struct casereader *pos = casereader_clone (positives);
749
750
751       struct casereader *n_pos =
752         process_positive_group (var, pos, dict, &rs[i]);
753
754       if ( negatives == NULL)
755         {
756           negatives = casewriter_make_reader (neg_wtr);
757         }
758
759       neg = casereader_clone (negatives);
760
761       n_neg = process_negative_group (var, neg, dict, &rs[i]);
762
763
764       /* Merge the n_pos and n_neg casereaders */
765       w = sort_create_writer (&up_ordering, n_proto);
766       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
767         {
768           struct ccase *pos_case = case_create (n_proto);
769           struct ccase *cneg;
770           const double jpos = case_data_idx (cpos, VALUE)->f;
771
772           while ((cneg = casereader_read (n_neg)))
773             {
774               struct ccase *nc = case_create (n_proto);
775
776               const double jneg = case_data_idx (cneg, VALUE)->f;
777
778               case_data_rw_idx (nc, VALUE)->f = jneg;
779               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
780
781               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
782
783               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
784               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
785
786               casewriter_write (w, nc);
787
788               case_unref (cneg);
789               if ( jneg > jpos)
790                 break;
791             }
792
793           case_data_rw_idx (pos_case, VALUE)->f = jpos;
794           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
795           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
796           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
797           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
798
799           casewriter_write (w, pos_case);
800         }
801
802 /* These aren't used anymore */
803 #undef N_EQ
804 #undef N_PRED
805
806       r = casewriter_make_reader (w);
807
808       /* Propagate the N_POS_GT values from the positive cases
809          to the negative ones */
810       {
811         double prev_pos_gt = rs[i].n1;
812         w = sort_create_writer (&down_ordering, n_proto);
813
814         for ( ; (c = casereader_read (r) ); case_unref (c))
815           {
816             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
817             struct ccase *nc = case_clone (c);
818
819             if ( n_pos_gt == SYSMIS)
820               {
821                 n_pos_gt = prev_pos_gt;
822                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
823               }
824             
825             casewriter_write (w, nc);
826             prev_pos_gt = n_pos_gt;
827           }
828
829         r = casewriter_make_reader (w);
830       }
831
832       /* Propagate the N_NEG_LT values from the negative cases
833          to the positive ones */
834       {
835         double prev_neg_lt = rs[i].n2;
836         w = sort_create_writer (&up_ordering, n_proto);
837
838         for ( ; (c = casereader_read (r) ); case_unref (c))
839           {
840             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
841             struct ccase *nc = case_clone (c);
842
843             if ( n_neg_lt == SYSMIS)
844               {
845                 n_neg_lt = prev_neg_lt;
846                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
847               }
848             
849             casewriter_write (w, nc);
850             prev_neg_lt = n_neg_lt;
851           }
852
853         r = casewriter_make_reader (w);
854       }
855
856       {
857         struct ccase *prev_case = NULL;
858         for ( ; (c = casereader_read (r) ); case_unref (c))
859           {
860             const struct ccase *next_case = casereader_peek (r, 0);
861
862             const double j = case_data_idx (c, VALUE)->f;
863             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
864             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
865             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
866             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
867
868             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
869               {
870                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
871                   {
872                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
873                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
874                   }
875
876                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
877                   {
878                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
879                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
880                   }
881               }
882
883             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
884               {
885                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
886
887                 rs[i].q1hat +=
888                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
889                 rs[i].q2hat +=
890                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
891
892               }
893
894             case_unref (prev_case);
895             prev_case = case_clone (c);
896           }
897
898         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
899         if ( roc->invert ) 
900           rs[i].auc = 1 - rs[i].auc;
901
902         if ( roc->bi_neg_exp )
903           {
904             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
905             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
906           }
907         else
908           {
909             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
910             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
911           }
912       }
913     }
914
915   casereader_destroy (positives);
916   casereader_destroy (negatives);
917
918   output_roc (rs, roc);
919
920   free (rs);
921 }
922
923 static void
924 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
925 {
926   int i;
927   const int n_fields = roc->print_se ? 5 : 1;
928   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
929   const int n_rows = 2 + roc->n_vars;
930   struct tab_table *tbl = tab_create (n_cols, n_rows);
931
932   if ( roc->n_vars > 1)
933     tab_title (tbl, _("Area Under the Curve"));
934   else
935     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
936
937   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
938
939   tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
940
941   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
942
943   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
944
945   tab_box (tbl,
946            TAL_2, TAL_2,
947            -1, TAL_1,
948            0, 0,
949            n_cols - 1,
950            n_rows - 1);
951
952   if ( roc->print_se )
953     {
954       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
955       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
956
957       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
958       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
959
960       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
961                              TAT_TITLE | TAB_CENTER,
962                              _("Asymp. %g%% Confidence Interval"), roc->ci);
963       tab_vline (tbl, 0, n_cols - 1, 0, 0);
964       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
965     }
966
967   if ( roc->n_vars > 1)
968     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
969
970   if ( roc->n_vars > 1)
971     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
972
973
974   for ( i = 0 ; i < roc->n_vars ; ++i )
975     {
976       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
977
978       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
979
980       if ( roc->print_se )
981         {
982           double se ;
983           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
984                                       (12 * rs[i].n1 * rs[i].n2));
985           double ci ;
986           double yy ;
987
988           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
989             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
990
991           se /= rs[i].n1 * rs[i].n2;
992
993           se = sqrt (se);
994
995           tab_double (tbl, n_cols - 4, 2 + i, 0,
996                       se,
997                       NULL);
998
999           ci = 1 - roc->ci / 100.0;
1000           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1001
1002           tab_double (tbl, n_cols - 2, 2 + i, 0,
1003                       rs[i].auc - yy,
1004                       NULL);
1005
1006           tab_double (tbl, n_cols - 1, 2 + i, 0,
1007                       rs[i].auc + yy,
1008                       NULL);
1009
1010           tab_double (tbl, n_cols - 3, 2 + i, 0,
1011                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1012                       NULL);
1013         }
1014     }
1015
1016   tab_submit (tbl);
1017 }
1018
1019
1020 static void
1021 show_summary (const struct cmd_roc *roc)
1022 {
1023   const int n_cols = 3;
1024   const int n_rows = 4;
1025   struct tab_table *tbl = tab_create (n_cols, n_rows);
1026
1027   tab_title (tbl, _("Case Summary"));
1028
1029   tab_headers (tbl, 1, 0, 2, 0);
1030
1031   tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
1032
1033   tab_box (tbl,
1034            TAL_2, TAL_2,
1035            -1, -1,
1036            0, 0,
1037            n_cols - 1,
1038            n_rows - 1);
1039
1040   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1041   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1042
1043
1044   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1045   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1046
1047
1048   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1049   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1050   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1051
1052   tab_joint_text (tbl, 1, 0, 2, 0,
1053                   TAT_TITLE | TAB_CENTER,
1054                   _("Valid N (listwise)"));
1055
1056
1057   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1058   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1059
1060
1061   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1062   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1063
1064   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1065   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1066
1067   tab_submit (tbl);
1068 }
1069
1070
1071 static void
1072 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1073 {
1074   int x = 1;
1075   int i;
1076   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1077   int n_rows = 1;
1078   struct tab_table *tbl ;
1079
1080   for (i = 0; i < roc->n_vars; ++i)
1081     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1082
1083   tbl = tab_create (n_cols, n_rows);
1084
1085   if ( roc->n_vars > 1)
1086     tab_title (tbl, _("Coordinates of the Curve"));
1087   else
1088     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1089
1090
1091   tab_headers (tbl, 1, 0, 1, 0);
1092
1093   tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
1094
1095   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1096
1097   if ( roc->n_vars > 1)
1098     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1099
1100   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1101   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1102   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1103
1104   tab_box (tbl,
1105            TAL_2, TAL_2,
1106            -1, TAL_1,
1107            0, 0,
1108            n_cols - 1,
1109            n_rows - 1);
1110
1111   if ( roc->n_vars > 1)
1112     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1113
1114   for (i = 0; i < roc->n_vars; ++i)
1115     {
1116       struct ccase *cc;
1117       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1118
1119       if ( roc->n_vars > 1)
1120         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1121
1122       if ( i > 0)
1123         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1124
1125
1126       for (; (cc = casereader_read (r)) != NULL;
1127            case_unref (cc), x++)
1128         {
1129           const double se = case_data_idx (cc, ROC_TP)->f /
1130             (
1131              case_data_idx (cc, ROC_TP)->f
1132              +
1133              case_data_idx (cc, ROC_FN)->f
1134              );
1135
1136           const double sp = case_data_idx (cc, ROC_TN)->f /
1137             (
1138              case_data_idx (cc, ROC_TN)->f
1139              +
1140              case_data_idx (cc, ROC_FP)->f
1141              );
1142
1143           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1144                       var_get_print_format (roc->vars[i]));
1145
1146           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1147           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1148         }
1149
1150       casereader_destroy (r);
1151     }
1152
1153   tab_submit (tbl);
1154 }
1155
1156
1157 static void
1158 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1159 {
1160   show_summary (roc);
1161
1162   if ( roc->curve )
1163     {
1164       struct roc_chart *rc;
1165       size_t i;
1166
1167       rc = roc_chart_create (roc->reference);
1168       for (i = 0; i < roc->n_vars; i++)
1169         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1170                            rs[i].cutpoint_rdr);
1171       chart_submit (roc_chart_get_chart (rc));
1172     }
1173
1174   show_auc (rs, roc);
1175
1176   if ( roc->print_coords )
1177     show_coords (rs, roc);
1178 }
1179