Fix memory leaks in ROC command
[pspp-builds.git] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <language/stats/roc.h>
20
21 #include <data/casegrouper.h>
22 #include <data/casereader.h>
23 #include <data/casewriter.h>
24 #include <data/dictionary.h>
25 #include <data/format.h>
26 #include <data/procedure.h>
27 #include <data/subcase.h>
28 #include <language/command.h>
29 #include <language/lexer/lexer.h>
30 #include <language/lexer/value-parser.h>
31 #include <language/lexer/variable-parser.h>
32 #include <libpspp/misc.h>
33 #include <math/sort.h>
34 #include <output/chart-item.h>
35 #include <output/charts/roc-chart.h>
36 #include <output/tab.h>
37
38 #include <gsl/gsl_cdf.h>
39
40 #include "gettext.h"
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
43
44 struct cmd_roc
45 {
46   size_t n_vars;
47   const struct variable **vars;
48   const struct dictionary *dict;
49
50   const struct variable *state_var;
51   union value state_value;
52
53   /* Plot the roc curve */
54   bool curve;
55   /* Plot the reference line */
56   bool reference;
57
58   double ci;
59
60   bool print_coords;
61   bool print_se;
62   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
63                       should be used */
64   enum mv_class exclude;
65
66   bool invert ; /* True iff a smaller test result variable indicates
67                    a positive result */
68
69   double pos;
70   double neg;
71   double pos_weighted;
72   double neg_weighted;
73 };
74
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
76
77 int
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
79 {
80   struct cmd_roc roc ;
81   const struct dictionary *dict = dataset_dict (ds);
82
83   roc.vars = NULL;
84   roc.n_vars = 0;
85   roc.print_se = false;
86   roc.print_coords = false;
87   roc.exclude = MV_ANY;
88   roc.curve = true;
89   roc.reference = false;
90   roc.ci = 95;
91   roc.bi_neg_exp = false;
92   roc.invert = false;
93   roc.pos = roc.pos_weighted = 0;
94   roc.neg = roc.neg_weighted = 0;
95   roc.dict = dataset_dict (ds);
96   roc.state_var = NULL;
97
98   lex_match (lexer, '/');
99   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
100                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
101     goto error;
102
103   if ( ! lex_force_match (lexer, T_BY))
104     {
105       goto error;
106     }
107
108   roc.state_var = parse_variable (lexer, dict);
109
110   if ( !lex_force_match (lexer, '('))
111     {
112       goto error;
113     }
114
115   value_init (&roc.state_value, var_get_width (roc.state_var));
116   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
117
118
119   if ( !lex_force_match (lexer, ')'))
120     {
121       goto error;
122     }
123
124
125   while (lex_token (lexer) != '.')
126     {
127       lex_match (lexer, '/');
128       if (lex_match_id (lexer, "MISSING"))
129         {
130           lex_match (lexer, '=');
131           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
132             {
133               if (lex_match_id (lexer, "INCLUDE"))
134                 {
135                   roc.exclude = MV_SYSTEM;
136                 }
137               else if (lex_match_id (lexer, "EXCLUDE"))
138                 {
139                   roc.exclude = MV_ANY;
140                 }
141               else
142                 {
143                   lex_error (lexer, NULL);
144                   goto error;
145                 }
146             }
147         }
148       else if (lex_match_id (lexer, "PLOT"))
149         {
150           lex_match (lexer, '=');
151           if (lex_match_id (lexer, "CURVE"))
152             {
153               roc.curve = true;
154               if (lex_match (lexer, '('))
155                 {
156                   roc.reference = true;
157                   lex_force_match_id (lexer, "REFERENCE");
158                   lex_force_match (lexer, ')');
159                 }
160             }
161           else if (lex_match_id (lexer, "NONE"))
162             {
163               roc.curve = false;
164             }
165           else
166             {
167               lex_error (lexer, NULL);
168               goto error;
169             }
170         }
171       else if (lex_match_id (lexer, "PRINT"))
172         {
173           lex_match (lexer, '=');
174           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
175             {
176               if (lex_match_id (lexer, "SE"))
177                 {
178                   roc.print_se = true;
179                 }
180               else if (lex_match_id (lexer, "COORDINATES"))
181                 {
182                   roc.print_coords = true;
183                 }
184               else
185                 {
186                   lex_error (lexer, NULL);
187                   goto error;
188                 }
189             }
190         }
191       else if (lex_match_id (lexer, "CRITERIA"))
192         {
193           lex_match (lexer, '=');
194           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
195             {
196               if (lex_match_id (lexer, "CUTOFF"))
197                 {
198                   lex_force_match (lexer, '(');
199                   if (lex_match_id (lexer, "INCLUDE"))
200                     {
201                       roc.exclude = MV_SYSTEM;
202                     }
203                   else if (lex_match_id (lexer, "EXCLUDE"))
204                     {
205                       roc.exclude = MV_USER | MV_SYSTEM;
206                     }
207                   else
208                     {
209                       lex_error (lexer, NULL);
210                       goto error;
211                     }
212                   lex_force_match (lexer, ')');
213                 }
214               else if (lex_match_id (lexer, "TESTPOS"))
215                 {
216                   lex_force_match (lexer, '(');
217                   if (lex_match_id (lexer, "LARGE"))
218                     {
219                       roc.invert = false;
220                     }
221                   else if (lex_match_id (lexer, "SMALL"))
222                     {
223                       roc.invert = true;
224                     }
225                   else
226                     {
227                       lex_error (lexer, NULL);
228                       goto error;
229                     }
230                   lex_force_match (lexer, ')');
231                 }
232               else if (lex_match_id (lexer, "CI"))
233                 {
234                   lex_force_match (lexer, '(');
235                   lex_force_num (lexer);
236                   roc.ci = lex_number (lexer);
237                   lex_get (lexer);
238                   lex_force_match (lexer, ')');
239                 }
240               else if (lex_match_id (lexer, "DISTRIBUTION"))
241                 {
242                   lex_force_match (lexer, '(');
243                   if (lex_match_id (lexer, "FREE"))
244                     {
245                       roc.bi_neg_exp = false;
246                     }
247                   else if (lex_match_id (lexer, "NEGEXPO"))
248                     {
249                       roc.bi_neg_exp = true;
250                     }
251                   else
252                     {
253                       lex_error (lexer, NULL);
254                       goto error;
255                     }
256                   lex_force_match (lexer, ')');
257                 }
258               else
259                 {
260                   lex_error (lexer, NULL);
261                   goto error;
262                 }
263             }
264         }
265       else
266         {
267           lex_error (lexer, NULL);
268           break;
269         }
270     }
271
272   if ( ! run_roc (ds, &roc)) 
273     goto error;
274
275   value_destroy (&roc.state_value, var_get_width (roc.state_var));
276   free (roc.vars);
277   return CMD_SUCCESS;
278
279  error:
280   if ( roc.state_var)
281     value_destroy (&roc.state_value, var_get_width (roc.state_var));
282   free (roc.vars);
283   return CMD_FAILURE;
284 }
285
286
287
288
289 static void
290 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
291
292
293 static int
294 run_roc (struct dataset *ds, struct cmd_roc *roc)
295 {
296   struct dictionary *dict = dataset_dict (ds);
297   bool ok;
298   struct casereader *group;
299
300   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
301   while (casegrouper_get_next_group (grouper, &group))
302     {
303       do_roc (roc, group, dataset_dict (ds));
304     }
305   ok = casegrouper_destroy (grouper);
306   ok = proc_commit (ds) && ok;
307
308   return ok;
309 }
310
311 #if 0
312 static void
313 dump_casereader (struct casereader *reader)
314 {
315   struct ccase *c;
316   struct casereader *r = casereader_clone (reader);
317
318   for ( ; (c = casereader_read (r) ); case_unref (c))
319     {
320       int i;
321       for (i = 0 ; i < case_get_value_cnt (c); ++i)
322         {
323           printf ("%g ", case_data_idx (c, i)->f);
324         }
325       printf ("\n");
326     }
327
328   casereader_destroy (r);
329 }
330 #endif
331
332
333 /* 
334    Return true iff the state variable indicates that C has positive actual state.
335
336    As a side effect, this function also accumulates the roc->{pos,neg} and 
337    roc->{pos,neg}_weighted counts.
338  */
339 static bool
340 match_positives (const struct ccase *c, void *aux)
341 {
342   struct cmd_roc *roc = aux;
343   const struct variable *wv = dict_get_weight (roc->dict);
344   const double weight = wv ? case_data (c, wv)->f : 1.0;
345
346   const bool positive =
347   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
348     var_get_width (roc->state_var)));
349
350   if ( positive )
351     {
352       roc->pos++;
353       roc->pos_weighted += weight;
354     }
355   else
356     {
357       roc->neg++;
358       roc->neg_weighted += weight;
359     }
360
361   return positive;
362 }
363
364
365 #define VALUE  0
366 #define N_EQ   1
367 #define N_PRED 2
368
369 /* Some intermediate state for calculating the cutpoints and the 
370    standard error values */
371 struct roc_state
372 {
373   double auc;  /* Area under the curve */
374
375   double n1;  /* total weight of positives */
376   double n2;  /* total weight of negatives */
377
378   /* intermediates for standard error */
379   double q1hat; 
380   double q2hat;
381
382   /* intermediates for cutpoints */
383   struct casewriter *cutpoint_wtr;
384   struct casereader *cutpoint_rdr;
385   double prev_result;
386   double min;
387   double max;
388 };
389
390 /* 
391    Return a new casereader based upon CUTPOINT_RDR.
392    The number of "positive" cases are placed into
393    the position TRUE_INDEX, and the number of "negative" cases
394    into FALSE_INDEX.
395    POS_COND and RESULT determine the semantics of what is 
396    "positive".
397    WEIGHT is the value of a single count.
398  */
399 static struct casereader *
400 accumulate_counts (struct casereader *input,
401                    double result, double weight, 
402                    bool (*pos_cond) (double, double),
403                    int true_index, int false_index)
404 {
405   const struct caseproto *proto = casereader_get_proto (input);
406   struct casewriter *w =
407     autopaging_writer_create (proto);
408   struct ccase *cpc;
409   double prev_cp = SYSMIS;
410
411   for ( ; (cpc = casereader_read (input) ); case_unref (cpc))
412     {
413       struct ccase *new_case;
414       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
415
416       assert (cp != SYSMIS);
417
418       /* We don't want duplicates here */
419       if ( cp == prev_cp )
420         continue;
421
422       new_case = case_clone (cpc);
423
424       if ( pos_cond (result, cp))
425         case_data_rw_idx (new_case, true_index)->f += weight;
426       else
427         case_data_rw_idx (new_case, false_index)->f += weight;
428
429       prev_cp = cp;
430
431       casewriter_write (w, new_case);
432     }
433   casereader_destroy (input);
434
435   return casewriter_make_reader (w);
436 }
437
438
439
440 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
441
442 /*
443   This function does 3 things:
444
445   1. Counts the number of cases which are equal to every other case in READER,
446   and those cases for which the relationship between it and every other case
447   satifies PRED (normally either > or <).  VAR is variable defining a case's value
448   for this purpose.
449
450   2. Counts the number of true and false cases in reader, and populates
451   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
452   which receive these values.  POS_COND is the condition defining true
453   and false.
454   
455   3. CC is filled with the cumulative weight of all cases of READER.
456 */
457 static struct casereader *
458 process_group (const struct variable *var, struct casereader *reader,
459                bool (*pred) (double, double),
460                const struct dictionary *dict,
461                double *cc,
462                struct casereader **cutpoint_rdr, 
463                bool (*pos_cond) (double, double),
464                int true_index,
465                int false_index)
466 {
467   const struct variable *w = dict_get_weight (dict);
468
469   struct casereader *r1 =
470     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
471
472   const int weight_idx  = w ? var_get_case_index (w) :
473     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
474   
475   struct ccase *c1;
476
477   struct casereader *rclone = casereader_clone (r1);
478   struct casewriter *wtr;
479   struct caseproto *proto = caseproto_create ();
480
481   proto = caseproto_add_width (proto, 0);
482   proto = caseproto_add_width (proto, 0);
483   proto = caseproto_add_width (proto, 0);
484
485   wtr = autopaging_writer_create (proto);  
486
487   *cc = 0;
488
489   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
490     {
491       struct ccase *new_case = case_create (proto);
492       struct ccase *c2;
493       struct casereader *r2 = casereader_clone (rclone);
494
495       const double weight1 = case_data_idx (c1, weight_idx)->f;
496       const double d1 = case_data (c1, var)->f;
497       double n_eq = 0.0;
498       double n_pred = 0.0;
499
500       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
501                                          pos_cond,
502                                          true_index, false_index);
503
504       *cc += weight1;
505
506       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
507         {
508           const double d2 = case_data (c2, var)->f;
509           const double weight2 = case_data_idx (c2, weight_idx)->f;
510
511           if ( d1 == d2 )
512             {
513               n_eq += weight2;
514               continue;
515             }
516           else  if ( pred (d2, d1))
517             {
518               n_pred += weight2;
519             }
520         }
521
522       case_data_rw_idx (new_case, VALUE)->f = d1;
523       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
524       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
525
526       casewriter_write (wtr, new_case);
527
528       casereader_destroy (r2);
529     }
530
531   
532   casereader_destroy (r1);
533   casereader_destroy (rclone);
534
535   caseproto_unref (proto);
536
537   return casewriter_make_reader (wtr);
538 }
539
540 /* Some more indeces into case data */
541 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
542 #define N_POS_GT 2  /* number of postive cases with values greater than n */
543 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
544 #define N_NEG_LT 4  /* number of negative cases with values less than n */
545
546 static bool
547 gt (double d1, double d2)
548 {
549   return d1 > d2;
550 }
551
552
553 static bool
554 ge (double d1, double d2)
555 {
556   return d1 > d2;
557 }
558
559 static bool
560 lt (double d1, double d2)
561 {
562   return d1 < d2;
563 }
564
565
566 /*
567   Return a casereader with width 3,
568   populated with cases based upon READER.
569   The cases will have the values:
570   (N, number of cases equal to N, number of cases greater than N)
571   As a side effect, update RS->n1 with the number of positive cases.
572 */
573 static struct casereader *
574 process_positive_group (const struct variable *var, struct casereader *reader,
575                         const struct dictionary *dict,
576                         struct roc_state *rs)
577 {
578   return process_group (var, reader, gt, dict, &rs->n1,
579                         &rs->cutpoint_rdr,
580                         ge,
581                         ROC_TP, ROC_FN);
582 }
583
584 /*
585   Return a casereader with width 3,
586   populated with cases based upon READER.
587   The cases will have the values:
588   (N, number of cases equal to N, number of cases less than N)
589   As a side effect, update RS->n2 with the number of negative cases.
590 */
591 static struct casereader *
592 process_negative_group (const struct variable *var, struct casereader *reader,
593                         const struct dictionary *dict,
594                         struct roc_state *rs)
595 {
596   return process_group (var, reader, lt, dict, &rs->n2,
597                         &rs->cutpoint_rdr,
598                         lt,
599                         ROC_TN, ROC_FP);
600 }
601
602
603
604
605 static void
606 append_cutpoint (struct casewriter *writer, double cutpoint)
607 {
608   struct ccase *cc = case_create (casewriter_get_proto (writer));
609
610   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
611   case_data_rw_idx (cc, ROC_TP)->f = 0;
612   case_data_rw_idx (cc, ROC_FN)->f = 0;
613   case_data_rw_idx (cc, ROC_TN)->f = 0;
614   case_data_rw_idx (cc, ROC_FP)->f = 0;
615
616   casewriter_write (writer, cc);
617 }
618
619
620 /* 
621    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
622    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
623    reader will be populated with its final number of cases.
624    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
625    value.  The other entries will be initialised to zero.
626 */
627 static void
628 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
629 {
630   int i;
631   struct casereader *r = casereader_clone (input);
632   struct ccase *c;
633
634   {
635     struct caseproto *proto = caseproto_create ();
636     struct subcase ordering;
637     subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
638
639     proto = caseproto_add_width (proto, 0); /* cutpoint */
640     proto = caseproto_add_width (proto, 0); /* ROC_TP */
641     proto = caseproto_add_width (proto, 0); /* ROC_FN */
642     proto = caseproto_add_width (proto, 0); /* ROC_TN */
643     proto = caseproto_add_width (proto, 0); /* ROC_FP */
644
645     for (i = 0 ; i < roc->n_vars; ++i)
646       {
647         rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
648         rs[i].prev_result = SYSMIS;
649         rs[i].max = -DBL_MAX;
650         rs[i].min = DBL_MAX;
651       }
652
653     caseproto_unref (proto);
654     subcase_destroy (&ordering);
655   }
656
657   for (; (c = casereader_read (r)) != NULL; case_unref (c))
658     {
659       for (i = 0 ; i < roc->n_vars; ++i)
660         {
661           const union value *v = case_data (c, roc->vars[i]); 
662           const double result = v->f;
663
664           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
665             continue;
666
667           minimize (&rs[i].min, result);
668           maximize (&rs[i].max, result);
669
670           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
671             {
672               const double mean = (result + rs[i].prev_result ) / 2.0;
673               append_cutpoint (rs[i].cutpoint_wtr, mean);
674             }
675
676           rs[i].prev_result = result;
677         }
678     }
679   casereader_destroy (r);
680
681
682   /* Append the min and max cutpoints */
683   for (i = 0 ; i < roc->n_vars; ++i)
684     {
685       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
686       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
687
688       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
689     }
690 }
691
692 static void
693 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
694 {
695   int i;
696
697   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
698
699   struct casereader *negatives = NULL;
700   struct casereader *positives = NULL;
701
702   struct caseproto *n_proto = NULL;
703
704   struct subcase up_ordering;
705   struct subcase down_ordering;
706
707   struct casewriter *neg_wtr = NULL;
708
709   struct casereader *input = casereader_create_filter_missing (reader,
710                                                                roc->vars, roc->n_vars,
711                                                                roc->exclude,
712                                                                NULL,
713                                                                NULL);
714
715   input = casereader_create_filter_missing (input,
716                                             &roc->state_var, 1,
717                                             roc->exclude,
718                                             NULL,
719                                             NULL);
720
721   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
722
723   prepare_cutpoints (roc, rs, input);
724
725
726   /* Separate the positive actual state cases from the negative ones */
727   positives = 
728     casereader_create_filter_func (input,
729                                    match_positives,
730                                    NULL,
731                                    roc,
732                                    neg_wtr);
733
734   n_proto = caseproto_create ();
735       
736   n_proto = caseproto_add_width (n_proto, 0);
737   n_proto = caseproto_add_width (n_proto, 0);
738   n_proto = caseproto_add_width (n_proto, 0);
739   n_proto = caseproto_add_width (n_proto, 0);
740   n_proto = caseproto_add_width (n_proto, 0);
741
742   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
743   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
744
745   for (i = 0 ; i < roc->n_vars; ++i)
746     {
747       struct casewriter *w = NULL;
748       struct casereader *r = NULL;
749
750       struct ccase *c;
751
752       struct ccase *cpos;
753       struct casereader *n_neg_reader ;
754       const struct variable *var = roc->vars[i];
755
756       struct casereader *neg ;
757       struct casereader *pos = casereader_clone (positives);
758
759       struct casereader *n_pos_reader =
760         process_positive_group (var, pos, dict, &rs[i]);
761
762       if ( negatives == NULL)
763         {
764           negatives = casewriter_make_reader (neg_wtr);
765         }
766
767       neg = casereader_clone (negatives);
768
769       n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
770
771       /* Merge the n_pos and n_neg casereaders */
772       w = sort_create_writer (&up_ordering, n_proto);
773       for ( ; (cpos = casereader_read (n_pos_reader) ); case_unref (cpos))
774         {
775           struct ccase *pos_case = case_create (n_proto);
776           struct ccase *cneg;
777           const double jpos = case_data_idx (cpos, VALUE)->f;
778
779           while ((cneg = casereader_read (n_neg_reader)))
780             {
781               struct ccase *nc = case_create (n_proto);
782
783               const double jneg = case_data_idx (cneg, VALUE)->f;
784
785               case_data_rw_idx (nc, VALUE)->f = jneg;
786               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
787
788               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
789
790               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
791               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
792
793               casewriter_write (w, nc);
794
795               case_unref (cneg);
796               if ( jneg > jpos)
797                 break;
798             }
799
800           case_data_rw_idx (pos_case, VALUE)->f = jpos;
801           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
802           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
803           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
804           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
805
806           casewriter_write (w, pos_case);
807         }
808
809       casereader_destroy (n_pos_reader);
810       casereader_destroy (n_neg_reader);
811
812 /* These aren't used anymore */
813 #undef N_EQ
814 #undef N_PRED
815
816       r = casewriter_make_reader (w);
817
818       /* Propagate the N_POS_GT values from the positive cases
819          to the negative ones */
820       {
821         double prev_pos_gt = rs[i].n1;
822         w = sort_create_writer (&down_ordering, n_proto);
823
824         for ( ; (c = casereader_read (r) ); case_unref (c))
825           {
826             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
827             struct ccase *nc = case_clone (c);
828
829             if ( n_pos_gt == SYSMIS)
830               {
831                 n_pos_gt = prev_pos_gt;
832                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
833               }
834             
835             casewriter_write (w, nc);
836             prev_pos_gt = n_pos_gt;
837           }
838
839         casereader_destroy (r);
840         r = casewriter_make_reader (w);
841       }
842
843       /* Propagate the N_NEG_LT values from the negative cases
844          to the positive ones */
845       {
846         double prev_neg_lt = rs[i].n2;
847         w = sort_create_writer (&up_ordering, n_proto);
848
849         for ( ; (c = casereader_read (r) ); case_unref (c))
850           {
851             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
852             struct ccase *nc = case_clone (c);
853
854             if ( n_neg_lt == SYSMIS)
855               {
856                 n_neg_lt = prev_neg_lt;
857                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
858               }
859             
860             casewriter_write (w, nc);
861             prev_neg_lt = n_neg_lt;
862           }
863
864         casereader_destroy (r);
865         r = casewriter_make_reader (w);
866       }
867
868       {
869         struct ccase *prev_case = NULL;
870         for ( ; (c = casereader_read (r) ); case_unref (c))
871           {
872             struct ccase *next_case = casereader_peek (r, 0);
873
874             const double j = case_data_idx (c, VALUE)->f;
875             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
876             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
877             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
878             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
879
880             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
881               {
882                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
883                   {
884                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
885                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
886                   }
887
888                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
889                   {
890                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
891                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
892                   }
893               }
894
895             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
896               {
897                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
898
899                 rs[i].q1hat +=
900                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
901                 rs[i].q2hat +=
902                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
903
904               }
905
906             case_unref (next_case);
907             case_unref (prev_case);
908             prev_case = case_clone (c);
909           }
910         casereader_destroy (r);
911         case_unref (prev_case);
912
913         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
914         if ( roc->invert ) 
915           rs[i].auc = 1 - rs[i].auc;
916
917         if ( roc->bi_neg_exp )
918           {
919             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
920             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
921           }
922         else
923           {
924             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
925             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
926           }
927       }
928     }
929
930   casereader_destroy (positives);
931   casereader_destroy (negatives);
932
933   caseproto_unref (n_proto);
934   subcase_destroy (&up_ordering);
935   subcase_destroy (&down_ordering);
936
937   output_roc (rs, roc);
938  
939   for (i = 0 ; i < roc->n_vars; ++i)
940     casereader_destroy (rs[i].cutpoint_rdr);
941
942   free (rs);
943 }
944
945 static void
946 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
947 {
948   int i;
949   const int n_fields = roc->print_se ? 5 : 1;
950   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
951   const int n_rows = 2 + roc->n_vars;
952   struct tab_table *tbl = tab_create (n_cols, n_rows);
953
954   if ( roc->n_vars > 1)
955     tab_title (tbl, _("Area Under the Curve"));
956   else
957     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
958
959   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
960
961
962   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
963
964   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
965
966   tab_box (tbl,
967            TAL_2, TAL_2,
968            -1, TAL_1,
969            0, 0,
970            n_cols - 1,
971            n_rows - 1);
972
973   if ( roc->print_se )
974     {
975       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
976       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
977
978       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
979       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
980
981       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
982                              TAT_TITLE | TAB_CENTER,
983                              _("Asymp. %g%% Confidence Interval"), roc->ci);
984       tab_vline (tbl, 0, n_cols - 1, 0, 0);
985       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
986     }
987
988   if ( roc->n_vars > 1)
989     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
990
991   if ( roc->n_vars > 1)
992     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
993
994
995   for ( i = 0 ; i < roc->n_vars ; ++i )
996     {
997       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
998
999       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
1000
1001       if ( roc->print_se )
1002         {
1003           double se ;
1004           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1005                                       (12 * rs[i].n1 * rs[i].n2));
1006           double ci ;
1007           double yy ;
1008
1009           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
1010             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
1011
1012           se /= rs[i].n1 * rs[i].n2;
1013
1014           se = sqrt (se);
1015
1016           tab_double (tbl, n_cols - 4, 2 + i, 0,
1017                       se,
1018                       NULL);
1019
1020           ci = 1 - roc->ci / 100.0;
1021           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1022
1023           tab_double (tbl, n_cols - 2, 2 + i, 0,
1024                       rs[i].auc - yy,
1025                       NULL);
1026
1027           tab_double (tbl, n_cols - 1, 2 + i, 0,
1028                       rs[i].auc + yy,
1029                       NULL);
1030
1031           tab_double (tbl, n_cols - 3, 2 + i, 0,
1032                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1033                       NULL);
1034         }
1035     }
1036
1037   tab_submit (tbl);
1038 }
1039
1040
1041 static void
1042 show_summary (const struct cmd_roc *roc)
1043 {
1044   const int n_cols = 3;
1045   const int n_rows = 4;
1046   struct tab_table *tbl = tab_create (n_cols, n_rows);
1047
1048   tab_title (tbl, _("Case Summary"));
1049
1050   tab_headers (tbl, 1, 0, 2, 0);
1051
1052   tab_box (tbl,
1053            TAL_2, TAL_2,
1054            -1, -1,
1055            0, 0,
1056            n_cols - 1,
1057            n_rows - 1);
1058
1059   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1060   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1061
1062
1063   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1064   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1065
1066
1067   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1068   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1069   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1070
1071   tab_joint_text (tbl, 1, 0, 2, 0,
1072                   TAT_TITLE | TAB_CENTER,
1073                   _("Valid N (listwise)"));
1074
1075
1076   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1077   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1078
1079
1080   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1081   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1082
1083   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1084   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1085
1086   tab_submit (tbl);
1087 }
1088
1089
1090 static void
1091 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1092 {
1093   int x = 1;
1094   int i;
1095   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1096   int n_rows = 1;
1097   struct tab_table *tbl ;
1098
1099   for (i = 0; i < roc->n_vars; ++i)
1100     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1101
1102   tbl = tab_create (n_cols, n_rows);
1103
1104   if ( roc->n_vars > 1)
1105     tab_title (tbl, _("Coordinates of the Curve"));
1106   else
1107     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1108
1109
1110   tab_headers (tbl, 1, 0, 1, 0);
1111
1112   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1113
1114   if ( roc->n_vars > 1)
1115     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1116
1117   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1118   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1119   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1120
1121   tab_box (tbl,
1122            TAL_2, TAL_2,
1123            -1, TAL_1,
1124            0, 0,
1125            n_cols - 1,
1126            n_rows - 1);
1127
1128   if ( roc->n_vars > 1)
1129     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1130
1131   for (i = 0; i < roc->n_vars; ++i)
1132     {
1133       struct ccase *cc;
1134       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1135
1136       if ( roc->n_vars > 1)
1137         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1138
1139       if ( i > 0)
1140         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1141
1142
1143       for (; (cc = casereader_read (r)) != NULL;
1144            case_unref (cc), x++)
1145         {
1146           const double se = case_data_idx (cc, ROC_TP)->f /
1147             (
1148              case_data_idx (cc, ROC_TP)->f
1149              +
1150              case_data_idx (cc, ROC_FN)->f
1151              );
1152
1153           const double sp = case_data_idx (cc, ROC_TN)->f /
1154             (
1155              case_data_idx (cc, ROC_TN)->f
1156              +
1157              case_data_idx (cc, ROC_FP)->f
1158              );
1159
1160           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1161                       var_get_print_format (roc->vars[i]));
1162
1163           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1164           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1165         }
1166
1167       casereader_destroy (r);
1168     }
1169
1170   tab_submit (tbl);
1171 }
1172
1173
1174 static void
1175 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1176 {
1177   show_summary (roc);
1178
1179   if ( roc->curve )
1180     {
1181       struct roc_chart *rc;
1182       size_t i;
1183
1184       rc = roc_chart_create (roc->reference);
1185       for (i = 0; i < roc->n_vars; i++)
1186         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1187                            rs[i].cutpoint_rdr);
1188       roc_chart_submit (rc);
1189     }
1190
1191   show_auc (rs, roc);
1192
1193   if ( roc->print_coords )
1194     show_coords (rs, roc);
1195 }
1196