Merge branch 'master' into output
[pspp-builds.git] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <language/stats/roc.h>
20
21 #include <data/procedure.h>
22 #include <language/lexer/variable-parser.h>
23 #include <language/lexer/value-parser.h>
24 #include <language/command.h>
25 #include <language/lexer/lexer.h>
26
27 #include <data/casegrouper.h>
28 #include <data/casereader.h>
29 #include <data/casewriter.h>
30 #include <data/dictionary.h>
31 #include <data/format.h>
32 #include <math/sort.h>
33 #include <data/subcase.h>
34
35
36 #include <libpspp/misc.h>
37
38 #include <gsl/gsl_cdf.h>
39 #include <output/table.h>
40
41 #include <output/chart.h>
42 #include <output/charts/roc-chart.h>
43
44 #include "gettext.h"
45 #define _(msgid) gettext (msgid)
46 #define N_(msgid) msgid
47
48 struct cmd_roc
49 {
50   size_t n_vars;
51   const struct variable **vars;
52   const struct dictionary *dict;
53
54   const struct variable *state_var ;
55   union value state_value;
56
57   /* Plot the roc curve */
58   bool curve;
59   /* Plot the reference line */
60   bool reference;
61
62   double ci;
63
64   bool print_coords;
65   bool print_se;
66   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
67                       should be used */
68   enum mv_class exclude;
69
70   bool invert ; /* True iff a smaller test result variable indicates
71                    a positive result */
72
73   double pos;
74   double neg;
75   double pos_weighted;
76   double neg_weighted;
77 };
78
79 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
80
81 int
82 cmd_roc (struct lexer *lexer, struct dataset *ds)
83 {
84   struct cmd_roc roc ;
85   const struct dictionary *dict = dataset_dict (ds);
86
87   roc.vars = NULL;
88   roc.n_vars = 0;
89   roc.print_se = false;
90   roc.print_coords = false;
91   roc.exclude = MV_ANY;
92   roc.curve = true;
93   roc.reference = false;
94   roc.ci = 95;
95   roc.bi_neg_exp = false;
96   roc.invert = false;
97   roc.pos = roc.pos_weighted = 0;
98   roc.neg = roc.neg_weighted = 0;
99   roc.dict = dataset_dict (ds);
100
101   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
102                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
103     goto error;
104
105   if ( ! lex_force_match (lexer, T_BY))
106     {
107       goto error;
108     }
109
110   roc.state_var = parse_variable (lexer, dict);
111
112   if ( !lex_force_match (lexer, '('))
113     {
114       goto error;
115     }
116
117   value_init (&roc.state_value, var_get_width (roc.state_var));
118   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
119
120
121   if ( !lex_force_match (lexer, ')'))
122     {
123       goto error;
124     }
125
126
127   while (lex_token (lexer) != '.')
128     {
129       lex_match (lexer, '/');
130       if (lex_match_id (lexer, "MISSING"))
131         {
132           lex_match (lexer, '=');
133           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
134             {
135               if (lex_match_id (lexer, "INCLUDE"))
136                 {
137                   roc.exclude = MV_SYSTEM;
138                 }
139               else if (lex_match_id (lexer, "EXCLUDE"))
140                 {
141                   roc.exclude = MV_ANY;
142                 }
143               else
144                 {
145                   lex_error (lexer, NULL);
146                   goto error;
147                 }
148             }
149         }
150       else if (lex_match_id (lexer, "PLOT"))
151         {
152           lex_match (lexer, '=');
153           if (lex_match_id (lexer, "CURVE"))
154             {
155               roc.curve = true;
156               if (lex_match (lexer, '('))
157                 {
158                   roc.reference = true;
159                   lex_force_match_id (lexer, "REFERENCE");
160                   lex_force_match (lexer, ')');
161                 }
162             }
163           else if (lex_match_id (lexer, "NONE"))
164             {
165               roc.curve = false;
166             }
167           else
168             {
169               lex_error (lexer, NULL);
170               goto error;
171             }
172         }
173       else if (lex_match_id (lexer, "PRINT"))
174         {
175           lex_match (lexer, '=');
176           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
177             {
178               if (lex_match_id (lexer, "SE"))
179                 {
180                   roc.print_se = true;
181                 }
182               else if (lex_match_id (lexer, "COORDINATES"))
183                 {
184                   roc.print_coords = true;
185                 }
186               else
187                 {
188                   lex_error (lexer, NULL);
189                   goto error;
190                 }
191             }
192         }
193       else if (lex_match_id (lexer, "CRITERIA"))
194         {
195           lex_match (lexer, '=');
196           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
197             {
198               if (lex_match_id (lexer, "CUTOFF"))
199                 {
200                   lex_force_match (lexer, '(');
201                   if (lex_match_id (lexer, "INCLUDE"))
202                     {
203                       roc.exclude = MV_SYSTEM;
204                     }
205                   else if (lex_match_id (lexer, "EXCLUDE"))
206                     {
207                       roc.exclude = MV_USER | MV_SYSTEM;
208                     }
209                   else
210                     {
211                       lex_error (lexer, NULL);
212                       goto error;
213                     }
214                   lex_force_match (lexer, ')');
215                 }
216               else if (lex_match_id (lexer, "TESTPOS"))
217                 {
218                   lex_force_match (lexer, '(');
219                   if (lex_match_id (lexer, "LARGE"))
220                     {
221                       roc.invert = false;
222                     }
223                   else if (lex_match_id (lexer, "SMALL"))
224                     {
225                       roc.invert = true;
226                     }
227                   else
228                     {
229                       lex_error (lexer, NULL);
230                       goto error;
231                     }
232                   lex_force_match (lexer, ')');
233                 }
234               else if (lex_match_id (lexer, "CI"))
235                 {
236                   lex_force_match (lexer, '(');
237                   lex_force_num (lexer);
238                   roc.ci = lex_number (lexer);
239                   lex_get (lexer);
240                   lex_force_match (lexer, ')');
241                 }
242               else if (lex_match_id (lexer, "DISTRIBUTION"))
243                 {
244                   lex_force_match (lexer, '(');
245                   if (lex_match_id (lexer, "FREE"))
246                     {
247                       roc.bi_neg_exp = false;
248                     }
249                   else if (lex_match_id (lexer, "NEGEXPO"))
250                     {
251                       roc.bi_neg_exp = true;
252                     }
253                   else
254                     {
255                       lex_error (lexer, NULL);
256                       goto error;
257                     }
258                   lex_force_match (lexer, ')');
259                 }
260               else
261                 {
262                   lex_error (lexer, NULL);
263                   goto error;
264                 }
265             }
266         }
267       else
268         {
269           lex_error (lexer, NULL);
270           break;
271         }
272     }
273
274   if ( ! run_roc (ds, &roc)) 
275     goto error;
276
277   value_destroy (&roc.state_value, var_get_width (roc.state_var));
278   free (roc.vars);
279   return CMD_SUCCESS;
280
281  error:
282   value_destroy (&roc.state_value, var_get_width (roc.state_var));
283   free (roc.vars);
284   return CMD_FAILURE;
285 }
286
287
288
289
290 static void
291 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
292
293
294 static int
295 run_roc (struct dataset *ds, struct cmd_roc *roc)
296 {
297   struct dictionary *dict = dataset_dict (ds);
298   bool ok;
299   struct casereader *group;
300
301   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
302   while (casegrouper_get_next_group (grouper, &group))
303     {
304       do_roc (roc, group, dataset_dict (ds));
305     }
306   ok = casegrouper_destroy (grouper);
307   ok = proc_commit (ds) && ok;
308
309   return ok;
310 }
311
312 #if 0
313 static void
314 dump_casereader (struct casereader *reader)
315 {
316   struct ccase *c;
317   struct casereader *r = casereader_clone (reader);
318
319   for ( ; (c = casereader_read (r) ); case_unref (c))
320     {
321       int i;
322       for (i = 0 ; i < case_get_value_cnt (c); ++i)
323         {
324           printf ("%g ", case_data_idx (c, i)->f);
325         }
326       printf ("\n");
327     }
328
329   casereader_destroy (r);
330 }
331 #endif
332
333
334 /* 
335    Return true iff the state variable indicates that C has positive actual state.
336
337    As a side effect, this function also accumulates the roc->{pos,neg} and 
338    roc->{pos,neg}_weighted counts.
339  */
340 static bool
341 match_positives (const struct ccase *c, void *aux)
342 {
343   struct cmd_roc *roc = aux;
344   const struct variable *wv = dict_get_weight (roc->dict);
345   const double weight = wv ? case_data (c, wv)->f : 1.0;
346
347   const bool positive =
348   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
349     var_get_width (roc->state_var)));
350
351   if ( positive )
352     {
353       roc->pos++;
354       roc->pos_weighted += weight;
355     }
356   else
357     {
358       roc->neg++;
359       roc->neg_weighted += weight;
360     }
361
362   return positive;
363 }
364
365
366 #define VALUE  0
367 #define N_EQ   1
368 #define N_PRED 2
369
370 /* Some intermediate state for calculating the cutpoints and the 
371    standard error values */
372 struct roc_state
373 {
374   double auc;  /* Area under the curve */
375
376   double n1;  /* total weight of positives */
377   double n2;  /* total weight of negatives */
378
379   /* intermediates for standard error */
380   double q1hat; 
381   double q2hat;
382
383   /* intermediates for cutpoints */
384   struct casewriter *cutpoint_wtr;
385   struct casereader *cutpoint_rdr;
386   double prev_result;
387   double min;
388   double max;
389 };
390
391 /* 
392    Return a new casereader based upon CUTPOINT_RDR.
393    The number of "positive" cases are placed into
394    the position TRUE_INDEX, and the number of "negative" cases
395    into FALSE_INDEX.
396    POS_COND and RESULT determine the semantics of what is 
397    "positive".
398    WEIGHT is the value of a single count.
399  */
400 static struct casereader *
401 accumulate_counts (struct casereader *cutpoint_rdr, 
402                    double result, double weight, 
403                    bool (*pos_cond) (double, double),
404                    int true_index, int false_index)
405 {
406   const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
407   struct casewriter *w =
408     autopaging_writer_create (proto);
409   struct casereader *r = casereader_clone (cutpoint_rdr);
410   struct ccase *cpc;
411   double prev_cp = SYSMIS;
412
413   for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
414     {
415       struct ccase *new_case;
416       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
417
418       assert (cp != SYSMIS);
419
420       /* We don't want duplicates here */
421       if ( cp == prev_cp )
422         continue;
423
424       new_case = case_clone (cpc);
425
426       if ( pos_cond (result, cp))
427         case_data_rw_idx (new_case, true_index)->f += weight;
428       else
429         case_data_rw_idx (new_case, false_index)->f += weight;
430
431       prev_cp = cp;
432
433       casewriter_write (w, new_case);
434     }
435   casereader_destroy (r);
436
437   return casewriter_make_reader (w);
438 }
439
440
441
442 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
443
444 /*
445   This function does 3 things:
446
447   1. Counts the number of cases which are equal to every other case in READER,
448   and those cases for which the relationship between it and every other case
449   satifies PRED (normally either > or <).  VAR is variable defining a case's value
450   for this purpose.
451
452   2. Counts the number of true and false cases in reader, and populates
453   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
454   which receive these values.  POS_COND is the condition defining true
455   and false.
456   
457   3. CC is filled with the cumulative weight of all cases of READER.
458 */
459 static struct casereader *
460 process_group (const struct variable *var, struct casereader *reader,
461                bool (*pred) (double, double),
462                const struct dictionary *dict,
463                double *cc,
464                struct casereader **cutpoint_rdr, 
465                bool (*pos_cond) (double, double),
466                int true_index,
467                int false_index)
468 {
469   const struct variable *w = dict_get_weight (dict);
470
471   struct casereader *r1 =
472     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
473
474   const int weight_idx  = w ? var_get_case_index (w) :
475     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
476   
477   struct ccase *c1;
478
479   struct casereader *rclone = casereader_clone (r1);
480   struct casewriter *wtr;
481   struct caseproto *proto = caseproto_create ();
482
483   proto = caseproto_add_width (proto, 0);
484   proto = caseproto_add_width (proto, 0);
485   proto = caseproto_add_width (proto, 0);
486
487   wtr = autopaging_writer_create (proto);  
488
489   *cc = 0;
490
491   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
492     {
493       struct ccase *new_case = case_create (proto);
494       struct ccase *c2;
495       struct casereader *r2 = casereader_clone (rclone);
496
497       const double weight1 = case_data_idx (c1, weight_idx)->f;
498       const double d1 = case_data (c1, var)->f;
499       double n_eq = 0.0;
500       double n_pred = 0.0;
501
502       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
503                                          pos_cond,
504                                          true_index, false_index);
505
506       *cc += weight1;
507
508       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
509         {
510           const double d2 = case_data (c2, var)->f;
511           const double weight2 = case_data_idx (c2, weight_idx)->f;
512
513           if ( d1 == d2 )
514             {
515               n_eq += weight2;
516               continue;
517             }
518           else  if ( pred (d2, d1))
519             {
520               n_pred += weight2;
521             }
522         }
523
524       case_data_rw_idx (new_case, VALUE)->f = d1;
525       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
526       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
527
528       casewriter_write (wtr, new_case);
529
530       casereader_destroy (r2);
531     }
532
533   casereader_destroy (r1);
534   casereader_destroy (rclone);
535
536   return casewriter_make_reader (wtr);
537 }
538
539 /* Some more indeces into case data */
540 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
541 #define N_POS_GT 2  /* number of postive cases with values greater than n */
542 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
543 #define N_NEG_LT 4  /* number of negative cases with values less than n */
544
545 static bool
546 gt (double d1, double d2)
547 {
548   return d1 > d2;
549 }
550
551
552 static bool
553 ge (double d1, double d2)
554 {
555   return d1 > d2;
556 }
557
558 static bool
559 lt (double d1, double d2)
560 {
561   return d1 < d2;
562 }
563
564
565 /*
566   Return a casereader with width 3,
567   populated with cases based upon READER.
568   The cases will have the values:
569   (N, number of cases equal to N, number of cases greater than N)
570   As a side effect, update RS->n1 with the number of positive cases.
571 */
572 static struct casereader *
573 process_positive_group (const struct variable *var, struct casereader *reader,
574                         const struct dictionary *dict,
575                         struct roc_state *rs)
576 {
577   return process_group (var, reader, gt, dict, &rs->n1,
578                         &rs->cutpoint_rdr,
579                         ge,
580                         ROC_TP, ROC_FN);
581 }
582
583 /*
584   Return a casereader with width 3,
585   populated with cases based upon READER.
586   The cases will have the values:
587   (N, number of cases equal to N, number of cases less than N)
588   As a side effect, update RS->n2 with the number of negative cases.
589 */
590 static struct casereader *
591 process_negative_group (const struct variable *var, struct casereader *reader,
592                         const struct dictionary *dict,
593                         struct roc_state *rs)
594 {
595   return process_group (var, reader, lt, dict, &rs->n2,
596                         &rs->cutpoint_rdr,
597                         lt,
598                         ROC_TN, ROC_FP);
599 }
600
601
602
603
604 static void
605 append_cutpoint (struct casewriter *writer, double cutpoint)
606 {
607   struct ccase *cc = case_create (casewriter_get_proto (writer));
608
609   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
610   case_data_rw_idx (cc, ROC_TP)->f = 0;
611   case_data_rw_idx (cc, ROC_FN)->f = 0;
612   case_data_rw_idx (cc, ROC_TN)->f = 0;
613   case_data_rw_idx (cc, ROC_FP)->f = 0;
614
615   casewriter_write (writer, cc);
616 }
617
618
619 /* 
620    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
621    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
622    reader will be populated with its final number of cases.
623    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
624    value.  The other entries will be initialised to zero.
625 */
626 static void
627 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
628 {
629   int i;
630   struct casereader *r = casereader_clone (input);
631   struct ccase *c;
632   struct caseproto *proto = caseproto_create ();
633
634   struct subcase ordering;
635   subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
636
637   proto = caseproto_add_width (proto, 0); /* cutpoint */
638   proto = caseproto_add_width (proto, 0); /* ROC_TP */
639   proto = caseproto_add_width (proto, 0); /* ROC_FN */
640   proto = caseproto_add_width (proto, 0); /* ROC_TN */
641   proto = caseproto_add_width (proto, 0); /* ROC_FP */
642
643   for (i = 0 ; i < roc->n_vars; ++i)
644     {
645       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
646       rs[i].prev_result = SYSMIS;
647       rs[i].max = -DBL_MAX;
648       rs[i].min = DBL_MAX;
649     }
650
651   for (; (c = casereader_read (r)) != NULL; case_unref (c))
652     {
653       for (i = 0 ; i < roc->n_vars; ++i)
654         {
655           const union value *v = case_data (c, roc->vars[i]); 
656           const double result = v->f;
657
658           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
659             continue;
660
661           minimize (&rs[i].min, result);
662           maximize (&rs[i].max, result);
663
664           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
665             {
666               const double mean = (result + rs[i].prev_result ) / 2.0;
667               append_cutpoint (rs[i].cutpoint_wtr, mean);
668             }
669
670           rs[i].prev_result = result;
671         }
672     }
673   casereader_destroy (r);
674
675
676   /* Append the min and max cutpoints */
677   for (i = 0 ; i < roc->n_vars; ++i)
678     {
679       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
680       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
681
682       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
683     }
684 }
685
686 static void
687 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
688 {
689   int i;
690
691   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
692
693   struct casereader *negatives = NULL;
694   struct casereader *positives = NULL;
695
696   struct caseproto *n_proto = caseproto_create ();
697
698   struct subcase up_ordering;
699   struct subcase down_ordering;
700
701   struct casewriter *neg_wtr = NULL;
702
703   struct casereader *input = casereader_create_filter_missing (reader,
704                                                                roc->vars, roc->n_vars,
705                                                                roc->exclude,
706                                                                NULL,
707                                                                NULL);
708
709   input = casereader_create_filter_missing (input,
710                                             &roc->state_var, 1,
711                                             roc->exclude,
712                                             NULL,
713                                             NULL);
714
715   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
716
717   prepare_cutpoints (roc, rs, input);
718
719
720   /* Separate the positive actual state cases from the negative ones */
721   positives = 
722     casereader_create_filter_func (input,
723                                    match_positives,
724                                    NULL,
725                                    roc,
726                                    neg_wtr);
727
728   n_proto = caseproto_create ();
729       
730   n_proto = caseproto_add_width (n_proto, 0);
731   n_proto = caseproto_add_width (n_proto, 0);
732   n_proto = caseproto_add_width (n_proto, 0);
733   n_proto = caseproto_add_width (n_proto, 0);
734   n_proto = caseproto_add_width (n_proto, 0);
735
736   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
737   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
738
739   for (i = 0 ; i < roc->n_vars; ++i)
740     {
741       struct casewriter *w = NULL;
742       struct casereader *r = NULL;
743
744       struct ccase *c;
745
746       struct ccase *cpos;
747       struct casereader *n_neg ;
748       const struct variable *var = roc->vars[i];
749
750       struct casereader *neg ;
751       struct casereader *pos = casereader_clone (positives);
752
753
754       struct casereader *n_pos =
755         process_positive_group (var, pos, dict, &rs[i]);
756
757       if ( negatives == NULL)
758         {
759           negatives = casewriter_make_reader (neg_wtr);
760         }
761
762       neg = casereader_clone (negatives);
763
764       n_neg = process_negative_group (var, neg, dict, &rs[i]);
765
766
767       /* Merge the n_pos and n_neg casereaders */
768       w = sort_create_writer (&up_ordering, n_proto);
769       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
770         {
771           struct ccase *pos_case = case_create (n_proto);
772           struct ccase *cneg;
773           const double jpos = case_data_idx (cpos, VALUE)->f;
774
775           while ((cneg = casereader_read (n_neg)))
776             {
777               struct ccase *nc = case_create (n_proto);
778
779               const double jneg = case_data_idx (cneg, VALUE)->f;
780
781               case_data_rw_idx (nc, VALUE)->f = jneg;
782               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
783
784               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
785
786               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
787               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
788
789               casewriter_write (w, nc);
790
791               case_unref (cneg);
792               if ( jneg > jpos)
793                 break;
794             }
795
796           case_data_rw_idx (pos_case, VALUE)->f = jpos;
797           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
798           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
799           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
800           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
801
802           casewriter_write (w, pos_case);
803         }
804
805 /* These aren't used anymore */
806 #undef N_EQ
807 #undef N_PRED
808
809       r = casewriter_make_reader (w);
810
811       /* Propagate the N_POS_GT values from the positive cases
812          to the negative ones */
813       {
814         double prev_pos_gt = rs[i].n1;
815         w = sort_create_writer (&down_ordering, n_proto);
816
817         for ( ; (c = casereader_read (r) ); case_unref (c))
818           {
819             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
820             struct ccase *nc = case_clone (c);
821
822             if ( n_pos_gt == SYSMIS)
823               {
824                 n_pos_gt = prev_pos_gt;
825                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
826               }
827             
828             casewriter_write (w, nc);
829             prev_pos_gt = n_pos_gt;
830           }
831
832         r = casewriter_make_reader (w);
833       }
834
835       /* Propagate the N_NEG_LT values from the negative cases
836          to the positive ones */
837       {
838         double prev_neg_lt = rs[i].n2;
839         w = sort_create_writer (&up_ordering, n_proto);
840
841         for ( ; (c = casereader_read (r) ); case_unref (c))
842           {
843             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
844             struct ccase *nc = case_clone (c);
845
846             if ( n_neg_lt == SYSMIS)
847               {
848                 n_neg_lt = prev_neg_lt;
849                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
850               }
851             
852             casewriter_write (w, nc);
853             prev_neg_lt = n_neg_lt;
854           }
855
856         r = casewriter_make_reader (w);
857       }
858
859       {
860         struct ccase *prev_case = NULL;
861         for ( ; (c = casereader_read (r) ); case_unref (c))
862           {
863             const struct ccase *next_case = casereader_peek (r, 0);
864
865             const double j = case_data_idx (c, VALUE)->f;
866             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
867             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
868             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
869             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
870
871             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
872               {
873                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
874                   {
875                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
876                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
877                   }
878
879                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
880                   {
881                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
882                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
883                   }
884               }
885
886             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
887               {
888                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
889
890                 rs[i].q1hat +=
891                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
892                 rs[i].q2hat +=
893                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
894
895               }
896
897             case_unref (prev_case);
898             prev_case = case_clone (c);
899           }
900
901         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
902         if ( roc->invert ) 
903           rs[i].auc = 1 - rs[i].auc;
904
905         if ( roc->bi_neg_exp )
906           {
907             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
908             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
909           }
910         else
911           {
912             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
913             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
914           }
915       }
916     }
917
918   casereader_destroy (positives);
919   casereader_destroy (negatives);
920
921   output_roc (rs, roc);
922
923   free (rs);
924 }
925
926 static void
927 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
928 {
929   int i;
930   const int n_fields = roc->print_se ? 5 : 1;
931   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
932   const int n_rows = 2 + roc->n_vars;
933   struct tab_table *tbl = tab_create (n_cols, n_rows);
934
935   if ( roc->n_vars > 1)
936     tab_title (tbl, _("Area Under the Curve"));
937   else
938     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
939
940   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
941
942   tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
943
944   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
945
946   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
947
948   tab_box (tbl,
949            TAL_2, TAL_2,
950            -1, TAL_1,
951            0, 0,
952            n_cols - 1,
953            n_rows - 1);
954
955   if ( roc->print_se )
956     {
957       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
958       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
959
960       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
961       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
962
963       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
964                              TAT_TITLE | TAB_CENTER,
965                              _("Asymp. %g%% Confidence Interval"), roc->ci);
966       tab_vline (tbl, 0, n_cols - 1, 0, 0);
967       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
968     }
969
970   if ( roc->n_vars > 1)
971     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
972
973   if ( roc->n_vars > 1)
974     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
975
976
977   for ( i = 0 ; i < roc->n_vars ; ++i )
978     {
979       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
980
981       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
982
983       if ( roc->print_se )
984         {
985           double se ;
986           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
987                                       (12 * rs[i].n1 * rs[i].n2));
988           double ci ;
989           double yy ;
990
991           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
992             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
993
994           se /= rs[i].n1 * rs[i].n2;
995
996           se = sqrt (se);
997
998           tab_double (tbl, n_cols - 4, 2 + i, 0,
999                       se,
1000                       NULL);
1001
1002           ci = 1 - roc->ci / 100.0;
1003           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1004
1005           tab_double (tbl, n_cols - 2, 2 + i, 0,
1006                       rs[i].auc - yy,
1007                       NULL);
1008
1009           tab_double (tbl, n_cols - 1, 2 + i, 0,
1010                       rs[i].auc + yy,
1011                       NULL);
1012
1013           tab_double (tbl, n_cols - 3, 2 + i, 0,
1014                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1015                       NULL);
1016         }
1017     }
1018
1019   tab_submit (tbl);
1020 }
1021
1022
1023 static void
1024 show_summary (const struct cmd_roc *roc)
1025 {
1026   const int n_cols = 3;
1027   const int n_rows = 4;
1028   struct tab_table *tbl = tab_create (n_cols, n_rows);
1029
1030   tab_title (tbl, _("Case Summary"));
1031
1032   tab_headers (tbl, 1, 0, 2, 0);
1033
1034   tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
1035
1036   tab_box (tbl,
1037            TAL_2, TAL_2,
1038            -1, -1,
1039            0, 0,
1040            n_cols - 1,
1041            n_rows - 1);
1042
1043   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1044   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1045
1046
1047   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1048   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1049
1050
1051   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1052   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1053   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1054
1055   tab_joint_text (tbl, 1, 0, 2, 0,
1056                   TAT_TITLE | TAB_CENTER,
1057                   _("Valid N (listwise)"));
1058
1059
1060   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1061   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1062
1063
1064   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1065   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1066
1067   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1068   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1069
1070   tab_submit (tbl);
1071 }
1072
1073
1074 static void
1075 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1076 {
1077   int x = 1;
1078   int i;
1079   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1080   int n_rows = 1;
1081   struct tab_table *tbl ;
1082
1083   for (i = 0; i < roc->n_vars; ++i)
1084     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1085
1086   tbl = tab_create (n_cols, n_rows);
1087
1088   if ( roc->n_vars > 1)
1089     tab_title (tbl, _("Coordinates of the Curve"));
1090   else
1091     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1092
1093
1094   tab_headers (tbl, 1, 0, 1, 0);
1095
1096   tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
1097
1098   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1099
1100   if ( roc->n_vars > 1)
1101     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1102
1103   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1104   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1105   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1106
1107   tab_box (tbl,
1108            TAL_2, TAL_2,
1109            -1, TAL_1,
1110            0, 0,
1111            n_cols - 1,
1112            n_rows - 1);
1113
1114   if ( roc->n_vars > 1)
1115     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1116
1117   for (i = 0; i < roc->n_vars; ++i)
1118     {
1119       struct ccase *cc;
1120       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1121
1122       if ( roc->n_vars > 1)
1123         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1124
1125       if ( i > 0)
1126         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1127
1128
1129       for (; (cc = casereader_read (r)) != NULL;
1130            case_unref (cc), x++)
1131         {
1132           const double se = case_data_idx (cc, ROC_TP)->f /
1133             (
1134              case_data_idx (cc, ROC_TP)->f
1135              +
1136              case_data_idx (cc, ROC_FN)->f
1137              );
1138
1139           const double sp = case_data_idx (cc, ROC_TN)->f /
1140             (
1141              case_data_idx (cc, ROC_TN)->f
1142              +
1143              case_data_idx (cc, ROC_FP)->f
1144              );
1145
1146           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1147                       var_get_print_format (roc->vars[i]));
1148
1149           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1150           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1151         }
1152
1153       casereader_destroy (r);
1154     }
1155
1156   tab_submit (tbl);
1157 }
1158
1159
1160 static void
1161 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1162 {
1163   show_summary (roc);
1164
1165   if ( roc->curve )
1166     {
1167       struct roc_chart *rc;
1168       size_t i;
1169
1170       rc = roc_chart_create (roc->reference);
1171       for (i = 0; i < roc->n_vars; i++)
1172         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1173                            rs[i].cutpoint_rdr);
1174       chart_submit (roc_chart_get_chart (rc));
1175     }
1176
1177   show_auc (rs, roc);
1178
1179   if ( roc->print_coords )
1180     show_coords (rs, roc);
1181 }
1182