Fix bug in ROC parsing long string variables
[pspp-builds.git] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/procedure.h>
20 #include <language/lexer/variable-parser.h>
21 #include <language/lexer/value-parser.h>
22 #include <language/command.h>
23 #include <language/lexer/lexer.h>
24
25 #include <data/casegrouper.h>
26 #include <data/casereader.h>
27 #include <data/casewriter.h>
28 #include <data/dictionary.h>
29 #include <data/format.h>
30 #include <math/sort.h>
31 #include <data/subcase.h>
32
33
34 #include <libpspp/misc.h>
35
36 #include <gsl/gsl_cdf.h>
37 #include <output/table.h>
38
39 #include <output/charts/plot-chart.h>
40 #include <output/charts/cartesian.h>
41
42 #include "gettext.h"
43 #define _(msgid) gettext (msgid)
44 #define N_(msgid) msgid
45
46 struct cmd_roc
47 {
48   size_t n_vars;
49   const struct variable **vars;
50   const struct dictionary *dict;
51
52   const struct variable *state_var ;
53   union value state_value;
54
55   /* Plot the roc curve */
56   bool curve;
57   /* Plot the reference line */
58   bool reference;
59
60   double ci;
61
62   bool print_coords;
63   bool print_se;
64   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
65                       should be used */
66   enum mv_class exclude;
67
68   bool invert ; /* True iff a smaller test result variable indicates
69                    a positive result */
70
71   double pos;
72   double neg;
73   double pos_weighted;
74   double neg_weighted;
75 };
76
77 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
78
79 int
80 cmd_roc (struct lexer *lexer, struct dataset *ds)
81 {
82   struct cmd_roc roc ;
83   const struct dictionary *dict = dataset_dict (ds);
84
85   roc.vars = NULL;
86   roc.n_vars = 0;
87   roc.print_se = false;
88   roc.print_coords = false;
89   roc.exclude = MV_ANY;
90   roc.curve = true;
91   roc.reference = false;
92   roc.ci = 95;
93   roc.bi_neg_exp = false;
94   roc.invert = false;
95   roc.pos = roc.pos_weighted = 0;
96   roc.neg = roc.neg_weighted = 0;
97   roc.dict = dataset_dict (ds);
98
99   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
100                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
101     goto error;
102
103   if ( ! lex_force_match (lexer, T_BY))
104     {
105       goto error;
106     }
107
108   roc.state_var = parse_variable (lexer, dict);
109
110   if ( !lex_force_match (lexer, '('))
111     {
112       goto error;
113     }
114
115   value_init (&roc.state_value, var_get_width (roc.state_var));
116   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
117
118
119   if ( !lex_force_match (lexer, ')'))
120     {
121       goto error;
122     }
123
124
125   while (lex_token (lexer) != '.')
126     {
127       lex_match (lexer, '/');
128       if (lex_match_id (lexer, "MISSING"))
129         {
130           lex_match (lexer, '=');
131           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
132             {
133               if (lex_match_id (lexer, "INCLUDE"))
134                 {
135                   roc.exclude = MV_SYSTEM;
136                 }
137               else if (lex_match_id (lexer, "EXCLUDE"))
138                 {
139                   roc.exclude = MV_ANY;
140                 }
141               else
142                 {
143                   lex_error (lexer, NULL);
144                   goto error;
145                 }
146             }
147         }
148       else if (lex_match_id (lexer, "PLOT"))
149         {
150           lex_match (lexer, '=');
151           if (lex_match_id (lexer, "CURVE"))
152             {
153               roc.curve = true;
154               if (lex_match (lexer, '('))
155                 {
156                   roc.reference = true;
157                   lex_force_match_id (lexer, "REFERENCE");
158                   lex_force_match (lexer, ')');
159                 }
160             }
161           else if (lex_match_id (lexer, "NONE"))
162             {
163               roc.curve = false;
164             }
165           else
166             {
167               lex_error (lexer, NULL);
168               goto error;
169             }
170         }
171       else if (lex_match_id (lexer, "PRINT"))
172         {
173           lex_match (lexer, '=');
174           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
175             {
176               if (lex_match_id (lexer, "SE"))
177                 {
178                   roc.print_se = true;
179                 }
180               else if (lex_match_id (lexer, "COORDINATES"))
181                 {
182                   roc.print_coords = true;
183                 }
184               else
185                 {
186                   lex_error (lexer, NULL);
187                   goto error;
188                 }
189             }
190         }
191       else if (lex_match_id (lexer, "CRITERIA"))
192         {
193           lex_match (lexer, '=');
194           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
195             {
196               if (lex_match_id (lexer, "CUTOFF"))
197                 {
198                   lex_force_match (lexer, '(');
199                   if (lex_match_id (lexer, "INCLUDE"))
200                     {
201                       roc.exclude = MV_SYSTEM;
202                     }
203                   else if (lex_match_id (lexer, "EXCLUDE"))
204                     {
205                       roc.exclude = MV_USER | MV_SYSTEM;
206                     }
207                   else
208                     {
209                       lex_error (lexer, NULL);
210                       goto error;
211                     }
212                   lex_force_match (lexer, ')');
213                 }
214               else if (lex_match_id (lexer, "TESTPOS"))
215                 {
216                   lex_force_match (lexer, '(');
217                   if (lex_match_id (lexer, "LARGE"))
218                     {
219                       roc.invert = false;
220                     }
221                   else if (lex_match_id (lexer, "SMALL"))
222                     {
223                       roc.invert = true;
224                     }
225                   else
226                     {
227                       lex_error (lexer, NULL);
228                       goto error;
229                     }
230                   lex_force_match (lexer, ')');
231                 }
232               else if (lex_match_id (lexer, "CI"))
233                 {
234                   lex_force_match (lexer, '(');
235                   lex_force_num (lexer);
236                   roc.ci = lex_number (lexer);
237                   lex_get (lexer);
238                   lex_force_match (lexer, ')');
239                 }
240               else if (lex_match_id (lexer, "DISTRIBUTION"))
241                 {
242                   lex_force_match (lexer, '(');
243                   if (lex_match_id (lexer, "FREE"))
244                     {
245                       roc.bi_neg_exp = false;
246                     }
247                   else if (lex_match_id (lexer, "NEGEXPO"))
248                     {
249                       roc.bi_neg_exp = true;
250                     }
251                   else
252                     {
253                       lex_error (lexer, NULL);
254                       goto error;
255                     }
256                   lex_force_match (lexer, ')');
257                 }
258               else
259                 {
260                   lex_error (lexer, NULL);
261                   goto error;
262                 }
263             }
264         }
265       else
266         {
267           lex_error (lexer, NULL);
268           break;
269         }
270     }
271
272   if ( ! run_roc (ds, &roc)) 
273     goto error;
274
275   value_destroy (&roc.state_value, var_get_width (roc.state_var));
276   free (roc.vars);
277   return CMD_SUCCESS;
278
279  error:
280   value_destroy (&roc.state_value, var_get_width (roc.state_var));
281   free (roc.vars);
282   return CMD_FAILURE;
283 }
284
285
286
287
288 static void
289 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
290
291
292 static int
293 run_roc (struct dataset *ds, struct cmd_roc *roc)
294 {
295   struct dictionary *dict = dataset_dict (ds);
296   bool ok;
297   struct casereader *group;
298
299   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
300   while (casegrouper_get_next_group (grouper, &group))
301     {
302       do_roc (roc, group, dataset_dict (ds));
303     }
304   ok = casegrouper_destroy (grouper);
305   ok = proc_commit (ds) && ok;
306
307   return ok;
308 }
309
310 #if 0
311 static void
312 dump_casereader (struct casereader *reader)
313 {
314   struct ccase *c;
315   struct casereader *r = casereader_clone (reader);
316
317   for ( ; (c = casereader_read (r) ); case_unref (c))
318     {
319       int i;
320       for (i = 0 ; i < case_get_value_cnt (c); ++i)
321         {
322           printf ("%g ", case_data_idx (c, i)->f);
323         }
324       printf ("\n");
325     }
326
327   casereader_destroy (r);
328 }
329 #endif
330
331
332 /* 
333    Return true iff the state variable indicates that C has positive actual state.
334
335    As a side effect, this function also accumulates the roc->{pos,neg} and 
336    roc->{pos,neg}_weighted counts.
337  */
338 static bool
339 match_positives (const struct ccase *c, void *aux)
340 {
341   struct cmd_roc *roc = aux;
342   const struct variable *wv = dict_get_weight (roc->dict);
343   const double weight = wv ? case_data (c, wv)->f : 1.0;
344
345   const bool positive =
346   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
347     var_get_width (roc->state_var)));
348
349   if ( positive )
350     {
351       roc->pos++;
352       roc->pos_weighted += weight;
353     }
354   else
355     {
356       roc->neg++;
357       roc->neg_weighted += weight;
358     }
359
360   return positive;
361 }
362
363
364 #define VALUE  0
365 #define N_EQ   1
366 #define N_PRED 2
367
368 /* Some intermediate state for calculating the cutpoints and the 
369    standard error values */
370 struct roc_state
371 {
372   double auc;  /* Area under the curve */
373
374   double n1;  /* total weight of positives */
375   double n2;  /* total weight of negatives */
376
377   /* intermediates for standard error */
378   double q1hat; 
379   double q2hat;
380
381   /* intermediates for cutpoints */
382   struct casewriter *cutpoint_wtr;
383   struct casereader *cutpoint_rdr;
384   double prev_result;
385   double min;
386   double max;
387 };
388
389 #define CUTPOINT 0
390 #define TP 1
391 #define FN 2
392 #define TN 3
393 #define FP 4
394
395
396 /* 
397    Return a new casereader based upon CUTPOINT_RDR.
398    The number of "positive" cases are placed into
399    the position TRUE_INDEX, and the number of "negative" cases
400    into FALSE_INDEX.
401    POS_COND and RESULT determine the semantics of what is 
402    "positive".
403    WEIGHT is the value of a single count.
404  */
405 static struct casereader *
406 accumulate_counts (struct casereader *cutpoint_rdr, 
407                    double result, double weight, 
408                    bool (*pos_cond) (double, double),
409                    int true_index, int false_index)
410 {
411   const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
412   struct casewriter *w =
413     autopaging_writer_create (proto);
414   struct casereader *r = casereader_clone (cutpoint_rdr);
415   struct ccase *cpc;
416   double prev_cp = SYSMIS;
417
418   for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
419     {
420       struct ccase *new_case;
421       const double cp = case_data_idx (cpc, CUTPOINT)->f;
422
423       assert (cp != SYSMIS);
424
425       /* We don't want duplicates here */
426       if ( cp == prev_cp )
427         continue;
428
429       new_case = case_clone (cpc);
430
431       if ( pos_cond (result, cp))
432         case_data_rw_idx (new_case, true_index)->f += weight;
433       else
434         case_data_rw_idx (new_case, false_index)->f += weight;
435
436       prev_cp = cp;
437
438       casewriter_write (w, new_case);
439     }
440   casereader_destroy (r);
441
442   return casewriter_make_reader (w);
443 }
444
445
446
447 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
448
449 /*
450   This function does 3 things:
451
452   1. Counts the number of cases which are equal to every other case in READER,
453   and those cases for which the relationship between it and every other case
454   satifies PRED (normally either > or <).  VAR is variable defining a case's value
455   for this purpose.
456
457   2. Counts the number of true and false cases in reader, and populates
458   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
459   which receive these values.  POS_COND is the condition defining true
460   and false.
461   
462   3. CC is filled with the cumulative weight of all cases of READER.
463 */
464 static struct casereader *
465 process_group (const struct variable *var, struct casereader *reader,
466                bool (*pred) (double, double),
467                const struct dictionary *dict,
468                double *cc,
469                struct casereader **cutpoint_rdr, 
470                bool (*pos_cond) (double, double),
471                int true_index,
472                int false_index)
473 {
474   const struct variable *w = dict_get_weight (dict);
475
476   struct casereader *r1 =
477     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
478
479   const int weight_idx  = w ? var_get_case_index (w) :
480     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
481   
482   struct ccase *c1;
483
484   struct casereader *rclone = casereader_clone (r1);
485   struct casewriter *wtr;
486   struct caseproto *proto = caseproto_create ();
487
488   proto = caseproto_add_width (proto, 0);
489   proto = caseproto_add_width (proto, 0);
490   proto = caseproto_add_width (proto, 0);
491
492   wtr = autopaging_writer_create (proto);  
493
494   *cc = 0;
495
496   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
497     {
498       struct ccase *new_case = case_create (proto);
499       struct ccase *c2;
500       struct casereader *r2 = casereader_clone (rclone);
501
502       const double weight1 = case_data_idx (c1, weight_idx)->f;
503       const double d1 = case_data (c1, var)->f;
504       double n_eq = 0.0;
505       double n_pred = 0.0;
506
507       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
508                                          pos_cond,
509                                          true_index, false_index);
510
511       *cc += weight1;
512
513       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
514         {
515           const double d2 = case_data (c2, var)->f;
516           const double weight2 = case_data_idx (c2, weight_idx)->f;
517
518           if ( d1 == d2 )
519             {
520               n_eq += weight2;
521               continue;
522             }
523           else  if ( pred (d2, d1))
524             {
525               n_pred += weight2;
526             }
527         }
528
529       case_data_rw_idx (new_case, VALUE)->f = d1;
530       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
531       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
532
533       casewriter_write (wtr, new_case);
534
535       casereader_destroy (r2);
536     }
537
538   casereader_destroy (r1);
539   casereader_destroy (rclone);
540
541   return casewriter_make_reader (wtr);
542 }
543
544 /* Some more indeces into case data */
545 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
546 #define N_POS_GT 2  /* number of postive cases with values greater than n */
547 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
548 #define N_NEG_LT 4  /* number of negative cases with values less than n */
549
550 static bool
551 gt (double d1, double d2)
552 {
553   return d1 > d2;
554 }
555
556
557 static bool
558 ge (double d1, double d2)
559 {
560   return d1 > d2;
561 }
562
563 static bool
564 lt (double d1, double d2)
565 {
566   return d1 < d2;
567 }
568
569
570 /*
571   Return a casereader with width 3,
572   populated with cases based upon READER.
573   The cases will have the values:
574   (N, number of cases equal to N, number of cases greater than N)
575   As a side effect, update RS->n1 with the number of positive cases.
576 */
577 static struct casereader *
578 process_positive_group (const struct variable *var, struct casereader *reader,
579                         const struct dictionary *dict,
580                         struct roc_state *rs)
581 {
582   return process_group (var, reader, gt, dict, &rs->n1,
583                         &rs->cutpoint_rdr,
584                         ge,
585                         TP, FN);
586 }
587
588 /*
589   Return a casereader with width 3,
590   populated with cases based upon READER.
591   The cases will have the values:
592   (N, number of cases equal to N, number of cases less than N)
593   As a side effect, update RS->n2 with the number of negative cases.
594 */
595 static struct casereader *
596 process_negative_group (const struct variable *var, struct casereader *reader,
597                         const struct dictionary *dict,
598                         struct roc_state *rs)
599 {
600   return process_group (var, reader, lt, dict, &rs->n2,
601                         &rs->cutpoint_rdr,
602                         lt,
603                         TN, FP);
604 }
605
606
607
608
609 static void
610 append_cutpoint (struct casewriter *writer, double cutpoint)
611 {
612   struct ccase *cc = case_create (casewriter_get_proto (writer));
613
614   case_data_rw_idx (cc, CUTPOINT)->f = cutpoint;
615   case_data_rw_idx (cc, TP)->f = 0;
616   case_data_rw_idx (cc, FN)->f = 0;
617   case_data_rw_idx (cc, TN)->f = 0;
618   case_data_rw_idx (cc, FP)->f = 0;
619
620   casewriter_write (writer, cc);
621 }
622
623
624 /* 
625    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
626    be created with width 5, ready to take the values (cutpoint, TP, FN, TN, FP), and the
627    reader will be populated with its final number of cases.
628    However on exit from this function, only CUTPOINT entries will be set to their final
629    value.  The other entries will be initialised to zero.
630 */
631 static void
632 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
633 {
634   int i;
635   struct casereader *r = casereader_clone (input);
636   struct ccase *c;
637   struct caseproto *proto = caseproto_create ();
638
639   struct subcase ordering;
640   subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND);
641
642   proto = caseproto_add_width (proto, 0); /* cutpoint */
643   proto = caseproto_add_width (proto, 0); /* TP */
644   proto = caseproto_add_width (proto, 0); /* FN */
645   proto = caseproto_add_width (proto, 0); /* TN */
646   proto = caseproto_add_width (proto, 0); /* FP */
647
648   for (i = 0 ; i < roc->n_vars; ++i)
649     {
650       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
651       rs[i].prev_result = SYSMIS;
652       rs[i].max = -DBL_MAX;
653       rs[i].min = DBL_MAX;
654     }
655
656   for (; (c = casereader_read (r)) != NULL; case_unref (c))
657     {
658       for (i = 0 ; i < roc->n_vars; ++i)
659         {
660           const union value *v = case_data (c, roc->vars[i]); 
661           const double result = v->f;
662
663           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
664             continue;
665
666           minimize (&rs[i].min, result);
667           maximize (&rs[i].max, result);
668
669           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
670             {
671               const double mean = (result + rs[i].prev_result ) / 2.0;
672               append_cutpoint (rs[i].cutpoint_wtr, mean);
673             }
674
675           rs[i].prev_result = result;
676         }
677     }
678   casereader_destroy (r);
679
680
681   /* Append the min and max cutpoints */
682   for (i = 0 ; i < roc->n_vars; ++i)
683     {
684       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
685       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
686
687       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
688     }
689 }
690
691 static void
692 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
693 {
694   int i;
695
696   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
697
698   struct casereader *negatives = NULL;
699   struct casereader *positives = NULL;
700
701   struct caseproto *n_proto = caseproto_create ();
702
703   struct subcase up_ordering;
704   struct subcase down_ordering;
705
706   struct casewriter *neg_wtr = NULL;
707
708   struct casereader *input = casereader_create_filter_missing (reader,
709                                                                roc->vars, roc->n_vars,
710                                                                roc->exclude,
711                                                                NULL,
712                                                                NULL);
713
714   input = casereader_create_filter_missing (input,
715                                             &roc->state_var, 1,
716                                             roc->exclude,
717                                             NULL,
718                                             NULL);
719
720   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
721
722   prepare_cutpoints (roc, rs, input);
723
724
725   /* Separate the positive actual state cases from the negative ones */
726   positives = 
727     casereader_create_filter_func (input,
728                                    match_positives,
729                                    NULL,
730                                    roc,
731                                    neg_wtr);
732
733   n_proto = caseproto_create ();
734       
735   n_proto = caseproto_add_width (n_proto, 0);
736   n_proto = caseproto_add_width (n_proto, 0);
737   n_proto = caseproto_add_width (n_proto, 0);
738   n_proto = caseproto_add_width (n_proto, 0);
739   n_proto = caseproto_add_width (n_proto, 0);
740
741   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
742   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
743
744   for (i = 0 ; i < roc->n_vars; ++i)
745     {
746       struct casewriter *w = NULL;
747       struct casereader *r = NULL;
748
749       struct ccase *c;
750
751       struct ccase *cpos;
752       struct casereader *n_neg ;
753       const struct variable *var = roc->vars[i];
754
755       struct casereader *neg ;
756       struct casereader *pos = casereader_clone (positives);
757
758
759       struct casereader *n_pos =
760         process_positive_group (var, pos, dict, &rs[i]);
761
762       if ( negatives == NULL)
763         {
764           negatives = casewriter_make_reader (neg_wtr);
765         }
766
767       neg = casereader_clone (negatives);
768
769       n_neg = process_negative_group (var, neg, dict, &rs[i]);
770
771
772       /* Merge the n_pos and n_neg casereaders */
773       w = sort_create_writer (&up_ordering, n_proto);
774       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
775         {
776           struct ccase *pos_case = case_create (n_proto);
777           struct ccase *cneg;
778           const double jpos = case_data_idx (cpos, VALUE)->f;
779
780           while ((cneg = casereader_read (n_neg)))
781             {
782               struct ccase *nc = case_create (n_proto);
783
784               const double jneg = case_data_idx (cneg, VALUE)->f;
785
786               case_data_rw_idx (nc, VALUE)->f = jneg;
787               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
788
789               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
790
791               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
792               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
793
794               casewriter_write (w, nc);
795
796               case_unref (cneg);
797               if ( jneg > jpos)
798                 break;
799             }
800
801           case_data_rw_idx (pos_case, VALUE)->f = jpos;
802           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
803           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
804           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
805           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
806
807           casewriter_write (w, pos_case);
808         }
809
810 /* These aren't used anymore */
811 #undef N_EQ
812 #undef N_PRED
813
814       r = casewriter_make_reader (w);
815
816       /* Propagate the N_POS_GT values from the positive cases
817          to the negative ones */
818       {
819         double prev_pos_gt = rs[i].n1;
820         w = sort_create_writer (&down_ordering, n_proto);
821
822         for ( ; (c = casereader_read (r) ); case_unref (c))
823           {
824             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
825             struct ccase *nc = case_clone (c);
826
827             if ( n_pos_gt == SYSMIS)
828               {
829                 n_pos_gt = prev_pos_gt;
830                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
831               }
832             
833             casewriter_write (w, nc);
834             prev_pos_gt = n_pos_gt;
835           }
836
837         r = casewriter_make_reader (w);
838       }
839
840       /* Propagate the N_NEG_LT values from the negative cases
841          to the positive ones */
842       {
843         double prev_neg_lt = rs[i].n2;
844         w = sort_create_writer (&up_ordering, n_proto);
845
846         for ( ; (c = casereader_read (r) ); case_unref (c))
847           {
848             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
849             struct ccase *nc = case_clone (c);
850
851             if ( n_neg_lt == SYSMIS)
852               {
853                 n_neg_lt = prev_neg_lt;
854                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
855               }
856             
857             casewriter_write (w, nc);
858             prev_neg_lt = n_neg_lt;
859           }
860
861         r = casewriter_make_reader (w);
862       }
863
864       {
865         struct ccase *prev_case = NULL;
866         for ( ; (c = casereader_read (r) ); case_unref (c))
867           {
868             const struct ccase *next_case = casereader_peek (r, 0);
869
870             const double j = case_data_idx (c, VALUE)->f;
871             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
872             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
873             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
874             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
875
876             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
877               {
878                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
879                   {
880                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
881                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
882                   }
883
884                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
885                   {
886                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
887                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
888                   }
889               }
890
891             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
892               {
893                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
894
895                 rs[i].q1hat +=
896                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
897                 rs[i].q2hat +=
898                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
899
900               }
901
902             case_unref (prev_case);
903             prev_case = case_clone (c);
904           }
905
906         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
907         if ( roc->invert ) 
908           rs[i].auc = 1 - rs[i].auc;
909
910         if ( roc->bi_neg_exp )
911           {
912             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
913             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
914           }
915         else
916           {
917             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
918             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
919           }
920       }
921     }
922
923   casereader_destroy (positives);
924   casereader_destroy (negatives);
925
926   output_roc (rs, roc);
927
928   free (rs);
929 }
930
931 static void
932 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
933 {
934   int i;
935   const int n_fields = roc->print_se ? 5 : 1;
936   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
937   const int n_rows = 2 + roc->n_vars;
938   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
939
940   if ( roc->n_vars > 1)
941     tab_title (tbl, _("Area Under the Curve"));
942   else
943     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
944
945   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
946
947   tab_dim (tbl, tab_natural_dimensions, NULL);
948
949   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
950
951   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
952
953   tab_box (tbl,
954            TAL_2, TAL_2,
955            -1, TAL_1,
956            0, 0,
957            n_cols - 1,
958            n_rows - 1);
959
960   if ( roc->print_se )
961     {
962       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
963       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
964
965       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
966       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
967
968       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
969                              TAT_TITLE | TAB_CENTER,
970                              _("Asymp. %g%% Confidence Interval"), roc->ci);
971       tab_vline (tbl, 0, n_cols - 1, 0, 0);
972       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
973     }
974
975   if ( roc->n_vars > 1)
976     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
977
978   if ( roc->n_vars > 1)
979     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
980
981
982   for ( i = 0 ; i < roc->n_vars ; ++i )
983     {
984       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
985
986       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
987
988       if ( roc->print_se )
989         {
990           double se ;
991           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
992                                       (12 * rs[i].n1 * rs[i].n2));
993           double ci ;
994           double yy ;
995
996           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
997             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
998
999           se /= rs[i].n1 * rs[i].n2;
1000
1001           se = sqrt (se);
1002
1003           tab_double (tbl, n_cols - 4, 2 + i, 0,
1004                       se,
1005                       NULL);
1006
1007           ci = 1 - roc->ci / 100.0;
1008           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1009
1010           tab_double (tbl, n_cols - 2, 2 + i, 0,
1011                       rs[i].auc - yy,
1012                       NULL);
1013
1014           tab_double (tbl, n_cols - 1, 2 + i, 0,
1015                       rs[i].auc + yy,
1016                       NULL);
1017
1018           tab_double (tbl, n_cols - 3, 2 + i, 0,
1019                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1020                       NULL);
1021         }
1022     }
1023
1024   tab_submit (tbl);
1025 }
1026
1027
1028 static void
1029 show_summary (const struct cmd_roc *roc)
1030 {
1031   const int n_cols = 3;
1032   const int n_rows = 4;
1033   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
1034
1035   tab_title (tbl, _("Case Summary"));
1036
1037   tab_headers (tbl, 1, 0, 2, 0);
1038
1039   tab_dim (tbl, tab_natural_dimensions, NULL);
1040
1041   tab_box (tbl,
1042            TAL_2, TAL_2,
1043            -1, -1,
1044            0, 0,
1045            n_cols - 1,
1046            n_rows - 1);
1047
1048   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1049   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1050
1051
1052   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1053   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1054
1055
1056   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1057   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1058   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1059
1060   tab_joint_text (tbl, 1, 0, 2, 0,
1061                   TAT_TITLE | TAB_CENTER,
1062                   _("Valid N (listwise)"));
1063
1064
1065   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1066   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1067
1068
1069   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1070   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1071
1072   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1073   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1074
1075   tab_submit (tbl);
1076 }
1077
1078
1079 static void
1080 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1081 {
1082   int x = 1;
1083   int i;
1084   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1085   int n_rows = 1;
1086   struct tab_table *tbl ;
1087
1088   for (i = 0; i < roc->n_vars; ++i)
1089     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1090
1091   tbl = tab_create (n_cols, n_rows, 0);
1092
1093   if ( roc->n_vars > 1)
1094     tab_title (tbl, _("Coordinates of the Curve"));
1095   else
1096     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1097
1098
1099   tab_headers (tbl, 1, 0, 1, 0);
1100
1101   tab_dim (tbl, tab_natural_dimensions, NULL);
1102
1103   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1104
1105   if ( roc->n_vars > 1)
1106     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1107
1108   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1109   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1110   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1111
1112   tab_box (tbl,
1113            TAL_2, TAL_2,
1114            -1, TAL_1,
1115            0, 0,
1116            n_cols - 1,
1117            n_rows - 1);
1118
1119   if ( roc->n_vars > 1)
1120     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1121
1122   for (i = 0; i < roc->n_vars; ++i)
1123     {
1124       struct ccase *cc;
1125       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1126
1127       if ( roc->n_vars > 1)
1128         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1129
1130       if ( i > 0)
1131         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1132
1133
1134       for (; (cc = casereader_read (r)) != NULL;
1135            case_unref (cc), x++)
1136         {
1137           const double se = case_data_idx (cc, TP)->f /
1138             (
1139              case_data_idx (cc, TP)->f
1140              +
1141              case_data_idx (cc, FN)->f
1142              );
1143
1144           const double sp = case_data_idx (cc, TN)->f /
1145             (
1146              case_data_idx (cc, TN)->f
1147              +
1148              case_data_idx (cc, FP)->f
1149              );
1150
1151           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, CUTPOINT)->f,
1152                       var_get_print_format (roc->vars[i]));
1153
1154           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1155           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1156         }
1157
1158       casereader_destroy (r);
1159     }
1160
1161   tab_submit (tbl);
1162 }
1163
1164
1165 static void
1166 draw_roc (struct roc_state *rs, const struct cmd_roc *roc)
1167 {
1168   int i;
1169
1170   struct chart *roc_chart = chart_create ();
1171
1172   chart_write_title (roc_chart, _("ROC Curve"));
1173   chart_write_xlabel (roc_chart, _("1 - Specificity"));
1174   chart_write_ylabel (roc_chart, _("Sensitivity"));
1175
1176   chart_write_xscale (roc_chart, 0, 1, 5);
1177   chart_write_yscale (roc_chart, 0, 1, 5);
1178
1179   if ( roc->reference )
1180     {
1181       chart_line (roc_chart, 1.0, 0,
1182                   0.0, 1.0,
1183                   CHART_DIM_X);
1184     }
1185
1186   for (i = 0; i < roc->n_vars; ++i)
1187     {
1188       struct ccase *cc;
1189       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1190
1191       chart_vector_start (roc_chart, var_get_name (roc->vars[i]));
1192       for (; (cc = casereader_read (r)) != NULL;
1193            case_unref (cc))
1194         {
1195           double se = case_data_idx (cc, TP)->f;
1196           double sp = case_data_idx (cc, TN)->f;
1197
1198           se /= case_data_idx (cc, FN)->f +
1199             case_data_idx (cc, TP)->f ;
1200
1201           sp /= case_data_idx (cc, TN)->f +
1202             case_data_idx (cc, FP)->f ;
1203
1204           chart_vector (roc_chart, 1 - sp, se);
1205         }
1206       chart_vector_end (roc_chart);
1207       casereader_destroy (r);
1208     }
1209
1210   chart_write_legend (roc_chart);
1211
1212   chart_submit (roc_chart);
1213 }
1214
1215
1216 static void
1217 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1218 {
1219   show_summary (roc);
1220
1221   if ( roc->curve )
1222     draw_roc (rs, roc);
1223
1224   show_auc (rs, roc);
1225
1226
1227   if ( roc->print_coords )
1228     show_coords (rs, roc);
1229 }
1230