Fix crash in ROC command when no valid state variable is given.
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/roc.h"
20
21 #include <gsl/gsl_cdf.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/chart-item.h"
37 #include "output/charts/roc-chart.h"
38 #include "output/tab.h"
39
40 #include "gettext.h"
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
43
44 struct cmd_roc
45 {
46   size_t n_vars;
47   const struct variable **vars;
48   const struct dictionary *dict;
49
50   const struct variable *state_var;
51   union value state_value;
52   size_t state_var_width;
53
54   /* Plot the roc curve */
55   bool curve;
56   /* Plot the reference line */
57   bool reference;
58
59   double ci;
60
61   bool print_coords;
62   bool print_se;
63   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
64                       should be used */
65   enum mv_class exclude;
66
67   bool invert ; /* True iff a smaller test result variable indicates
68                    a positive result */
69
70   double pos;
71   double neg;
72   double pos_weighted;
73   double neg_weighted;
74 };
75
76 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
77
78 int
79 cmd_roc (struct lexer *lexer, struct dataset *ds)
80 {
81   struct cmd_roc roc ;
82   const struct dictionary *dict = dataset_dict (ds);
83
84   roc.vars = NULL;
85   roc.n_vars = 0;
86   roc.print_se = false;
87   roc.print_coords = false;
88   roc.exclude = MV_ANY;
89   roc.curve = true;
90   roc.reference = false;
91   roc.ci = 95;
92   roc.bi_neg_exp = false;
93   roc.invert = false;
94   roc.pos = roc.pos_weighted = 0;
95   roc.neg = roc.neg_weighted = 0;
96   roc.dict = dataset_dict (ds);
97   roc.state_var = NULL;
98   roc.state_var_width = -1;
99
100   lex_match (lexer, T_SLASH);
101   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
102                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
103     goto error;
104
105   if ( ! lex_force_match (lexer, T_BY))
106     {
107       goto error;
108     }
109
110   roc.state_var = parse_variable (lexer, dict);
111   if (! roc.state_var)
112     {
113       goto error;
114     }
115
116   if ( !lex_force_match (lexer, T_LPAREN))
117     {
118       goto error;
119     }
120
121   roc.state_var_width = var_get_width (roc.state_var);
122   value_init (&roc.state_value, roc.state_var_width);
123   parse_value (lexer, &roc.state_value, roc.state_var);
124
125
126   if ( !lex_force_match (lexer, T_RPAREN))
127     {
128       goto error;
129     }
130
131   while (lex_token (lexer) != T_ENDCMD)
132     {
133       lex_match (lexer, T_SLASH);
134       if (lex_match_id (lexer, "MISSING"))
135         {
136           lex_match (lexer, T_EQUALS);
137           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
138             {
139               if (lex_match_id (lexer, "INCLUDE"))
140                 {
141                   roc.exclude = MV_SYSTEM;
142                 }
143               else if (lex_match_id (lexer, "EXCLUDE"))
144                 {
145                   roc.exclude = MV_ANY;
146                 }
147               else
148                 {
149                   lex_error (lexer, NULL);
150                   goto error;
151                 }
152             }
153         }
154       else if (lex_match_id (lexer, "PLOT"))
155         {
156           lex_match (lexer, T_EQUALS);
157           if (lex_match_id (lexer, "CURVE"))
158             {
159               roc.curve = true;
160               if (lex_match (lexer, T_LPAREN))
161                 {
162                   roc.reference = true;
163                   lex_force_match_id (lexer, "REFERENCE");
164                   lex_force_match (lexer, T_RPAREN);
165                 }
166             }
167           else if (lex_match_id (lexer, "NONE"))
168             {
169               roc.curve = false;
170             }
171           else
172             {
173               lex_error (lexer, NULL);
174               goto error;
175             }
176         }
177       else if (lex_match_id (lexer, "PRINT"))
178         {
179           lex_match (lexer, T_EQUALS);
180           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
181             {
182               if (lex_match_id (lexer, "SE"))
183                 {
184                   roc.print_se = true;
185                 }
186               else if (lex_match_id (lexer, "COORDINATES"))
187                 {
188                   roc.print_coords = true;
189                 }
190               else
191                 {
192                   lex_error (lexer, NULL);
193                   goto error;
194                 }
195             }
196         }
197       else if (lex_match_id (lexer, "CRITERIA"))
198         {
199           lex_match (lexer, T_EQUALS);
200           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
201             {
202               if (lex_match_id (lexer, "CUTOFF"))
203                 {
204                   lex_force_match (lexer, T_LPAREN);
205                   if (lex_match_id (lexer, "INCLUDE"))
206                     {
207                       roc.exclude = MV_SYSTEM;
208                     }
209                   else if (lex_match_id (lexer, "EXCLUDE"))
210                     {
211                       roc.exclude = MV_USER | MV_SYSTEM;
212                     }
213                   else
214                     {
215                       lex_error (lexer, NULL);
216                       goto error;
217                     }
218                   lex_force_match (lexer, T_RPAREN);
219                 }
220               else if (lex_match_id (lexer, "TESTPOS"))
221                 {
222                   lex_force_match (lexer, T_LPAREN);
223                   if (lex_match_id (lexer, "LARGE"))
224                     {
225                       roc.invert = false;
226                     }
227                   else if (lex_match_id (lexer, "SMALL"))
228                     {
229                       roc.invert = true;
230                     }
231                   else
232                     {
233                       lex_error (lexer, NULL);
234                       goto error;
235                     }
236                   lex_force_match (lexer, T_RPAREN);
237                 }
238               else if (lex_match_id (lexer, "CI"))
239                 {
240                   lex_force_match (lexer, T_LPAREN);
241                   lex_force_num (lexer);
242                   roc.ci = lex_number (lexer);
243                   lex_get (lexer);
244                   lex_force_match (lexer, T_RPAREN);
245                 }
246               else if (lex_match_id (lexer, "DISTRIBUTION"))
247                 {
248                   lex_force_match (lexer, T_LPAREN);
249                   if (lex_match_id (lexer, "FREE"))
250                     {
251                       roc.bi_neg_exp = false;
252                     }
253                   else if (lex_match_id (lexer, "NEGEXPO"))
254                     {
255                       roc.bi_neg_exp = true;
256                     }
257                   else
258                     {
259                       lex_error (lexer, NULL);
260                       goto error;
261                     }
262                   lex_force_match (lexer, T_RPAREN);
263                 }
264               else
265                 {
266                   lex_error (lexer, NULL);
267                   goto error;
268                 }
269             }
270         }
271       else
272         {
273           lex_error (lexer, NULL);
274           break;
275         }
276     }
277
278   if ( ! run_roc (ds, &roc)) 
279     goto error;
280
281   if ( roc.state_var)
282     value_destroy (&roc.state_value, roc.state_var_width);
283   free (roc.vars);
284   return CMD_SUCCESS;
285
286  error:
287   if ( roc.state_var)
288     value_destroy (&roc.state_value, roc.state_var_width);
289   free (roc.vars);
290   return CMD_FAILURE;
291 }
292
293
294
295
296 static void
297 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
298
299
300 static int
301 run_roc (struct dataset *ds, struct cmd_roc *roc)
302 {
303   struct dictionary *dict = dataset_dict (ds);
304   bool ok;
305   struct casereader *group;
306
307   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
308   while (casegrouper_get_next_group (grouper, &group))
309     {
310       do_roc (roc, group, dataset_dict (ds));
311     }
312   ok = casegrouper_destroy (grouper);
313   ok = proc_commit (ds) && ok;
314
315   return ok;
316 }
317
318 #if 0
319 static void
320 dump_casereader (struct casereader *reader)
321 {
322   struct ccase *c;
323   struct casereader *r = casereader_clone (reader);
324
325   for ( ; (c = casereader_read (r) ); case_unref (c))
326     {
327       int i;
328       for (i = 0 ; i < case_get_value_cnt (c); ++i)
329         {
330           printf ("%g ", case_data_idx (c, i)->f);
331         }
332       printf ("\n");
333     }
334
335   casereader_destroy (r);
336 }
337 #endif
338
339
340 /* 
341    Return true iff the state variable indicates that C has positive actual state.
342
343    As a side effect, this function also accumulates the roc->{pos,neg} and 
344    roc->{pos,neg}_weighted counts.
345  */
346 static bool
347 match_positives (const struct ccase *c, void *aux)
348 {
349   struct cmd_roc *roc = aux;
350   const struct variable *wv = dict_get_weight (roc->dict);
351   const double weight = wv ? case_data (c, wv)->f : 1.0;
352
353   const bool positive =
354   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
355     var_get_width (roc->state_var)));
356
357   if ( positive )
358     {
359       roc->pos++;
360       roc->pos_weighted += weight;
361     }
362   else
363     {
364       roc->neg++;
365       roc->neg_weighted += weight;
366     }
367
368   return positive;
369 }
370
371
372 #define VALUE  0
373 #define N_EQ   1
374 #define N_PRED 2
375
376 /* Some intermediate state for calculating the cutpoints and the 
377    standard error values */
378 struct roc_state
379 {
380   double auc;  /* Area under the curve */
381
382   double n1;  /* total weight of positives */
383   double n2;  /* total weight of negatives */
384
385   /* intermediates for standard error */
386   double q1hat; 
387   double q2hat;
388
389   /* intermediates for cutpoints */
390   struct casewriter *cutpoint_wtr;
391   struct casereader *cutpoint_rdr;
392   double prev_result;
393   double min;
394   double max;
395 };
396
397 /* 
398    Return a new casereader based upon CUTPOINT_RDR.
399    The number of "positive" cases are placed into
400    the position TRUE_INDEX, and the number of "negative" cases
401    into FALSE_INDEX.
402    POS_COND and RESULT determine the semantics of what is 
403    "positive".
404    WEIGHT is the value of a single count.
405  */
406 static struct casereader *
407 accumulate_counts (struct casereader *input,
408                    double result, double weight, 
409                    bool (*pos_cond) (double, double),
410                    int true_index, int false_index)
411 {
412   const struct caseproto *proto = casereader_get_proto (input);
413   struct casewriter *w =
414     autopaging_writer_create (proto);
415   struct ccase *cpc;
416   double prev_cp = SYSMIS;
417
418   for ( ; (cpc = casereader_read (input) ); case_unref (cpc))
419     {
420       struct ccase *new_case;
421       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
422
423       assert (cp != SYSMIS);
424
425       /* We don't want duplicates here */
426       if ( cp == prev_cp )
427         continue;
428
429       new_case = case_clone (cpc);
430
431       if ( pos_cond (result, cp))
432         case_data_rw_idx (new_case, true_index)->f += weight;
433       else
434         case_data_rw_idx (new_case, false_index)->f += weight;
435
436       prev_cp = cp;
437
438       casewriter_write (w, new_case);
439     }
440   casereader_destroy (input);
441
442   return casewriter_make_reader (w);
443 }
444
445
446
447 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
448
449 /*
450   This function does 3 things:
451
452   1. Counts the number of cases which are equal to every other case in READER,
453   and those cases for which the relationship between it and every other case
454   satifies PRED (normally either > or <).  VAR is variable defining a case's value
455   for this purpose.
456
457   2. Counts the number of true and false cases in reader, and populates
458   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
459   which receive these values.  POS_COND is the condition defining true
460   and false.
461   
462   3. CC is filled with the cumulative weight of all cases of READER.
463 */
464 static struct casereader *
465 process_group (const struct variable *var, struct casereader *reader,
466                bool (*pred) (double, double),
467                const struct dictionary *dict,
468                double *cc,
469                struct casereader **cutpoint_rdr, 
470                bool (*pos_cond) (double, double),
471                int true_index,
472                int false_index)
473 {
474   const struct variable *w = dict_get_weight (dict);
475
476   struct casereader *r1 =
477     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
478
479   const int weight_idx  = w ? var_get_case_index (w) :
480     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
481   
482   struct ccase *c1;
483
484   struct casereader *rclone = casereader_clone (r1);
485   struct casewriter *wtr;
486   struct caseproto *proto = caseproto_create ();
487
488   proto = caseproto_add_width (proto, 0);
489   proto = caseproto_add_width (proto, 0);
490   proto = caseproto_add_width (proto, 0);
491
492   wtr = autopaging_writer_create (proto);  
493
494   *cc = 0;
495
496   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
497     {
498       struct ccase *new_case = case_create (proto);
499       struct ccase *c2;
500       struct casereader *r2 = casereader_clone (rclone);
501
502       const double weight1 = case_data_idx (c1, weight_idx)->f;
503       const double d1 = case_data (c1, var)->f;
504       double n_eq = 0.0;
505       double n_pred = 0.0;
506
507       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
508                                          pos_cond,
509                                          true_index, false_index);
510
511       *cc += weight1;
512
513       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
514         {
515           const double d2 = case_data (c2, var)->f;
516           const double weight2 = case_data_idx (c2, weight_idx)->f;
517
518           if ( d1 == d2 )
519             {
520               n_eq += weight2;
521               continue;
522             }
523           else  if ( pred (d2, d1))
524             {
525               n_pred += weight2;
526             }
527         }
528
529       case_data_rw_idx (new_case, VALUE)->f = d1;
530       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
531       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
532
533       casewriter_write (wtr, new_case);
534
535       casereader_destroy (r2);
536     }
537
538   
539   casereader_destroy (r1);
540   casereader_destroy (rclone);
541
542   caseproto_unref (proto);
543
544   return casewriter_make_reader (wtr);
545 }
546
547 /* Some more indeces into case data */
548 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
549 #define N_POS_GT 2  /* number of postive cases with values greater than n */
550 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
551 #define N_NEG_LT 4  /* number of negative cases with values less than n */
552
553 static bool
554 gt (double d1, double d2)
555 {
556   return d1 > d2;
557 }
558
559
560 static bool
561 ge (double d1, double d2)
562 {
563   return d1 > d2;
564 }
565
566 static bool
567 lt (double d1, double d2)
568 {
569   return d1 < d2;
570 }
571
572
573 /*
574   Return a casereader with width 3,
575   populated with cases based upon READER.
576   The cases will have the values:
577   (N, number of cases equal to N, number of cases greater than N)
578   As a side effect, update RS->n1 with the number of positive cases.
579 */
580 static struct casereader *
581 process_positive_group (const struct variable *var, struct casereader *reader,
582                         const struct dictionary *dict,
583                         struct roc_state *rs)
584 {
585   return process_group (var, reader, gt, dict, &rs->n1,
586                         &rs->cutpoint_rdr,
587                         ge,
588                         ROC_TP, ROC_FN);
589 }
590
591 /*
592   Return a casereader with width 3,
593   populated with cases based upon READER.
594   The cases will have the values:
595   (N, number of cases equal to N, number of cases less than N)
596   As a side effect, update RS->n2 with the number of negative cases.
597 */
598 static struct casereader *
599 process_negative_group (const struct variable *var, struct casereader *reader,
600                         const struct dictionary *dict,
601                         struct roc_state *rs)
602 {
603   return process_group (var, reader, lt, dict, &rs->n2,
604                         &rs->cutpoint_rdr,
605                         lt,
606                         ROC_TN, ROC_FP);
607 }
608
609
610
611
612 static void
613 append_cutpoint (struct casewriter *writer, double cutpoint)
614 {
615   struct ccase *cc = case_create (casewriter_get_proto (writer));
616
617   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
618   case_data_rw_idx (cc, ROC_TP)->f = 0;
619   case_data_rw_idx (cc, ROC_FN)->f = 0;
620   case_data_rw_idx (cc, ROC_TN)->f = 0;
621   case_data_rw_idx (cc, ROC_FP)->f = 0;
622
623   casewriter_write (writer, cc);
624 }
625
626
627 /* 
628    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
629    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
630    reader will be populated with its final number of cases.
631    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
632    value.  The other entries will be initialised to zero.
633 */
634 static void
635 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
636 {
637   int i;
638   struct casereader *r = casereader_clone (input);
639   struct ccase *c;
640
641   {
642     struct caseproto *proto = caseproto_create ();
643     struct subcase ordering;
644     subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
645
646     proto = caseproto_add_width (proto, 0); /* cutpoint */
647     proto = caseproto_add_width (proto, 0); /* ROC_TP */
648     proto = caseproto_add_width (proto, 0); /* ROC_FN */
649     proto = caseproto_add_width (proto, 0); /* ROC_TN */
650     proto = caseproto_add_width (proto, 0); /* ROC_FP */
651
652     for (i = 0 ; i < roc->n_vars; ++i)
653       {
654         rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
655         rs[i].prev_result = SYSMIS;
656         rs[i].max = -DBL_MAX;
657         rs[i].min = DBL_MAX;
658       }
659
660     caseproto_unref (proto);
661     subcase_destroy (&ordering);
662   }
663
664   for (; (c = casereader_read (r)) != NULL; case_unref (c))
665     {
666       for (i = 0 ; i < roc->n_vars; ++i)
667         {
668           const union value *v = case_data (c, roc->vars[i]); 
669           const double result = v->f;
670
671           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
672             continue;
673
674           minimize (&rs[i].min, result);
675           maximize (&rs[i].max, result);
676
677           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
678             {
679               const double mean = (result + rs[i].prev_result ) / 2.0;
680               append_cutpoint (rs[i].cutpoint_wtr, mean);
681             }
682
683           rs[i].prev_result = result;
684         }
685     }
686   casereader_destroy (r);
687
688
689   /* Append the min and max cutpoints */
690   for (i = 0 ; i < roc->n_vars; ++i)
691     {
692       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
693       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
694
695       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
696     }
697 }
698
699 static void
700 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
701 {
702   int i;
703
704   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
705
706   struct casereader *negatives = NULL;
707   struct casereader *positives = NULL;
708
709   struct caseproto *n_proto = NULL;
710
711   struct subcase up_ordering;
712   struct subcase down_ordering;
713
714   struct casewriter *neg_wtr = NULL;
715
716   struct casereader *input = casereader_create_filter_missing (reader,
717                                                                roc->vars, roc->n_vars,
718                                                                roc->exclude,
719                                                                NULL,
720                                                                NULL);
721
722   input = casereader_create_filter_missing (input,
723                                             &roc->state_var, 1,
724                                             roc->exclude,
725                                             NULL,
726                                             NULL);
727
728   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
729
730   prepare_cutpoints (roc, rs, input);
731
732
733   /* Separate the positive actual state cases from the negative ones */
734   positives = 
735     casereader_create_filter_func (input,
736                                    match_positives,
737                                    NULL,
738                                    roc,
739                                    neg_wtr);
740
741   n_proto = caseproto_create ();
742       
743   n_proto = caseproto_add_width (n_proto, 0);
744   n_proto = caseproto_add_width (n_proto, 0);
745   n_proto = caseproto_add_width (n_proto, 0);
746   n_proto = caseproto_add_width (n_proto, 0);
747   n_proto = caseproto_add_width (n_proto, 0);
748
749   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
750   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
751
752   for (i = 0 ; i < roc->n_vars; ++i)
753     {
754       struct casewriter *w = NULL;
755       struct casereader *r = NULL;
756
757       struct ccase *c;
758
759       struct ccase *cpos;
760       struct casereader *n_neg_reader ;
761       const struct variable *var = roc->vars[i];
762
763       struct casereader *neg ;
764       struct casereader *pos = casereader_clone (positives);
765
766       struct casereader *n_pos_reader =
767         process_positive_group (var, pos, dict, &rs[i]);
768
769       if ( negatives == NULL)
770         {
771           negatives = casewriter_make_reader (neg_wtr);
772         }
773
774       neg = casereader_clone (negatives);
775
776       n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
777
778       /* Merge the n_pos and n_neg casereaders */
779       w = sort_create_writer (&up_ordering, n_proto);
780       for ( ; (cpos = casereader_read (n_pos_reader) ); case_unref (cpos))
781         {
782           struct ccase *pos_case = case_create (n_proto);
783           struct ccase *cneg;
784           const double jpos = case_data_idx (cpos, VALUE)->f;
785
786           while ((cneg = casereader_read (n_neg_reader)))
787             {
788               struct ccase *nc = case_create (n_proto);
789
790               const double jneg = case_data_idx (cneg, VALUE)->f;
791
792               case_data_rw_idx (nc, VALUE)->f = jneg;
793               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
794
795               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
796
797               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
798               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
799
800               casewriter_write (w, nc);
801
802               case_unref (cneg);
803               if ( jneg > jpos)
804                 break;
805             }
806
807           case_data_rw_idx (pos_case, VALUE)->f = jpos;
808           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
809           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
810           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
811           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
812
813           casewriter_write (w, pos_case);
814         }
815
816       casereader_destroy (n_pos_reader);
817       casereader_destroy (n_neg_reader);
818
819 /* These aren't used anymore */
820 #undef N_EQ
821 #undef N_PRED
822
823       r = casewriter_make_reader (w);
824
825       /* Propagate the N_POS_GT values from the positive cases
826          to the negative ones */
827       {
828         double prev_pos_gt = rs[i].n1;
829         w = sort_create_writer (&down_ordering, n_proto);
830
831         for ( ; (c = casereader_read (r) ); case_unref (c))
832           {
833             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
834             struct ccase *nc = case_clone (c);
835
836             if ( n_pos_gt == SYSMIS)
837               {
838                 n_pos_gt = prev_pos_gt;
839                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
840               }
841             
842             casewriter_write (w, nc);
843             prev_pos_gt = n_pos_gt;
844           }
845
846         casereader_destroy (r);
847         r = casewriter_make_reader (w);
848       }
849
850       /* Propagate the N_NEG_LT values from the negative cases
851          to the positive ones */
852       {
853         double prev_neg_lt = rs[i].n2;
854         w = sort_create_writer (&up_ordering, n_proto);
855
856         for ( ; (c = casereader_read (r) ); case_unref (c))
857           {
858             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
859             struct ccase *nc = case_clone (c);
860
861             if ( n_neg_lt == SYSMIS)
862               {
863                 n_neg_lt = prev_neg_lt;
864                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
865               }
866             
867             casewriter_write (w, nc);
868             prev_neg_lt = n_neg_lt;
869           }
870
871         casereader_destroy (r);
872         r = casewriter_make_reader (w);
873       }
874
875       {
876         struct ccase *prev_case = NULL;
877         for ( ; (c = casereader_read (r) ); case_unref (c))
878           {
879             struct ccase *next_case = casereader_peek (r, 0);
880
881             const double j = case_data_idx (c, VALUE)->f;
882             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
883             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
884             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
885             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
886
887             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
888               {
889                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
890                   {
891                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
892                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
893                   }
894
895                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
896                   {
897                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
898                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
899                   }
900               }
901
902             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
903               {
904                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
905
906                 rs[i].q1hat +=
907                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
908                 rs[i].q2hat +=
909                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
910
911               }
912
913             case_unref (next_case);
914             case_unref (prev_case);
915             prev_case = case_clone (c);
916           }
917         casereader_destroy (r);
918         case_unref (prev_case);
919
920         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
921         if ( roc->invert ) 
922           rs[i].auc = 1 - rs[i].auc;
923
924         if ( roc->bi_neg_exp )
925           {
926             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
927             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
928           }
929         else
930           {
931             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
932             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
933           }
934       }
935     }
936
937   casereader_destroy (positives);
938   casereader_destroy (negatives);
939
940   caseproto_unref (n_proto);
941   subcase_destroy (&up_ordering);
942   subcase_destroy (&down_ordering);
943
944   output_roc (rs, roc);
945  
946   for (i = 0 ; i < roc->n_vars; ++i)
947     casereader_destroy (rs[i].cutpoint_rdr);
948
949   free (rs);
950 }
951
952 static void
953 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
954 {
955   int i;
956   const int n_fields = roc->print_se ? 5 : 1;
957   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
958   const int n_rows = 2 + roc->n_vars;
959   struct tab_table *tbl = tab_create (n_cols, n_rows);
960
961   if ( roc->n_vars > 1)
962     tab_title (tbl, _("Area Under the Curve"));
963   else
964     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
965
966   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
967
968
969   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
970
971   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
972
973   tab_box (tbl,
974            TAL_2, TAL_2,
975            -1, TAL_1,
976            0, 0,
977            n_cols - 1,
978            n_rows - 1);
979
980   if ( roc->print_se )
981     {
982       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
983       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
984
985       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
986       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
987
988       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
989                              TAT_TITLE | TAB_CENTER,
990                              _("Asymp. %g%% Confidence Interval"), roc->ci);
991       tab_vline (tbl, 0, n_cols - 1, 0, 0);
992       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
993     }
994
995   if ( roc->n_vars > 1)
996     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
997
998   if ( roc->n_vars > 1)
999     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1000
1001
1002   for ( i = 0 ; i < roc->n_vars ; ++i )
1003     {
1004       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
1005
1006       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL, RC_OTHER);
1007
1008       if ( roc->print_se )
1009         {
1010           double se ;
1011           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1012                                       (12 * rs[i].n1 * rs[i].n2));
1013           double ci ;
1014           double yy ;
1015
1016           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
1017             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
1018
1019           se /= rs[i].n1 * rs[i].n2;
1020
1021           se = sqrt (se);
1022
1023           tab_double (tbl, n_cols - 4, 2 + i, 0,
1024                       se,
1025                       NULL, RC_OTHER);
1026
1027           ci = 1 - roc->ci / 100.0;
1028           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1029
1030           tab_double (tbl, n_cols - 2, 2 + i, 0,
1031                       rs[i].auc - yy,
1032                       NULL, RC_OTHER);
1033
1034           tab_double (tbl, n_cols - 1, 2 + i, 0,
1035                       rs[i].auc + yy,
1036                       NULL, RC_OTHER);
1037
1038           tab_double (tbl, n_cols - 3, 2 + i, 0,
1039                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1040                       NULL, RC_PVALUE);
1041         }
1042     }
1043
1044   tab_submit (tbl);
1045 }
1046
1047
1048 static void
1049 show_summary (const struct cmd_roc *roc)
1050 {
1051   const int n_cols = 3;
1052   const int n_rows = 4;
1053   struct tab_table *tbl = tab_create (n_cols, n_rows);
1054
1055   tab_title (tbl, _("Case Summary"));
1056
1057   tab_headers (tbl, 1, 0, 2, 0);
1058
1059   tab_box (tbl,
1060            TAL_2, TAL_2,
1061            -1, -1,
1062            0, 0,
1063            n_cols - 1,
1064            n_rows - 1);
1065
1066   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1067   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1068
1069
1070   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1071   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1072
1073
1074   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1075   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1076   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1077
1078   tab_joint_text (tbl, 1, 0, 2, 0,
1079                   TAT_TITLE | TAB_CENTER,
1080                   _("Valid N (listwise)"));
1081
1082
1083   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1084   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1085
1086
1087   tab_double (tbl, 1, 2, 0, roc->pos, NULL, RC_INTEGER);
1088   tab_double (tbl, 1, 3, 0, roc->neg, NULL, RC_INTEGER);
1089
1090   tab_double (tbl, 2, 2, 0, roc->pos_weighted, NULL, RC_OTHER);
1091   tab_double (tbl, 2, 3, 0, roc->neg_weighted, NULL, RC_OTHER);
1092
1093   tab_submit (tbl);
1094 }
1095
1096
1097 static void
1098 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1099 {
1100   int x = 1;
1101   int i;
1102   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1103   int n_rows = 1;
1104   struct tab_table *tbl ;
1105
1106   for (i = 0; i < roc->n_vars; ++i)
1107     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1108
1109   tbl = tab_create (n_cols, n_rows);
1110
1111   if ( roc->n_vars > 1)
1112     tab_title (tbl, _("Coordinates of the Curve"));
1113   else
1114     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1115
1116
1117   tab_headers (tbl, 1, 0, 1, 0);
1118
1119   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1120
1121   if ( roc->n_vars > 1)
1122     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1123
1124   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1125   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1126   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1127
1128   tab_box (tbl,
1129            TAL_2, TAL_2,
1130            -1, TAL_1,
1131            0, 0,
1132            n_cols - 1,
1133            n_rows - 1);
1134
1135   if ( roc->n_vars > 1)
1136     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1137
1138   for (i = 0; i < roc->n_vars; ++i)
1139     {
1140       struct ccase *cc;
1141       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1142
1143       if ( roc->n_vars > 1)
1144         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1145
1146       if ( i > 0)
1147         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1148
1149
1150       for (; (cc = casereader_read (r)) != NULL;
1151            case_unref (cc), x++)
1152         {
1153           const double se = case_data_idx (cc, ROC_TP)->f /
1154             (
1155              case_data_idx (cc, ROC_TP)->f
1156              +
1157              case_data_idx (cc, ROC_FN)->f
1158              );
1159
1160           const double sp = case_data_idx (cc, ROC_TN)->f /
1161             (
1162              case_data_idx (cc, ROC_TN)->f
1163              +
1164              case_data_idx (cc, ROC_FP)->f
1165              );
1166
1167           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1168                       var_get_print_format (roc->vars[i]), RC_OTHER);
1169
1170           tab_double (tbl, n_cols - 2, x, 0, se, NULL, RC_OTHER);
1171           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL, RC_OTHER);
1172         }
1173
1174       casereader_destroy (r);
1175     }
1176
1177   tab_submit (tbl);
1178 }
1179
1180
1181 static void
1182 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1183 {
1184   show_summary (roc);
1185
1186   if ( roc->curve )
1187     {
1188       struct roc_chart *rc;
1189       size_t i;
1190
1191       rc = roc_chart_create (roc->reference);
1192       for (i = 0; i < roc->n_vars; i++)
1193         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1194                            rs[i].cutpoint_rdr);
1195       roc_chart_submit (rc);
1196     }
1197
1198   show_auc (rs, roc);
1199
1200   if ( roc->print_coords )
1201     show_coords (rs, roc);
1202 }
1203