Fix crash when ROC was passed a non-number where a number was expected.
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/roc.h"
20
21 #include <gsl/gsl_cdf.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/chart-item.h"
37 #include "output/charts/roc-chart.h"
38 #include "output/tab.h"
39
40 #include "gettext.h"
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
43
44 struct cmd_roc
45 {
46   size_t n_vars;
47   const struct variable **vars;
48   const struct dictionary *dict;
49
50   const struct variable *state_var;
51   union value state_value;
52   size_t state_var_width;
53
54   /* Plot the roc curve */
55   bool curve;
56   /* Plot the reference line */
57   bool reference;
58
59   double ci;
60
61   bool print_coords;
62   bool print_se;
63   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
64                       should be used */
65   enum mv_class exclude;
66
67   bool invert ; /* True iff a smaller test result variable indicates
68                    a positive result */
69
70   double pos;
71   double neg;
72   double pos_weighted;
73   double neg_weighted;
74 };
75
76 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
77
78 int
79 cmd_roc (struct lexer *lexer, struct dataset *ds)
80 {
81   struct cmd_roc roc ;
82   const struct dictionary *dict = dataset_dict (ds);
83
84   roc.vars = NULL;
85   roc.n_vars = 0;
86   roc.print_se = false;
87   roc.print_coords = false;
88   roc.exclude = MV_ANY;
89   roc.curve = true;
90   roc.reference = false;
91   roc.ci = 95;
92   roc.bi_neg_exp = false;
93   roc.invert = false;
94   roc.pos = roc.pos_weighted = 0;
95   roc.neg = roc.neg_weighted = 0;
96   roc.dict = dataset_dict (ds);
97   roc.state_var = NULL;
98   roc.state_var_width = -1;
99
100   lex_match (lexer, T_SLASH);
101   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
102                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
103     goto error;
104
105   if ( ! lex_force_match (lexer, T_BY))
106     {
107       goto error;
108     }
109
110   roc.state_var = parse_variable (lexer, dict);
111   if (! roc.state_var)
112     {
113       goto error;
114     }
115
116   if ( !lex_force_match (lexer, T_LPAREN))
117     {
118       goto error;
119     }
120
121   roc.state_var_width = var_get_width (roc.state_var);
122   value_init (&roc.state_value, roc.state_var_width);
123   parse_value (lexer, &roc.state_value, roc.state_var);
124
125
126   if ( !lex_force_match (lexer, T_RPAREN))
127     {
128       goto error;
129     }
130
131   while (lex_token (lexer) != T_ENDCMD)
132     {
133       lex_match (lexer, T_SLASH);
134       if (lex_match_id (lexer, "MISSING"))
135         {
136           lex_match (lexer, T_EQUALS);
137           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
138             {
139               if (lex_match_id (lexer, "INCLUDE"))
140                 {
141                   roc.exclude = MV_SYSTEM;
142                 }
143               else if (lex_match_id (lexer, "EXCLUDE"))
144                 {
145                   roc.exclude = MV_ANY;
146                 }
147               else
148                 {
149                   lex_error (lexer, NULL);
150                   goto error;
151                 }
152             }
153         }
154       else if (lex_match_id (lexer, "PLOT"))
155         {
156           lex_match (lexer, T_EQUALS);
157           if (lex_match_id (lexer, "CURVE"))
158             {
159               roc.curve = true;
160               if (lex_match (lexer, T_LPAREN))
161                 {
162                   roc.reference = true;
163                   lex_force_match_id (lexer, "REFERENCE");
164                   lex_force_match (lexer, T_RPAREN);
165                 }
166             }
167           else if (lex_match_id (lexer, "NONE"))
168             {
169               roc.curve = false;
170             }
171           else
172             {
173               lex_error (lexer, NULL);
174               goto error;
175             }
176         }
177       else if (lex_match_id (lexer, "PRINT"))
178         {
179           lex_match (lexer, T_EQUALS);
180           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
181             {
182               if (lex_match_id (lexer, "SE"))
183                 {
184                   roc.print_se = true;
185                 }
186               else if (lex_match_id (lexer, "COORDINATES"))
187                 {
188                   roc.print_coords = true;
189                 }
190               else
191                 {
192                   lex_error (lexer, NULL);
193                   goto error;
194                 }
195             }
196         }
197       else if (lex_match_id (lexer, "CRITERIA"))
198         {
199           lex_match (lexer, T_EQUALS);
200           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
201             {
202               if (lex_match_id (lexer, "CUTOFF"))
203                 {
204                   lex_force_match (lexer, T_LPAREN);
205                   if (lex_match_id (lexer, "INCLUDE"))
206                     {
207                       roc.exclude = MV_SYSTEM;
208                     }
209                   else if (lex_match_id (lexer, "EXCLUDE"))
210                     {
211                       roc.exclude = MV_USER | MV_SYSTEM;
212                     }
213                   else
214                     {
215                       lex_error (lexer, NULL);
216                       goto error;
217                     }
218                   lex_force_match (lexer, T_RPAREN);
219                 }
220               else if (lex_match_id (lexer, "TESTPOS"))
221                 {
222                   lex_force_match (lexer, T_LPAREN);
223                   if (lex_match_id (lexer, "LARGE"))
224                     {
225                       roc.invert = false;
226                     }
227                   else if (lex_match_id (lexer, "SMALL"))
228                     {
229                       roc.invert = true;
230                     }
231                   else
232                     {
233                       lex_error (lexer, NULL);
234                       goto error;
235                     }
236                   lex_force_match (lexer, T_RPAREN);
237                 }
238               else if (lex_match_id (lexer, "CI"))
239                 {
240                   if (!lex_force_match (lexer, T_LPAREN))
241                     goto error;
242                   if (! lex_force_num (lexer))
243                     goto error;
244                   roc.ci = lex_number (lexer);
245                   lex_get (lexer);
246                   if (!lex_force_match (lexer, T_RPAREN))
247                     goto error;
248                 }
249               else if (lex_match_id (lexer, "DISTRIBUTION"))
250                 {
251                   if (!lex_force_match (lexer, T_LPAREN))
252                     goto error;
253                   if (lex_match_id (lexer, "FREE"))
254                     {
255                       roc.bi_neg_exp = false;
256                     }
257                   else if (lex_match_id (lexer, "NEGEXPO"))
258                     {
259                       roc.bi_neg_exp = true;
260                     }
261                   else
262                     {
263                       lex_error (lexer, NULL);
264                       goto error;
265                     }
266                   if (!lex_force_match (lexer, T_RPAREN))
267                     goto error;
268                 }
269               else
270                 {
271                   lex_error (lexer, NULL);
272                   goto error;
273                 }
274             }
275         }
276       else
277         {
278           lex_error (lexer, NULL);
279           break;
280         }
281     }
282
283   if ( ! run_roc (ds, &roc)) 
284     goto error;
285
286   if ( roc.state_var)
287     value_destroy (&roc.state_value, roc.state_var_width);
288   free (roc.vars);
289   return CMD_SUCCESS;
290
291  error:
292   if ( roc.state_var)
293     value_destroy (&roc.state_value, roc.state_var_width);
294   free (roc.vars);
295   return CMD_FAILURE;
296 }
297
298
299
300
301 static void
302 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
303
304
305 static int
306 run_roc (struct dataset *ds, struct cmd_roc *roc)
307 {
308   struct dictionary *dict = dataset_dict (ds);
309   bool ok;
310   struct casereader *group;
311
312   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
313   while (casegrouper_get_next_group (grouper, &group))
314     {
315       do_roc (roc, group, dataset_dict (ds));
316     }
317   ok = casegrouper_destroy (grouper);
318   ok = proc_commit (ds) && ok;
319
320   return ok;
321 }
322
323 #if 0
324 static void
325 dump_casereader (struct casereader *reader)
326 {
327   struct ccase *c;
328   struct casereader *r = casereader_clone (reader);
329
330   for ( ; (c = casereader_read (r) ); case_unref (c))
331     {
332       int i;
333       for (i = 0 ; i < case_get_value_cnt (c); ++i)
334         {
335           printf ("%g ", case_data_idx (c, i)->f);
336         }
337       printf ("\n");
338     }
339
340   casereader_destroy (r);
341 }
342 #endif
343
344
345 /* 
346    Return true iff the state variable indicates that C has positive actual state.
347
348    As a side effect, this function also accumulates the roc->{pos,neg} and 
349    roc->{pos,neg}_weighted counts.
350  */
351 static bool
352 match_positives (const struct ccase *c, void *aux)
353 {
354   struct cmd_roc *roc = aux;
355   const struct variable *wv = dict_get_weight (roc->dict);
356   const double weight = wv ? case_data (c, wv)->f : 1.0;
357
358   const bool positive =
359   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
360     var_get_width (roc->state_var)));
361
362   if ( positive )
363     {
364       roc->pos++;
365       roc->pos_weighted += weight;
366     }
367   else
368     {
369       roc->neg++;
370       roc->neg_weighted += weight;
371     }
372
373   return positive;
374 }
375
376
377 #define VALUE  0
378 #define N_EQ   1
379 #define N_PRED 2
380
381 /* Some intermediate state for calculating the cutpoints and the 
382    standard error values */
383 struct roc_state
384 {
385   double auc;  /* Area under the curve */
386
387   double n1;  /* total weight of positives */
388   double n2;  /* total weight of negatives */
389
390   /* intermediates for standard error */
391   double q1hat; 
392   double q2hat;
393
394   /* intermediates for cutpoints */
395   struct casewriter *cutpoint_wtr;
396   struct casereader *cutpoint_rdr;
397   double prev_result;
398   double min;
399   double max;
400 };
401
402 /* 
403    Return a new casereader based upon CUTPOINT_RDR.
404    The number of "positive" cases are placed into
405    the position TRUE_INDEX, and the number of "negative" cases
406    into FALSE_INDEX.
407    POS_COND and RESULT determine the semantics of what is 
408    "positive".
409    WEIGHT is the value of a single count.
410  */
411 static struct casereader *
412 accumulate_counts (struct casereader *input,
413                    double result, double weight, 
414                    bool (*pos_cond) (double, double),
415                    int true_index, int false_index)
416 {
417   const struct caseproto *proto = casereader_get_proto (input);
418   struct casewriter *w =
419     autopaging_writer_create (proto);
420   struct ccase *cpc;
421   double prev_cp = SYSMIS;
422
423   for ( ; (cpc = casereader_read (input) ); case_unref (cpc))
424     {
425       struct ccase *new_case;
426       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
427
428       assert (cp != SYSMIS);
429
430       /* We don't want duplicates here */
431       if ( cp == prev_cp )
432         continue;
433
434       new_case = case_clone (cpc);
435
436       if ( pos_cond (result, cp))
437         case_data_rw_idx (new_case, true_index)->f += weight;
438       else
439         case_data_rw_idx (new_case, false_index)->f += weight;
440
441       prev_cp = cp;
442
443       casewriter_write (w, new_case);
444     }
445   casereader_destroy (input);
446
447   return casewriter_make_reader (w);
448 }
449
450
451
452 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
453
454 /*
455   This function does 3 things:
456
457   1. Counts the number of cases which are equal to every other case in READER,
458   and those cases for which the relationship between it and every other case
459   satifies PRED (normally either > or <).  VAR is variable defining a case's value
460   for this purpose.
461
462   2. Counts the number of true and false cases in reader, and populates
463   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
464   which receive these values.  POS_COND is the condition defining true
465   and false.
466   
467   3. CC is filled with the cumulative weight of all cases of READER.
468 */
469 static struct casereader *
470 process_group (const struct variable *var, struct casereader *reader,
471                bool (*pred) (double, double),
472                const struct dictionary *dict,
473                double *cc,
474                struct casereader **cutpoint_rdr, 
475                bool (*pos_cond) (double, double),
476                int true_index,
477                int false_index)
478 {
479   const struct variable *w = dict_get_weight (dict);
480
481   struct casereader *r1 =
482     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
483
484   const int weight_idx  = w ? var_get_case_index (w) :
485     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
486   
487   struct ccase *c1;
488
489   struct casereader *rclone = casereader_clone (r1);
490   struct casewriter *wtr;
491   struct caseproto *proto = caseproto_create ();
492
493   proto = caseproto_add_width (proto, 0);
494   proto = caseproto_add_width (proto, 0);
495   proto = caseproto_add_width (proto, 0);
496
497   wtr = autopaging_writer_create (proto);  
498
499   *cc = 0;
500
501   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
502     {
503       struct ccase *new_case = case_create (proto);
504       struct ccase *c2;
505       struct casereader *r2 = casereader_clone (rclone);
506
507       const double weight1 = case_data_idx (c1, weight_idx)->f;
508       const double d1 = case_data (c1, var)->f;
509       double n_eq = 0.0;
510       double n_pred = 0.0;
511
512       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
513                                          pos_cond,
514                                          true_index, false_index);
515
516       *cc += weight1;
517
518       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
519         {
520           const double d2 = case_data (c2, var)->f;
521           const double weight2 = case_data_idx (c2, weight_idx)->f;
522
523           if ( d1 == d2 )
524             {
525               n_eq += weight2;
526               continue;
527             }
528           else  if ( pred (d2, d1))
529             {
530               n_pred += weight2;
531             }
532         }
533
534       case_data_rw_idx (new_case, VALUE)->f = d1;
535       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
536       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
537
538       casewriter_write (wtr, new_case);
539
540       casereader_destroy (r2);
541     }
542
543   
544   casereader_destroy (r1);
545   casereader_destroy (rclone);
546
547   caseproto_unref (proto);
548
549   return casewriter_make_reader (wtr);
550 }
551
552 /* Some more indeces into case data */
553 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
554 #define N_POS_GT 2  /* number of postive cases with values greater than n */
555 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
556 #define N_NEG_LT 4  /* number of negative cases with values less than n */
557
558 static bool
559 gt (double d1, double d2)
560 {
561   return d1 > d2;
562 }
563
564
565 static bool
566 ge (double d1, double d2)
567 {
568   return d1 > d2;
569 }
570
571 static bool
572 lt (double d1, double d2)
573 {
574   return d1 < d2;
575 }
576
577
578 /*
579   Return a casereader with width 3,
580   populated with cases based upon READER.
581   The cases will have the values:
582   (N, number of cases equal to N, number of cases greater than N)
583   As a side effect, update RS->n1 with the number of positive cases.
584 */
585 static struct casereader *
586 process_positive_group (const struct variable *var, struct casereader *reader,
587                         const struct dictionary *dict,
588                         struct roc_state *rs)
589 {
590   return process_group (var, reader, gt, dict, &rs->n1,
591                         &rs->cutpoint_rdr,
592                         ge,
593                         ROC_TP, ROC_FN);
594 }
595
596 /*
597   Return a casereader with width 3,
598   populated with cases based upon READER.
599   The cases will have the values:
600   (N, number of cases equal to N, number of cases less than N)
601   As a side effect, update RS->n2 with the number of negative cases.
602 */
603 static struct casereader *
604 process_negative_group (const struct variable *var, struct casereader *reader,
605                         const struct dictionary *dict,
606                         struct roc_state *rs)
607 {
608   return process_group (var, reader, lt, dict, &rs->n2,
609                         &rs->cutpoint_rdr,
610                         lt,
611                         ROC_TN, ROC_FP);
612 }
613
614
615
616
617 static void
618 append_cutpoint (struct casewriter *writer, double cutpoint)
619 {
620   struct ccase *cc = case_create (casewriter_get_proto (writer));
621
622   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
623   case_data_rw_idx (cc, ROC_TP)->f = 0;
624   case_data_rw_idx (cc, ROC_FN)->f = 0;
625   case_data_rw_idx (cc, ROC_TN)->f = 0;
626   case_data_rw_idx (cc, ROC_FP)->f = 0;
627
628   casewriter_write (writer, cc);
629 }
630
631
632 /* 
633    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
634    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
635    reader will be populated with its final number of cases.
636    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
637    value.  The other entries will be initialised to zero.
638 */
639 static void
640 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
641 {
642   int i;
643   struct casereader *r = casereader_clone (input);
644   struct ccase *c;
645
646   {
647     struct caseproto *proto = caseproto_create ();
648     struct subcase ordering;
649     subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
650
651     proto = caseproto_add_width (proto, 0); /* cutpoint */
652     proto = caseproto_add_width (proto, 0); /* ROC_TP */
653     proto = caseproto_add_width (proto, 0); /* ROC_FN */
654     proto = caseproto_add_width (proto, 0); /* ROC_TN */
655     proto = caseproto_add_width (proto, 0); /* ROC_FP */
656
657     for (i = 0 ; i < roc->n_vars; ++i)
658       {
659         rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
660         rs[i].prev_result = SYSMIS;
661         rs[i].max = -DBL_MAX;
662         rs[i].min = DBL_MAX;
663       }
664
665     caseproto_unref (proto);
666     subcase_destroy (&ordering);
667   }
668
669   for (; (c = casereader_read (r)) != NULL; case_unref (c))
670     {
671       for (i = 0 ; i < roc->n_vars; ++i)
672         {
673           const union value *v = case_data (c, roc->vars[i]); 
674           const double result = v->f;
675
676           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
677             continue;
678
679           minimize (&rs[i].min, result);
680           maximize (&rs[i].max, result);
681
682           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
683             {
684               const double mean = (result + rs[i].prev_result ) / 2.0;
685               append_cutpoint (rs[i].cutpoint_wtr, mean);
686             }
687
688           rs[i].prev_result = result;
689         }
690     }
691   casereader_destroy (r);
692
693
694   /* Append the min and max cutpoints */
695   for (i = 0 ; i < roc->n_vars; ++i)
696     {
697       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
698       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
699
700       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
701     }
702 }
703
704 static void
705 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
706 {
707   int i;
708
709   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
710
711   struct casereader *negatives = NULL;
712   struct casereader *positives = NULL;
713
714   struct caseproto *n_proto = NULL;
715
716   struct subcase up_ordering;
717   struct subcase down_ordering;
718
719   struct casewriter *neg_wtr = NULL;
720
721   struct casereader *input = casereader_create_filter_missing (reader,
722                                                                roc->vars, roc->n_vars,
723                                                                roc->exclude,
724                                                                NULL,
725                                                                NULL);
726
727   input = casereader_create_filter_missing (input,
728                                             &roc->state_var, 1,
729                                             roc->exclude,
730                                             NULL,
731                                             NULL);
732
733   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
734
735   prepare_cutpoints (roc, rs, input);
736
737
738   /* Separate the positive actual state cases from the negative ones */
739   positives = 
740     casereader_create_filter_func (input,
741                                    match_positives,
742                                    NULL,
743                                    roc,
744                                    neg_wtr);
745
746   n_proto = caseproto_create ();
747       
748   n_proto = caseproto_add_width (n_proto, 0);
749   n_proto = caseproto_add_width (n_proto, 0);
750   n_proto = caseproto_add_width (n_proto, 0);
751   n_proto = caseproto_add_width (n_proto, 0);
752   n_proto = caseproto_add_width (n_proto, 0);
753
754   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
755   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
756
757   for (i = 0 ; i < roc->n_vars; ++i)
758     {
759       struct casewriter *w = NULL;
760       struct casereader *r = NULL;
761
762       struct ccase *c;
763
764       struct ccase *cpos;
765       struct casereader *n_neg_reader ;
766       const struct variable *var = roc->vars[i];
767
768       struct casereader *neg ;
769       struct casereader *pos = casereader_clone (positives);
770
771       struct casereader *n_pos_reader =
772         process_positive_group (var, pos, dict, &rs[i]);
773
774       if ( negatives == NULL)
775         {
776           negatives = casewriter_make_reader (neg_wtr);
777         }
778
779       neg = casereader_clone (negatives);
780
781       n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
782
783       /* Merge the n_pos and n_neg casereaders */
784       w = sort_create_writer (&up_ordering, n_proto);
785       for ( ; (cpos = casereader_read (n_pos_reader) ); case_unref (cpos))
786         {
787           struct ccase *pos_case = case_create (n_proto);
788           struct ccase *cneg;
789           const double jpos = case_data_idx (cpos, VALUE)->f;
790
791           while ((cneg = casereader_read (n_neg_reader)))
792             {
793               struct ccase *nc = case_create (n_proto);
794
795               const double jneg = case_data_idx (cneg, VALUE)->f;
796
797               case_data_rw_idx (nc, VALUE)->f = jneg;
798               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
799
800               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
801
802               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
803               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
804
805               casewriter_write (w, nc);
806
807               case_unref (cneg);
808               if ( jneg > jpos)
809                 break;
810             }
811
812           case_data_rw_idx (pos_case, VALUE)->f = jpos;
813           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
814           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
815           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
816           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
817
818           casewriter_write (w, pos_case);
819         }
820
821       casereader_destroy (n_pos_reader);
822       casereader_destroy (n_neg_reader);
823
824 /* These aren't used anymore */
825 #undef N_EQ
826 #undef N_PRED
827
828       r = casewriter_make_reader (w);
829
830       /* Propagate the N_POS_GT values from the positive cases
831          to the negative ones */
832       {
833         double prev_pos_gt = rs[i].n1;
834         w = sort_create_writer (&down_ordering, n_proto);
835
836         for ( ; (c = casereader_read (r) ); case_unref (c))
837           {
838             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
839             struct ccase *nc = case_clone (c);
840
841             if ( n_pos_gt == SYSMIS)
842               {
843                 n_pos_gt = prev_pos_gt;
844                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
845               }
846             
847             casewriter_write (w, nc);
848             prev_pos_gt = n_pos_gt;
849           }
850
851         casereader_destroy (r);
852         r = casewriter_make_reader (w);
853       }
854
855       /* Propagate the N_NEG_LT values from the negative cases
856          to the positive ones */
857       {
858         double prev_neg_lt = rs[i].n2;
859         w = sort_create_writer (&up_ordering, n_proto);
860
861         for ( ; (c = casereader_read (r) ); case_unref (c))
862           {
863             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
864             struct ccase *nc = case_clone (c);
865
866             if ( n_neg_lt == SYSMIS)
867               {
868                 n_neg_lt = prev_neg_lt;
869                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
870               }
871             
872             casewriter_write (w, nc);
873             prev_neg_lt = n_neg_lt;
874           }
875
876         casereader_destroy (r);
877         r = casewriter_make_reader (w);
878       }
879
880       {
881         struct ccase *prev_case = NULL;
882         for ( ; (c = casereader_read (r) ); case_unref (c))
883           {
884             struct ccase *next_case = casereader_peek (r, 0);
885
886             const double j = case_data_idx (c, VALUE)->f;
887             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
888             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
889             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
890             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
891
892             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
893               {
894                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
895                   {
896                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
897                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
898                   }
899
900                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
901                   {
902                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
903                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
904                   }
905               }
906
907             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
908               {
909                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
910
911                 rs[i].q1hat +=
912                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
913                 rs[i].q2hat +=
914                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
915
916               }
917
918             case_unref (next_case);
919             case_unref (prev_case);
920             prev_case = case_clone (c);
921           }
922         casereader_destroy (r);
923         case_unref (prev_case);
924
925         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
926         if ( roc->invert ) 
927           rs[i].auc = 1 - rs[i].auc;
928
929         if ( roc->bi_neg_exp )
930           {
931             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
932             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
933           }
934         else
935           {
936             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
937             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
938           }
939       }
940     }
941
942   casereader_destroy (positives);
943   casereader_destroy (negatives);
944
945   caseproto_unref (n_proto);
946   subcase_destroy (&up_ordering);
947   subcase_destroy (&down_ordering);
948
949   output_roc (rs, roc);
950  
951   for (i = 0 ; i < roc->n_vars; ++i)
952     casereader_destroy (rs[i].cutpoint_rdr);
953
954   free (rs);
955 }
956
957 static void
958 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
959 {
960   int i;
961   const int n_fields = roc->print_se ? 5 : 1;
962   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
963   const int n_rows = 2 + roc->n_vars;
964   struct tab_table *tbl = tab_create (n_cols, n_rows);
965
966   if ( roc->n_vars > 1)
967     tab_title (tbl, _("Area Under the Curve"));
968   else
969     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
970
971   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
972
973
974   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
975
976   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
977
978   tab_box (tbl,
979            TAL_2, TAL_2,
980            -1, TAL_1,
981            0, 0,
982            n_cols - 1,
983            n_rows - 1);
984
985   if ( roc->print_se )
986     {
987       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
988       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
989
990       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
991       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
992
993       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
994                              TAT_TITLE | TAB_CENTER,
995                              _("Asymp. %g%% Confidence Interval"), roc->ci);
996       tab_vline (tbl, 0, n_cols - 1, 0, 0);
997       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
998     }
999
1000   if ( roc->n_vars > 1)
1001     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
1002
1003   if ( roc->n_vars > 1)
1004     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1005
1006
1007   for ( i = 0 ; i < roc->n_vars ; ++i )
1008     {
1009       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
1010
1011       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL, RC_OTHER);
1012
1013       if ( roc->print_se )
1014         {
1015           double se ;
1016           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1017                                       (12 * rs[i].n1 * rs[i].n2));
1018           double ci ;
1019           double yy ;
1020
1021           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
1022             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
1023
1024           se /= rs[i].n1 * rs[i].n2;
1025
1026           se = sqrt (se);
1027
1028           tab_double (tbl, n_cols - 4, 2 + i, 0,
1029                       se,
1030                       NULL, RC_OTHER);
1031
1032           ci = 1 - roc->ci / 100.0;
1033           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1034
1035           tab_double (tbl, n_cols - 2, 2 + i, 0,
1036                       rs[i].auc - yy,
1037                       NULL, RC_OTHER);
1038
1039           tab_double (tbl, n_cols - 1, 2 + i, 0,
1040                       rs[i].auc + yy,
1041                       NULL, RC_OTHER);
1042
1043           tab_double (tbl, n_cols - 3, 2 + i, 0,
1044                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1045                       NULL, RC_PVALUE);
1046         }
1047     }
1048
1049   tab_submit (tbl);
1050 }
1051
1052
1053 static void
1054 show_summary (const struct cmd_roc *roc)
1055 {
1056   const int n_cols = 3;
1057   const int n_rows = 4;
1058   struct tab_table *tbl = tab_create (n_cols, n_rows);
1059
1060   tab_title (tbl, _("Case Summary"));
1061
1062   tab_headers (tbl, 1, 0, 2, 0);
1063
1064   tab_box (tbl,
1065            TAL_2, TAL_2,
1066            -1, -1,
1067            0, 0,
1068            n_cols - 1,
1069            n_rows - 1);
1070
1071   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1072   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1073
1074
1075   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1076   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1077
1078
1079   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1080   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1081   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1082
1083   tab_joint_text (tbl, 1, 0, 2, 0,
1084                   TAT_TITLE | TAB_CENTER,
1085                   _("Valid N (listwise)"));
1086
1087
1088   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1089   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1090
1091
1092   tab_double (tbl, 1, 2, 0, roc->pos, NULL, RC_INTEGER);
1093   tab_double (tbl, 1, 3, 0, roc->neg, NULL, RC_INTEGER);
1094
1095   tab_double (tbl, 2, 2, 0, roc->pos_weighted, NULL, RC_OTHER);
1096   tab_double (tbl, 2, 3, 0, roc->neg_weighted, NULL, RC_OTHER);
1097
1098   tab_submit (tbl);
1099 }
1100
1101
1102 static void
1103 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1104 {
1105   int x = 1;
1106   int i;
1107   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1108   int n_rows = 1;
1109   struct tab_table *tbl ;
1110
1111   for (i = 0; i < roc->n_vars; ++i)
1112     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1113
1114   tbl = tab_create (n_cols, n_rows);
1115
1116   if ( roc->n_vars > 1)
1117     tab_title (tbl, _("Coordinates of the Curve"));
1118   else
1119     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1120
1121
1122   tab_headers (tbl, 1, 0, 1, 0);
1123
1124   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1125
1126   if ( roc->n_vars > 1)
1127     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1128
1129   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1130   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1131   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1132
1133   tab_box (tbl,
1134            TAL_2, TAL_2,
1135            -1, TAL_1,
1136            0, 0,
1137            n_cols - 1,
1138            n_rows - 1);
1139
1140   if ( roc->n_vars > 1)
1141     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1142
1143   for (i = 0; i < roc->n_vars; ++i)
1144     {
1145       struct ccase *cc;
1146       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1147
1148       if ( roc->n_vars > 1)
1149         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1150
1151       if ( i > 0)
1152         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1153
1154
1155       for (; (cc = casereader_read (r)) != NULL;
1156            case_unref (cc), x++)
1157         {
1158           const double se = case_data_idx (cc, ROC_TP)->f /
1159             (
1160              case_data_idx (cc, ROC_TP)->f
1161              +
1162              case_data_idx (cc, ROC_FN)->f
1163              );
1164
1165           const double sp = case_data_idx (cc, ROC_TN)->f /
1166             (
1167              case_data_idx (cc, ROC_TN)->f
1168              +
1169              case_data_idx (cc, ROC_FP)->f
1170              );
1171
1172           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1173                       var_get_print_format (roc->vars[i]), RC_OTHER);
1174
1175           tab_double (tbl, n_cols - 2, x, 0, se, NULL, RC_OTHER);
1176           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL, RC_OTHER);
1177         }
1178
1179       casereader_destroy (r);
1180     }
1181
1182   tab_submit (tbl);
1183 }
1184
1185
1186 static void
1187 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1188 {
1189   show_summary (roc);
1190
1191   if ( roc->curve )
1192     {
1193       struct roc_chart *rc;
1194       size_t i;
1195
1196       rc = roc_chart_create (roc->reference);
1197       for (i = 0; i < roc->n_vars; i++)
1198         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1199                            rs[i].cutpoint_rdr);
1200       roc_chart_submit (rc);
1201     }
1202
1203   show_auc (rs, roc);
1204
1205   if ( roc->print_coords )
1206     show_coords (rs, roc);
1207 }
1208