Rewrite PSPP output engine.
[pspp-builds.git] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <language/stats/roc.h>
20
21 #include <data/casegrouper.h>
22 #include <data/casereader.h>
23 #include <data/casewriter.h>
24 #include <data/dictionary.h>
25 #include <data/format.h>
26 #include <data/procedure.h>
27 #include <data/subcase.h>
28 #include <language/command.h>
29 #include <language/lexer/lexer.h>
30 #include <language/lexer/value-parser.h>
31 #include <language/lexer/variable-parser.h>
32 #include <libpspp/misc.h>
33 #include <math/sort.h>
34 #include <output/chart-item.h>
35 #include <output/charts/roc-chart.h>
36 #include <output/tab.h>
37
38 #include <gsl/gsl_cdf.h>
39
40 #include "gettext.h"
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
43
44 struct cmd_roc
45 {
46   size_t n_vars;
47   const struct variable **vars;
48   const struct dictionary *dict;
49
50   const struct variable *state_var ;
51   union value state_value;
52
53   /* Plot the roc curve */
54   bool curve;
55   /* Plot the reference line */
56   bool reference;
57
58   double ci;
59
60   bool print_coords;
61   bool print_se;
62   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
63                       should be used */
64   enum mv_class exclude;
65
66   bool invert ; /* True iff a smaller test result variable indicates
67                    a positive result */
68
69   double pos;
70   double neg;
71   double pos_weighted;
72   double neg_weighted;
73 };
74
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
76
77 int
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
79 {
80   struct cmd_roc roc ;
81   const struct dictionary *dict = dataset_dict (ds);
82
83   roc.vars = NULL;
84   roc.n_vars = 0;
85   roc.print_se = false;
86   roc.print_coords = false;
87   roc.exclude = MV_ANY;
88   roc.curve = true;
89   roc.reference = false;
90   roc.ci = 95;
91   roc.bi_neg_exp = false;
92   roc.invert = false;
93   roc.pos = roc.pos_weighted = 0;
94   roc.neg = roc.neg_weighted = 0;
95   roc.dict = dataset_dict (ds);
96
97   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
98                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
99     goto error;
100
101   if ( ! lex_force_match (lexer, T_BY))
102     {
103       goto error;
104     }
105
106   roc.state_var = parse_variable (lexer, dict);
107
108   if ( !lex_force_match (lexer, '('))
109     {
110       goto error;
111     }
112
113   value_init (&roc.state_value, var_get_width (roc.state_var));
114   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
115
116
117   if ( !lex_force_match (lexer, ')'))
118     {
119       goto error;
120     }
121
122
123   while (lex_token (lexer) != '.')
124     {
125       lex_match (lexer, '/');
126       if (lex_match_id (lexer, "MISSING"))
127         {
128           lex_match (lexer, '=');
129           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
130             {
131               if (lex_match_id (lexer, "INCLUDE"))
132                 {
133                   roc.exclude = MV_SYSTEM;
134                 }
135               else if (lex_match_id (lexer, "EXCLUDE"))
136                 {
137                   roc.exclude = MV_ANY;
138                 }
139               else
140                 {
141                   lex_error (lexer, NULL);
142                   goto error;
143                 }
144             }
145         }
146       else if (lex_match_id (lexer, "PLOT"))
147         {
148           lex_match (lexer, '=');
149           if (lex_match_id (lexer, "CURVE"))
150             {
151               roc.curve = true;
152               if (lex_match (lexer, '('))
153                 {
154                   roc.reference = true;
155                   lex_force_match_id (lexer, "REFERENCE");
156                   lex_force_match (lexer, ')');
157                 }
158             }
159           else if (lex_match_id (lexer, "NONE"))
160             {
161               roc.curve = false;
162             }
163           else
164             {
165               lex_error (lexer, NULL);
166               goto error;
167             }
168         }
169       else if (lex_match_id (lexer, "PRINT"))
170         {
171           lex_match (lexer, '=');
172           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
173             {
174               if (lex_match_id (lexer, "SE"))
175                 {
176                   roc.print_se = true;
177                 }
178               else if (lex_match_id (lexer, "COORDINATES"))
179                 {
180                   roc.print_coords = true;
181                 }
182               else
183                 {
184                   lex_error (lexer, NULL);
185                   goto error;
186                 }
187             }
188         }
189       else if (lex_match_id (lexer, "CRITERIA"))
190         {
191           lex_match (lexer, '=');
192           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
193             {
194               if (lex_match_id (lexer, "CUTOFF"))
195                 {
196                   lex_force_match (lexer, '(');
197                   if (lex_match_id (lexer, "INCLUDE"))
198                     {
199                       roc.exclude = MV_SYSTEM;
200                     }
201                   else if (lex_match_id (lexer, "EXCLUDE"))
202                     {
203                       roc.exclude = MV_USER | MV_SYSTEM;
204                     }
205                   else
206                     {
207                       lex_error (lexer, NULL);
208                       goto error;
209                     }
210                   lex_force_match (lexer, ')');
211                 }
212               else if (lex_match_id (lexer, "TESTPOS"))
213                 {
214                   lex_force_match (lexer, '(');
215                   if (lex_match_id (lexer, "LARGE"))
216                     {
217                       roc.invert = false;
218                     }
219                   else if (lex_match_id (lexer, "SMALL"))
220                     {
221                       roc.invert = true;
222                     }
223                   else
224                     {
225                       lex_error (lexer, NULL);
226                       goto error;
227                     }
228                   lex_force_match (lexer, ')');
229                 }
230               else if (lex_match_id (lexer, "CI"))
231                 {
232                   lex_force_match (lexer, '(');
233                   lex_force_num (lexer);
234                   roc.ci = lex_number (lexer);
235                   lex_get (lexer);
236                   lex_force_match (lexer, ')');
237                 }
238               else if (lex_match_id (lexer, "DISTRIBUTION"))
239                 {
240                   lex_force_match (lexer, '(');
241                   if (lex_match_id (lexer, "FREE"))
242                     {
243                       roc.bi_neg_exp = false;
244                     }
245                   else if (lex_match_id (lexer, "NEGEXPO"))
246                     {
247                       roc.bi_neg_exp = true;
248                     }
249                   else
250                     {
251                       lex_error (lexer, NULL);
252                       goto error;
253                     }
254                   lex_force_match (lexer, ')');
255                 }
256               else
257                 {
258                   lex_error (lexer, NULL);
259                   goto error;
260                 }
261             }
262         }
263       else
264         {
265           lex_error (lexer, NULL);
266           break;
267         }
268     }
269
270   if ( ! run_roc (ds, &roc)) 
271     goto error;
272
273   value_destroy (&roc.state_value, var_get_width (roc.state_var));
274   free (roc.vars);
275   return CMD_SUCCESS;
276
277  error:
278   value_destroy (&roc.state_value, var_get_width (roc.state_var));
279   free (roc.vars);
280   return CMD_FAILURE;
281 }
282
283
284
285
286 static void
287 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
288
289
290 static int
291 run_roc (struct dataset *ds, struct cmd_roc *roc)
292 {
293   struct dictionary *dict = dataset_dict (ds);
294   bool ok;
295   struct casereader *group;
296
297   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
298   while (casegrouper_get_next_group (grouper, &group))
299     {
300       do_roc (roc, group, dataset_dict (ds));
301     }
302   ok = casegrouper_destroy (grouper);
303   ok = proc_commit (ds) && ok;
304
305   return ok;
306 }
307
308 #if 0
309 static void
310 dump_casereader (struct casereader *reader)
311 {
312   struct ccase *c;
313   struct casereader *r = casereader_clone (reader);
314
315   for ( ; (c = casereader_read (r) ); case_unref (c))
316     {
317       int i;
318       for (i = 0 ; i < case_get_value_cnt (c); ++i)
319         {
320           printf ("%g ", case_data_idx (c, i)->f);
321         }
322       printf ("\n");
323     }
324
325   casereader_destroy (r);
326 }
327 #endif
328
329
330 /* 
331    Return true iff the state variable indicates that C has positive actual state.
332
333    As a side effect, this function also accumulates the roc->{pos,neg} and 
334    roc->{pos,neg}_weighted counts.
335  */
336 static bool
337 match_positives (const struct ccase *c, void *aux)
338 {
339   struct cmd_roc *roc = aux;
340   const struct variable *wv = dict_get_weight (roc->dict);
341   const double weight = wv ? case_data (c, wv)->f : 1.0;
342
343   const bool positive =
344   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
345     var_get_width (roc->state_var)));
346
347   if ( positive )
348     {
349       roc->pos++;
350       roc->pos_weighted += weight;
351     }
352   else
353     {
354       roc->neg++;
355       roc->neg_weighted += weight;
356     }
357
358   return positive;
359 }
360
361
362 #define VALUE  0
363 #define N_EQ   1
364 #define N_PRED 2
365
366 /* Some intermediate state for calculating the cutpoints and the 
367    standard error values */
368 struct roc_state
369 {
370   double auc;  /* Area under the curve */
371
372   double n1;  /* total weight of positives */
373   double n2;  /* total weight of negatives */
374
375   /* intermediates for standard error */
376   double q1hat; 
377   double q2hat;
378
379   /* intermediates for cutpoints */
380   struct casewriter *cutpoint_wtr;
381   struct casereader *cutpoint_rdr;
382   double prev_result;
383   double min;
384   double max;
385 };
386
387 /* 
388    Return a new casereader based upon CUTPOINT_RDR.
389    The number of "positive" cases are placed into
390    the position TRUE_INDEX, and the number of "negative" cases
391    into FALSE_INDEX.
392    POS_COND and RESULT determine the semantics of what is 
393    "positive".
394    WEIGHT is the value of a single count.
395  */
396 static struct casereader *
397 accumulate_counts (struct casereader *cutpoint_rdr, 
398                    double result, double weight, 
399                    bool (*pos_cond) (double, double),
400                    int true_index, int false_index)
401 {
402   const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
403   struct casewriter *w =
404     autopaging_writer_create (proto);
405   struct casereader *r = casereader_clone (cutpoint_rdr);
406   struct ccase *cpc;
407   double prev_cp = SYSMIS;
408
409   for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
410     {
411       struct ccase *new_case;
412       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
413
414       assert (cp != SYSMIS);
415
416       /* We don't want duplicates here */
417       if ( cp == prev_cp )
418         continue;
419
420       new_case = case_clone (cpc);
421
422       if ( pos_cond (result, cp))
423         case_data_rw_idx (new_case, true_index)->f += weight;
424       else
425         case_data_rw_idx (new_case, false_index)->f += weight;
426
427       prev_cp = cp;
428
429       casewriter_write (w, new_case);
430     }
431   casereader_destroy (r);
432
433   return casewriter_make_reader (w);
434 }
435
436
437
438 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
439
440 /*
441   This function does 3 things:
442
443   1. Counts the number of cases which are equal to every other case in READER,
444   and those cases for which the relationship between it and every other case
445   satifies PRED (normally either > or <).  VAR is variable defining a case's value
446   for this purpose.
447
448   2. Counts the number of true and false cases in reader, and populates
449   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
450   which receive these values.  POS_COND is the condition defining true
451   and false.
452   
453   3. CC is filled with the cumulative weight of all cases of READER.
454 */
455 static struct casereader *
456 process_group (const struct variable *var, struct casereader *reader,
457                bool (*pred) (double, double),
458                const struct dictionary *dict,
459                double *cc,
460                struct casereader **cutpoint_rdr, 
461                bool (*pos_cond) (double, double),
462                int true_index,
463                int false_index)
464 {
465   const struct variable *w = dict_get_weight (dict);
466
467   struct casereader *r1 =
468     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
469
470   const int weight_idx  = w ? var_get_case_index (w) :
471     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
472   
473   struct ccase *c1;
474
475   struct casereader *rclone = casereader_clone (r1);
476   struct casewriter *wtr;
477   struct caseproto *proto = caseproto_create ();
478
479   proto = caseproto_add_width (proto, 0);
480   proto = caseproto_add_width (proto, 0);
481   proto = caseproto_add_width (proto, 0);
482
483   wtr = autopaging_writer_create (proto);  
484
485   *cc = 0;
486
487   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
488     {
489       struct ccase *new_case = case_create (proto);
490       struct ccase *c2;
491       struct casereader *r2 = casereader_clone (rclone);
492
493       const double weight1 = case_data_idx (c1, weight_idx)->f;
494       const double d1 = case_data (c1, var)->f;
495       double n_eq = 0.0;
496       double n_pred = 0.0;
497
498       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
499                                          pos_cond,
500                                          true_index, false_index);
501
502       *cc += weight1;
503
504       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
505         {
506           const double d2 = case_data (c2, var)->f;
507           const double weight2 = case_data_idx (c2, weight_idx)->f;
508
509           if ( d1 == d2 )
510             {
511               n_eq += weight2;
512               continue;
513             }
514           else  if ( pred (d2, d1))
515             {
516               n_pred += weight2;
517             }
518         }
519
520       case_data_rw_idx (new_case, VALUE)->f = d1;
521       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
522       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
523
524       casewriter_write (wtr, new_case);
525
526       casereader_destroy (r2);
527     }
528
529   casereader_destroy (r1);
530   casereader_destroy (rclone);
531
532   return casewriter_make_reader (wtr);
533 }
534
535 /* Some more indeces into case data */
536 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
537 #define N_POS_GT 2  /* number of postive cases with values greater than n */
538 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
539 #define N_NEG_LT 4  /* number of negative cases with values less than n */
540
541 static bool
542 gt (double d1, double d2)
543 {
544   return d1 > d2;
545 }
546
547
548 static bool
549 ge (double d1, double d2)
550 {
551   return d1 > d2;
552 }
553
554 static bool
555 lt (double d1, double d2)
556 {
557   return d1 < d2;
558 }
559
560
561 /*
562   Return a casereader with width 3,
563   populated with cases based upon READER.
564   The cases will have the values:
565   (N, number of cases equal to N, number of cases greater than N)
566   As a side effect, update RS->n1 with the number of positive cases.
567 */
568 static struct casereader *
569 process_positive_group (const struct variable *var, struct casereader *reader,
570                         const struct dictionary *dict,
571                         struct roc_state *rs)
572 {
573   return process_group (var, reader, gt, dict, &rs->n1,
574                         &rs->cutpoint_rdr,
575                         ge,
576                         ROC_TP, ROC_FN);
577 }
578
579 /*
580   Return a casereader with width 3,
581   populated with cases based upon READER.
582   The cases will have the values:
583   (N, number of cases equal to N, number of cases less than N)
584   As a side effect, update RS->n2 with the number of negative cases.
585 */
586 static struct casereader *
587 process_negative_group (const struct variable *var, struct casereader *reader,
588                         const struct dictionary *dict,
589                         struct roc_state *rs)
590 {
591   return process_group (var, reader, lt, dict, &rs->n2,
592                         &rs->cutpoint_rdr,
593                         lt,
594                         ROC_TN, ROC_FP);
595 }
596
597
598
599
600 static void
601 append_cutpoint (struct casewriter *writer, double cutpoint)
602 {
603   struct ccase *cc = case_create (casewriter_get_proto (writer));
604
605   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
606   case_data_rw_idx (cc, ROC_TP)->f = 0;
607   case_data_rw_idx (cc, ROC_FN)->f = 0;
608   case_data_rw_idx (cc, ROC_TN)->f = 0;
609   case_data_rw_idx (cc, ROC_FP)->f = 0;
610
611   casewriter_write (writer, cc);
612 }
613
614
615 /* 
616    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
617    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
618    reader will be populated with its final number of cases.
619    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
620    value.  The other entries will be initialised to zero.
621 */
622 static void
623 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
624 {
625   int i;
626   struct casereader *r = casereader_clone (input);
627   struct ccase *c;
628   struct caseproto *proto = caseproto_create ();
629
630   struct subcase ordering;
631   subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
632
633   proto = caseproto_add_width (proto, 0); /* cutpoint */
634   proto = caseproto_add_width (proto, 0); /* ROC_TP */
635   proto = caseproto_add_width (proto, 0); /* ROC_FN */
636   proto = caseproto_add_width (proto, 0); /* ROC_TN */
637   proto = caseproto_add_width (proto, 0); /* ROC_FP */
638
639   for (i = 0 ; i < roc->n_vars; ++i)
640     {
641       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
642       rs[i].prev_result = SYSMIS;
643       rs[i].max = -DBL_MAX;
644       rs[i].min = DBL_MAX;
645     }
646
647   for (; (c = casereader_read (r)) != NULL; case_unref (c))
648     {
649       for (i = 0 ; i < roc->n_vars; ++i)
650         {
651           const union value *v = case_data (c, roc->vars[i]); 
652           const double result = v->f;
653
654           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
655             continue;
656
657           minimize (&rs[i].min, result);
658           maximize (&rs[i].max, result);
659
660           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
661             {
662               const double mean = (result + rs[i].prev_result ) / 2.0;
663               append_cutpoint (rs[i].cutpoint_wtr, mean);
664             }
665
666           rs[i].prev_result = result;
667         }
668     }
669   casereader_destroy (r);
670
671
672   /* Append the min and max cutpoints */
673   for (i = 0 ; i < roc->n_vars; ++i)
674     {
675       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
676       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
677
678       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
679     }
680 }
681
682 static void
683 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
684 {
685   int i;
686
687   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
688
689   struct casereader *negatives = NULL;
690   struct casereader *positives = NULL;
691
692   struct caseproto *n_proto = caseproto_create ();
693
694   struct subcase up_ordering;
695   struct subcase down_ordering;
696
697   struct casewriter *neg_wtr = NULL;
698
699   struct casereader *input = casereader_create_filter_missing (reader,
700                                                                roc->vars, roc->n_vars,
701                                                                roc->exclude,
702                                                                NULL,
703                                                                NULL);
704
705   input = casereader_create_filter_missing (input,
706                                             &roc->state_var, 1,
707                                             roc->exclude,
708                                             NULL,
709                                             NULL);
710
711   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
712
713   prepare_cutpoints (roc, rs, input);
714
715
716   /* Separate the positive actual state cases from the negative ones */
717   positives = 
718     casereader_create_filter_func (input,
719                                    match_positives,
720                                    NULL,
721                                    roc,
722                                    neg_wtr);
723
724   n_proto = caseproto_create ();
725       
726   n_proto = caseproto_add_width (n_proto, 0);
727   n_proto = caseproto_add_width (n_proto, 0);
728   n_proto = caseproto_add_width (n_proto, 0);
729   n_proto = caseproto_add_width (n_proto, 0);
730   n_proto = caseproto_add_width (n_proto, 0);
731
732   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
733   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
734
735   for (i = 0 ; i < roc->n_vars; ++i)
736     {
737       struct casewriter *w = NULL;
738       struct casereader *r = NULL;
739
740       struct ccase *c;
741
742       struct ccase *cpos;
743       struct casereader *n_neg ;
744       const struct variable *var = roc->vars[i];
745
746       struct casereader *neg ;
747       struct casereader *pos = casereader_clone (positives);
748
749
750       struct casereader *n_pos =
751         process_positive_group (var, pos, dict, &rs[i]);
752
753       if ( negatives == NULL)
754         {
755           negatives = casewriter_make_reader (neg_wtr);
756         }
757
758       neg = casereader_clone (negatives);
759
760       n_neg = process_negative_group (var, neg, dict, &rs[i]);
761
762
763       /* Merge the n_pos and n_neg casereaders */
764       w = sort_create_writer (&up_ordering, n_proto);
765       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
766         {
767           struct ccase *pos_case = case_create (n_proto);
768           struct ccase *cneg;
769           const double jpos = case_data_idx (cpos, VALUE)->f;
770
771           while ((cneg = casereader_read (n_neg)))
772             {
773               struct ccase *nc = case_create (n_proto);
774
775               const double jneg = case_data_idx (cneg, VALUE)->f;
776
777               case_data_rw_idx (nc, VALUE)->f = jneg;
778               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
779
780               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
781
782               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
783               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
784
785               casewriter_write (w, nc);
786
787               case_unref (cneg);
788               if ( jneg > jpos)
789                 break;
790             }
791
792           case_data_rw_idx (pos_case, VALUE)->f = jpos;
793           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
794           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
795           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
796           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
797
798           casewriter_write (w, pos_case);
799         }
800
801 /* These aren't used anymore */
802 #undef N_EQ
803 #undef N_PRED
804
805       r = casewriter_make_reader (w);
806
807       /* Propagate the N_POS_GT values from the positive cases
808          to the negative ones */
809       {
810         double prev_pos_gt = rs[i].n1;
811         w = sort_create_writer (&down_ordering, n_proto);
812
813         for ( ; (c = casereader_read (r) ); case_unref (c))
814           {
815             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
816             struct ccase *nc = case_clone (c);
817
818             if ( n_pos_gt == SYSMIS)
819               {
820                 n_pos_gt = prev_pos_gt;
821                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
822               }
823             
824             casewriter_write (w, nc);
825             prev_pos_gt = n_pos_gt;
826           }
827
828         r = casewriter_make_reader (w);
829       }
830
831       /* Propagate the N_NEG_LT values from the negative cases
832          to the positive ones */
833       {
834         double prev_neg_lt = rs[i].n2;
835         w = sort_create_writer (&up_ordering, n_proto);
836
837         for ( ; (c = casereader_read (r) ); case_unref (c))
838           {
839             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
840             struct ccase *nc = case_clone (c);
841
842             if ( n_neg_lt == SYSMIS)
843               {
844                 n_neg_lt = prev_neg_lt;
845                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
846               }
847             
848             casewriter_write (w, nc);
849             prev_neg_lt = n_neg_lt;
850           }
851
852         r = casewriter_make_reader (w);
853       }
854
855       {
856         struct ccase *prev_case = NULL;
857         for ( ; (c = casereader_read (r) ); case_unref (c))
858           {
859             const struct ccase *next_case = casereader_peek (r, 0);
860
861             const double j = case_data_idx (c, VALUE)->f;
862             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
863             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
864             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
865             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
866
867             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
868               {
869                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
870                   {
871                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
872                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
873                   }
874
875                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
876                   {
877                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
878                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
879                   }
880               }
881
882             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
883               {
884                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
885
886                 rs[i].q1hat +=
887                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
888                 rs[i].q2hat +=
889                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
890
891               }
892
893             case_unref (prev_case);
894             prev_case = case_clone (c);
895           }
896
897         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
898         if ( roc->invert ) 
899           rs[i].auc = 1 - rs[i].auc;
900
901         if ( roc->bi_neg_exp )
902           {
903             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
904             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
905           }
906         else
907           {
908             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
909             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
910           }
911       }
912     }
913
914   casereader_destroy (positives);
915   casereader_destroy (negatives);
916
917   output_roc (rs, roc);
918
919   free (rs);
920 }
921
922 static void
923 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
924 {
925   int i;
926   const int n_fields = roc->print_se ? 5 : 1;
927   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
928   const int n_rows = 2 + roc->n_vars;
929   struct tab_table *tbl = tab_create (n_cols, n_rows);
930
931   if ( roc->n_vars > 1)
932     tab_title (tbl, _("Area Under the Curve"));
933   else
934     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
935
936   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
937
938
939   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
940
941   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
942
943   tab_box (tbl,
944            TAL_2, TAL_2,
945            -1, TAL_1,
946            0, 0,
947            n_cols - 1,
948            n_rows - 1);
949
950   if ( roc->print_se )
951     {
952       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
953       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
954
955       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
956       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
957
958       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
959                              TAT_TITLE | TAB_CENTER,
960                              _("Asymp. %g%% Confidence Interval"), roc->ci);
961       tab_vline (tbl, 0, n_cols - 1, 0, 0);
962       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
963     }
964
965   if ( roc->n_vars > 1)
966     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
967
968   if ( roc->n_vars > 1)
969     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
970
971
972   for ( i = 0 ; i < roc->n_vars ; ++i )
973     {
974       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
975
976       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
977
978       if ( roc->print_se )
979         {
980           double se ;
981           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
982                                       (12 * rs[i].n1 * rs[i].n2));
983           double ci ;
984           double yy ;
985
986           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
987             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
988
989           se /= rs[i].n1 * rs[i].n2;
990
991           se = sqrt (se);
992
993           tab_double (tbl, n_cols - 4, 2 + i, 0,
994                       se,
995                       NULL);
996
997           ci = 1 - roc->ci / 100.0;
998           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
999
1000           tab_double (tbl, n_cols - 2, 2 + i, 0,
1001                       rs[i].auc - yy,
1002                       NULL);
1003
1004           tab_double (tbl, n_cols - 1, 2 + i, 0,
1005                       rs[i].auc + yy,
1006                       NULL);
1007
1008           tab_double (tbl, n_cols - 3, 2 + i, 0,
1009                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1010                       NULL);
1011         }
1012     }
1013
1014   tab_submit (tbl);
1015 }
1016
1017
1018 static void
1019 show_summary (const struct cmd_roc *roc)
1020 {
1021   const int n_cols = 3;
1022   const int n_rows = 4;
1023   struct tab_table *tbl = tab_create (n_cols, n_rows);
1024
1025   tab_title (tbl, _("Case Summary"));
1026
1027   tab_headers (tbl, 1, 0, 2, 0);
1028
1029   tab_box (tbl,
1030            TAL_2, TAL_2,
1031            -1, -1,
1032            0, 0,
1033            n_cols - 1,
1034            n_rows - 1);
1035
1036   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1037   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1038
1039
1040   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1041   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1042
1043
1044   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1045   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1046   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1047
1048   tab_joint_text (tbl, 1, 0, 2, 0,
1049                   TAT_TITLE | TAB_CENTER,
1050                   _("Valid N (listwise)"));
1051
1052
1053   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1054   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1055
1056
1057   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1058   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1059
1060   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1061   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1062
1063   tab_submit (tbl);
1064 }
1065
1066
1067 static void
1068 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1069 {
1070   int x = 1;
1071   int i;
1072   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1073   int n_rows = 1;
1074   struct tab_table *tbl ;
1075
1076   for (i = 0; i < roc->n_vars; ++i)
1077     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1078
1079   tbl = tab_create (n_cols, n_rows);
1080
1081   if ( roc->n_vars > 1)
1082     tab_title (tbl, _("Coordinates of the Curve"));
1083   else
1084     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1085
1086
1087   tab_headers (tbl, 1, 0, 1, 0);
1088
1089   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1090
1091   if ( roc->n_vars > 1)
1092     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1093
1094   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1095   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1096   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1097
1098   tab_box (tbl,
1099            TAL_2, TAL_2,
1100            -1, TAL_1,
1101            0, 0,
1102            n_cols - 1,
1103            n_rows - 1);
1104
1105   if ( roc->n_vars > 1)
1106     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1107
1108   for (i = 0; i < roc->n_vars; ++i)
1109     {
1110       struct ccase *cc;
1111       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1112
1113       if ( roc->n_vars > 1)
1114         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1115
1116       if ( i > 0)
1117         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1118
1119
1120       for (; (cc = casereader_read (r)) != NULL;
1121            case_unref (cc), x++)
1122         {
1123           const double se = case_data_idx (cc, ROC_TP)->f /
1124             (
1125              case_data_idx (cc, ROC_TP)->f
1126              +
1127              case_data_idx (cc, ROC_FN)->f
1128              );
1129
1130           const double sp = case_data_idx (cc, ROC_TN)->f /
1131             (
1132              case_data_idx (cc, ROC_TN)->f
1133              +
1134              case_data_idx (cc, ROC_FP)->f
1135              );
1136
1137           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1138                       var_get_print_format (roc->vars[i]));
1139
1140           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1141           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1142         }
1143
1144       casereader_destroy (r);
1145     }
1146
1147   tab_submit (tbl);
1148 }
1149
1150
1151 static void
1152 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1153 {
1154   show_summary (roc);
1155
1156   if ( roc->curve )
1157     {
1158       struct roc_chart *rc;
1159       size_t i;
1160
1161       rc = roc_chart_create (roc->reference);
1162       for (i = 0; i < roc->n_vars; i++)
1163         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1164                            rs[i].cutpoint_rdr);
1165       roc_chart_submit (rc);
1166     }
1167
1168   show_auc (rs, roc);
1169
1170   if ( roc->print_coords )
1171     show_coords (rs, roc);
1172 }
1173