d266d7cf7684d311db10a36ba788852a19470f4a
[pspp-builds.git] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/procedure.h>
20 #include <language/lexer/variable-parser.h>
21 #include <language/lexer/value-parser.h>
22 #include <language/command.h>
23 #include <language/lexer/lexer.h>
24
25 #include <data/casegrouper.h>
26 #include <data/casereader.h>
27 #include <data/casewriter.h>
28 #include <data/dictionary.h>
29 #include <data/format.h>
30 #include <math/sort.h>
31 #include <data/subcase.h>
32
33
34 #include <libpspp/misc.h>
35
36 #include <gsl/gsl_cdf.h>
37 #include <output/table.h>
38
39 #include <output/charts/plot-chart.h>
40 #include <output/charts/cartesian.h>
41
42 #include "gettext.h"
43 #define _(msgid) gettext (msgid)
44 #define N_(msgid) msgid
45
46 struct cmd_roc
47 {
48   size_t n_vars;
49   const struct variable **vars;
50   const struct dictionary *dict;
51
52   const struct variable *state_var ;
53   union value state_value;
54
55   /* Plot the roc curve */
56   bool curve;
57   /* Plot the reference line */
58   bool reference;
59
60   double ci;
61
62   bool print_coords;
63   bool print_se;
64   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
65                       should be used */
66   enum mv_class exclude;
67
68   bool invert ; /* True iff a smaller test result variable indicates
69                    a positive result */
70
71   double pos;
72   double neg;
73   double pos_weighted;
74   double neg_weighted;
75 };
76
77 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
78
79 int
80 cmd_roc (struct lexer *lexer, struct dataset *ds)
81 {
82   struct cmd_roc roc ;
83   const struct dictionary *dict = dataset_dict (ds);
84
85   roc.vars = NULL;
86   roc.n_vars = 0;
87   roc.print_se = false;
88   roc.print_coords = false;
89   roc.exclude = MV_ANY;
90   roc.curve = true;
91   roc.reference = false;
92   roc.ci = 95;
93   roc.bi_neg_exp = false;
94   roc.invert = false;
95   roc.pos = roc.pos_weighted = 0;
96   roc.neg = roc.neg_weighted = 0;
97   roc.dict = dataset_dict (ds);
98
99   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
100                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
101     goto error;;
102
103   if ( ! lex_force_match (lexer, T_BY))
104     {
105       goto error;;
106     }
107
108   roc.state_var = parse_variable (lexer, dict);
109
110   if ( !lex_force_match (lexer, '('))
111     {
112       goto error;;
113     }
114
115   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
116
117
118   if ( !lex_force_match (lexer, ')'))
119     {
120       goto error;;
121     }
122
123
124   while (lex_token (lexer) != '.')
125     {
126       lex_match (lexer, '/');
127       if (lex_match_id (lexer, "MISSING"))
128         {
129           lex_match (lexer, '=');
130           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
131             {
132               if (lex_match_id (lexer, "INCLUDE"))
133                 {
134                   roc.exclude = MV_SYSTEM;
135                 }
136               else if (lex_match_id (lexer, "EXCLUDE"))
137                 {
138                   roc.exclude = MV_ANY;
139                 }
140               else
141                 {
142                   lex_error (lexer, NULL);
143                   goto error;;
144                 }
145             }
146         }
147       else if (lex_match_id (lexer, "PLOT"))
148         {
149           lex_match (lexer, '=');
150           if (lex_match_id (lexer, "CURVE"))
151             {
152               roc.curve = true;
153               if (lex_match (lexer, '('))
154                 {
155                   roc.reference = true;
156                   lex_force_match_id (lexer, "REFERENCE");
157                   lex_force_match (lexer, ')');
158                 }
159             }
160           else if (lex_match_id (lexer, "NONE"))
161             {
162               roc.curve = false;
163             }
164           else
165             {
166               lex_error (lexer, NULL);
167               goto error;;
168             }
169         }
170       else if (lex_match_id (lexer, "PRINT"))
171         {
172           lex_match (lexer, '=');
173           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
174             {
175               if (lex_match_id (lexer, "SE"))
176                 {
177                   roc.print_se = true;
178                 }
179               else if (lex_match_id (lexer, "COORDINATES"))
180                 {
181                   roc.print_coords = true;
182                 }
183               else
184                 {
185                   lex_error (lexer, NULL);
186                   goto error;;
187                 }
188             }
189         }
190       else if (lex_match_id (lexer, "CRITERIA"))
191         {
192           lex_match (lexer, '=');
193           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
194             {
195               if (lex_match_id (lexer, "CUTOFF"))
196                 {
197                   lex_force_match (lexer, '(');
198                   if (lex_match_id (lexer, "INCLUDE"))
199                     {
200                       roc.exclude = MV_SYSTEM;
201                     }
202                   else if (lex_match_id (lexer, "EXCLUDE"))
203                     {
204                       roc.exclude = MV_USER | MV_SYSTEM;
205                     }
206                   else
207                     {
208                       lex_error (lexer, NULL);
209                       goto error;;
210                     }
211                   lex_force_match (lexer, ')');
212                 }
213               else if (lex_match_id (lexer, "TESTPOS"))
214                 {
215                   lex_force_match (lexer, '(');
216                   if (lex_match_id (lexer, "LARGE"))
217                     {
218                       roc.invert = false;
219                     }
220                   else if (lex_match_id (lexer, "SMALL"))
221                     {
222                       roc.invert = true;
223                     }
224                   else
225                     {
226                       lex_error (lexer, NULL);
227                       goto error;;
228                     }
229                   lex_force_match (lexer, ')');
230                 }
231               else if (lex_match_id (lexer, "CI"))
232                 {
233                   lex_force_match (lexer, '(');
234                   lex_force_num (lexer);
235                   roc.ci = lex_number (lexer);
236                   lex_get (lexer);
237                   lex_force_match (lexer, ')');
238                 }
239               else if (lex_match_id (lexer, "DISTRIBUTION"))
240                 {
241                   lex_force_match (lexer, '(');
242                   if (lex_match_id (lexer, "FREE"))
243                     {
244                       roc.bi_neg_exp = false;
245                     }
246                   else if (lex_match_id (lexer, "NEGEXPO"))
247                     {
248                       roc.bi_neg_exp = true;
249                     }
250                   else
251                     {
252                       lex_error (lexer, NULL);
253                       goto error;;
254                     }
255                   lex_force_match (lexer, ')');
256                 }
257               else
258                 {
259                   lex_error (lexer, NULL);
260                   goto error;;
261                 }
262             }
263         }
264       else
265         {
266           lex_error (lexer, NULL);
267           break;
268         }
269     }
270
271   if ( ! run_roc (ds, &roc)) 
272     goto error;;
273
274   return CMD_SUCCESS;
275
276  error:
277   free (roc.vars);
278   return CMD_FAILURE;
279 }
280
281
282
283
284 static void
285 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
286
287
288 static int
289 run_roc (struct dataset *ds, struct cmd_roc *roc)
290 {
291   struct dictionary *dict = dataset_dict (ds);
292   bool ok;
293   struct casereader *group;
294
295   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
296   while (casegrouper_get_next_group (grouper, &group))
297     {
298       do_roc (roc, group, dataset_dict (ds));
299     }
300   ok = casegrouper_destroy (grouper);
301   ok = proc_commit (ds) && ok;
302
303   return ok;
304 }
305
306 #if 0
307 static void
308 dump_casereader (struct casereader *reader)
309 {
310   struct ccase *c;
311   struct casereader *r = casereader_clone (reader);
312
313   for ( ; (c = casereader_read (r) ); case_unref (c))
314     {
315       int i;
316       for (i = 0 ; i < case_get_value_cnt (c); ++i)
317         {
318           printf ("%g ", case_data_idx (c, i)->f);
319         }
320       printf ("\n");
321     }
322
323   casereader_destroy (r);
324 }
325 #endif
326
327
328 /* 
329    Return true iff the state variable indicates that C has positive actual state.
330
331    As a side effect, this function also accumulates the roc->{pos,neg} and 
332    roc->{pos,neg}_weighted counts.
333  */
334 static bool
335 match_positives (const struct ccase *c, void *aux)
336 {
337   struct cmd_roc *roc = aux;
338   const struct variable *wv = dict_get_weight (roc->dict);
339   const double weight = wv ? case_data (c, wv)->f : 1.0;
340
341   const bool positive =
342   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
343     var_get_width (roc->state_var)));
344
345   if ( positive )
346     {
347       roc->pos++;
348       roc->pos_weighted += weight;
349     }
350   else
351     {
352       roc->neg++;
353       roc->neg_weighted += weight;
354     }
355
356   return positive;
357 }
358
359
360 #define VALUE  0
361 #define N_EQ   1
362 #define N_PRED 2
363
364 /* Some intermediate state for calculating the cutpoints and the 
365    standard error values */
366 struct roc_state
367 {
368   double auc;  /* Area under the curve */
369
370   double n1;  /* total weight of positives */
371   double n2;  /* total weight of negatives */
372
373   /* intermediates for standard error */
374   double q1hat; 
375   double q2hat;
376
377   /* intermediates for cutpoints */
378   struct casewriter *cutpoint_wtr;
379   struct casereader *cutpoint_rdr;
380   double prev_result;
381   double min;
382   double max;
383 };
384
385 #define CUTPOINT 0
386 #define TP 1
387 #define FN 2
388 #define TN 3
389 #define FP 4
390
391
392 /* 
393    Return a new casereader based upon CUTPOINT_RDR.
394    The number of "positive" cases are placed into
395    the position TRUE_INDEX, and the number of "negative" cases
396    into FALSE_INDEX.
397    POS_COND and RESULT determine the semantics of what is 
398    "positive".
399    WEIGHT is the value of a single count.
400  */
401 static struct casereader *
402 accumulate_counts (struct casereader *cutpoint_rdr, 
403                    double result, double weight, 
404                    bool (*pos_cond) (double, double),
405                    int true_index, int false_index)
406 {
407   const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
408   struct casewriter *w =
409     autopaging_writer_create (proto);
410   struct casereader *r = casereader_clone (cutpoint_rdr);
411   struct ccase *cpc;
412   double prev_cp = SYSMIS;
413
414   for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
415     {
416       struct ccase *new_case;
417       const double cp = case_data_idx (cpc, CUTPOINT)->f;
418
419       assert (cp != SYSMIS);
420
421       /* We don't want duplicates here */
422       if ( cp == prev_cp )
423         continue;
424
425       new_case = case_clone (cpc);
426
427       if ( pos_cond (result, cp))
428         case_data_rw_idx (new_case, true_index)->f += weight;
429       else
430         case_data_rw_idx (new_case, false_index)->f += weight;
431
432       prev_cp = cp;
433
434       casewriter_write (w, new_case);
435     }
436   casereader_destroy (r);
437
438   return casewriter_make_reader (w);
439 }
440
441
442
443 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
444
445 /*
446   This function does 3 things:
447
448   1. Counts the number of cases which are equal to every other case in READER,
449   and those cases for which the relationship between it and every other case
450   satifies PRED (normally either > or <).  VAR is variable defining a case's value
451   for this purpose.
452
453   2. Counts the number of true and false cases in reader, and populates
454   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
455   which receive these values.  POS_COND is the condition defining true
456   and false.
457   
458   3. CC is filled with the cumulative weight of all cases of READER.
459 */
460 static struct casereader *
461 process_group (const struct variable *var, struct casereader *reader,
462                bool (*pred) (double, double),
463                const struct dictionary *dict,
464                double *cc,
465                struct casereader **cutpoint_rdr, 
466                bool (*pos_cond) (double, double),
467                int true_index,
468                int false_index)
469 {
470   const struct variable *w = dict_get_weight (dict);
471
472   struct casereader *r1 =
473     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
474
475   const int weight_idx  = w ? var_get_case_index (w) :
476     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
477   
478   struct ccase *c1;
479
480   struct casereader *rclone = casereader_clone (r1);
481   struct casewriter *wtr;
482   struct caseproto *proto = caseproto_create ();
483
484   proto = caseproto_add_width (proto, 0);
485   proto = caseproto_add_width (proto, 0);
486   proto = caseproto_add_width (proto, 0);
487
488   wtr = autopaging_writer_create (proto);  
489
490   *cc = 0;
491
492   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
493     {
494       struct ccase *new_case = case_create (proto);
495       struct ccase *c2;
496       struct casereader *r2 = casereader_clone (rclone);
497
498       const double weight1 = case_data_idx (c1, weight_idx)->f;
499       const double d1 = case_data (c1, var)->f;
500       double n_eq = 0.0;
501       double n_pred = 0.0;
502
503       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
504                                          pos_cond,
505                                          true_index, false_index);
506
507       *cc += weight1;
508
509       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
510         {
511           const double d2 = case_data (c2, var)->f;
512           const double weight2 = case_data_idx (c2, weight_idx)->f;
513
514           if ( d1 == d2 )
515             {
516               n_eq += weight2;
517               continue;
518             }
519           else  if ( pred (d2, d1))
520             {
521               n_pred += weight2;
522             }
523         }
524
525       case_data_rw_idx (new_case, VALUE)->f = d1;
526       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
527       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
528
529       casewriter_write (wtr, new_case);
530
531       casereader_destroy (r2);
532     }
533
534   casereader_destroy (r1);
535   casereader_destroy (rclone);
536
537   return casewriter_make_reader (wtr);
538 }
539
540 /* Some more indeces into case data */
541 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
542 #define N_POS_GT 2  /* number of postive cases with values greater than n */
543 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
544 #define N_NEG_LT 4  /* number of negative cases with values less than n */
545
546 static bool
547 gt (double d1, double d2)
548 {
549   return d1 > d2;
550 }
551
552
553 static bool
554 ge (double d1, double d2)
555 {
556   return d1 > d2;
557 }
558
559 static bool
560 lt (double d1, double d2)
561 {
562   return d1 < d2;
563 }
564
565
566 /*
567   Return a casereader with width 3,
568   populated with cases based upon READER.
569   The cases will have the values:
570   (N, number of cases equal to N, number of cases greater than N)
571   As a side effect, update RS->n1 with the number of positive cases.
572 */
573 static struct casereader *
574 process_positive_group (const struct variable *var, struct casereader *reader,
575                         const struct dictionary *dict,
576                         struct roc_state *rs)
577 {
578   return process_group (var, reader, gt, dict, &rs->n1,
579                         &rs->cutpoint_rdr,
580                         ge,
581                         TP, FN);
582 }
583
584 /*
585   Return a casereader with width 3,
586   populated with cases based upon READER.
587   The cases will have the values:
588   (N, number of cases equal to N, number of cases less than N)
589   As a side effect, update RS->n2 with the number of negative cases.
590 */
591 static struct casereader *
592 process_negative_group (const struct variable *var, struct casereader *reader,
593                         const struct dictionary *dict,
594                         struct roc_state *rs)
595 {
596   return process_group (var, reader, lt, dict, &rs->n2,
597                         &rs->cutpoint_rdr,
598                         lt,
599                         TN, FP);
600 }
601
602
603
604
605 static void
606 append_cutpoint (struct casewriter *writer, double cutpoint)
607 {
608   struct ccase *cc = case_create (casewriter_get_proto (writer));
609
610   case_data_rw_idx (cc, CUTPOINT)->f = cutpoint;
611   case_data_rw_idx (cc, TP)->f = 0;
612   case_data_rw_idx (cc, FN)->f = 0;
613   case_data_rw_idx (cc, TN)->f = 0;
614   case_data_rw_idx (cc, FP)->f = 0;
615
616   casewriter_write (writer, cc);
617 }
618
619
620 /* 
621    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
622    be created with width 5, ready to take the values (cutpoint, TP, FN, TN, FP), and the
623    reader will be populated with its final number of cases.
624    However on exit from this function, only CUTPOINT entries will be set to their final
625    value.  The other entries will be initialised to zero.
626 */
627 static void
628 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
629 {
630   int i;
631   struct casereader *r = casereader_clone (input);
632   struct ccase *c;
633   struct caseproto *proto = caseproto_create ();
634
635   struct subcase ordering;
636   subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND);
637
638   proto = caseproto_add_width (proto, 0); /* cutpoint */
639   proto = caseproto_add_width (proto, 0); /* TP */
640   proto = caseproto_add_width (proto, 0); /* FN */
641   proto = caseproto_add_width (proto, 0); /* TN */
642   proto = caseproto_add_width (proto, 0); /* FP */
643
644   for (i = 0 ; i < roc->n_vars; ++i)
645     {
646       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
647       rs[i].prev_result = SYSMIS;
648       rs[i].max = -DBL_MAX;
649       rs[i].min = DBL_MAX;
650     }
651
652   for (; (c = casereader_read (r)) != NULL; case_unref (c))
653     {
654       for (i = 0 ; i < roc->n_vars; ++i)
655         {
656           const union value *v = case_data (c, roc->vars[i]); 
657           const double result = v->f;
658
659           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
660             continue;
661
662           minimize (&rs[i].min, result);
663           maximize (&rs[i].max, result);
664
665           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
666             {
667               const double mean = (result + rs[i].prev_result ) / 2.0;
668               append_cutpoint (rs[i].cutpoint_wtr, mean);
669             }
670
671           rs[i].prev_result = result;
672         }
673     }
674   casereader_destroy (r);
675
676
677   /* Append the min and max cutpoints */
678   for (i = 0 ; i < roc->n_vars; ++i)
679     {
680       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
681       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
682
683       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
684     }
685 }
686
687 static void
688 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
689 {
690   int i;
691
692   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
693
694   struct casereader *negatives = NULL;
695   struct casereader *positives = NULL;
696
697   struct caseproto *n_proto = caseproto_create ();
698
699   struct subcase up_ordering;
700   struct subcase down_ordering;
701
702   struct casewriter *neg_wtr = NULL;
703
704   struct casereader *input = casereader_create_filter_missing (reader,
705                                                                roc->vars, roc->n_vars,
706                                                                roc->exclude,
707                                                                NULL,
708                                                                NULL);
709
710   input = casereader_create_filter_missing (input,
711                                             &roc->state_var, 1,
712                                             roc->exclude,
713                                             NULL,
714                                             NULL);
715
716   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
717
718   prepare_cutpoints (roc, rs, input);
719
720
721   /* Separate the positive actual state cases from the negative ones */
722   positives = 
723     casereader_create_filter_func (input,
724                                    match_positives,
725                                    NULL,
726                                    roc,
727                                    neg_wtr);
728
729   n_proto = caseproto_create ();
730       
731   n_proto = caseproto_add_width (n_proto, 0);
732   n_proto = caseproto_add_width (n_proto, 0);
733   n_proto = caseproto_add_width (n_proto, 0);
734   n_proto = caseproto_add_width (n_proto, 0);
735   n_proto = caseproto_add_width (n_proto, 0);
736
737   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
738   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
739
740   for (i = 0 ; i < roc->n_vars; ++i)
741     {
742       struct casewriter *w = NULL;
743       struct casereader *r = NULL;
744
745       struct ccase *c;
746
747       struct ccase *cpos;
748       struct casereader *n_neg ;
749       const struct variable *var = roc->vars[i];
750
751       struct casereader *neg ;
752       struct casereader *pos = casereader_clone (positives);
753
754
755       struct casereader *n_pos =
756         process_positive_group (var, pos, dict, &rs[i]);
757
758       if ( negatives == NULL)
759         {
760           negatives = casewriter_make_reader (neg_wtr);
761         }
762
763       neg = casereader_clone (negatives);
764
765       n_neg = process_negative_group (var, neg, dict, &rs[i]);
766
767
768       /* Merge the n_pos and n_neg casereaders */
769       w = sort_create_writer (&up_ordering, n_proto);
770       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
771         {
772           struct ccase *pos_case = case_create (n_proto);
773           struct ccase *cneg;
774           const double jpos = case_data_idx (cpos, VALUE)->f;
775
776           while ((cneg = casereader_read (n_neg)))
777             {
778               struct ccase *nc = case_create (n_proto);
779
780               const double jneg = case_data_idx (cneg, VALUE)->f;
781
782               case_data_rw_idx (nc, VALUE)->f = jneg;
783               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
784
785               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
786
787               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
788               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
789
790               casewriter_write (w, nc);
791
792               case_unref (cneg);
793               if ( jneg > jpos)
794                 break;
795             }
796
797           case_data_rw_idx (pos_case, VALUE)->f = jpos;
798           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
799           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
800           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
801           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
802
803           casewriter_write (w, pos_case);
804         }
805
806 /* These aren't used anymore */
807 #undef N_EQ
808 #undef N_PRED
809
810       r = casewriter_make_reader (w);
811
812       /* Propagate the N_POS_GT values from the positive cases
813          to the negative ones */
814       {
815         double prev_pos_gt = rs[i].n1;
816         w = sort_create_writer (&down_ordering, n_proto);
817
818         for ( ; (c = casereader_read (r) ); case_unref (c))
819           {
820             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
821             struct ccase *nc = case_clone (c);
822
823             if ( n_pos_gt == SYSMIS)
824               {
825                 n_pos_gt = prev_pos_gt;
826                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
827               }
828             
829             casewriter_write (w, nc);
830             prev_pos_gt = n_pos_gt;
831           }
832
833         r = casewriter_make_reader (w);
834       }
835
836       /* Propagate the N_NEG_LT values from the negative cases
837          to the positive ones */
838       {
839         double prev_neg_lt = rs[i].n2;
840         w = sort_create_writer (&up_ordering, n_proto);
841
842         for ( ; (c = casereader_read (r) ); case_unref (c))
843           {
844             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
845             struct ccase *nc = case_clone (c);
846
847             if ( n_neg_lt == SYSMIS)
848               {
849                 n_neg_lt = prev_neg_lt;
850                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
851               }
852             
853             casewriter_write (w, nc);
854             prev_neg_lt = n_neg_lt;
855           }
856
857         r = casewriter_make_reader (w);
858       }
859
860       {
861         struct ccase *prev_case = NULL;
862         for ( ; (c = casereader_read (r) ); case_unref (c))
863           {
864             const struct ccase *next_case = casereader_peek (r, 0);
865
866             const double j = case_data_idx (c, VALUE)->f;
867             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
868             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
869             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
870             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
871
872             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
873               {
874                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
875                   {
876                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
877                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
878                   }
879
880                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
881                   {
882                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
883                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
884                   }
885               }
886
887             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
888               {
889                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
890
891                 rs[i].q1hat +=
892                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
893                 rs[i].q2hat +=
894                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
895
896               }
897
898             case_unref (prev_case);
899             prev_case = case_clone (c);
900           }
901
902         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
903         if ( roc->invert ) 
904           rs[i].auc = 1 - rs[i].auc;
905
906         if ( roc->bi_neg_exp )
907           {
908             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
909             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
910           }
911         else
912           {
913             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
914             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
915           }
916       }
917     }
918
919   casereader_destroy (positives);
920   casereader_destroy (negatives);
921
922   output_roc (rs, roc);
923
924   free (rs);
925 }
926
927 static void
928 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
929 {
930   int i;
931   const int n_fields = roc->print_se ? 5 : 1;
932   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
933   const int n_rows = 2 + roc->n_vars;
934   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
935
936   if ( roc->n_vars > 1)
937     tab_title (tbl, _("Area Under the Curve"));
938   else
939     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
940
941   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
942
943   tab_dim (tbl, tab_natural_dimensions, NULL);
944
945   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
946
947   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
948
949   tab_box (tbl,
950            TAL_2, TAL_2,
951            -1, TAL_1,
952            0, 0,
953            n_cols - 1,
954            n_rows - 1);
955
956   if ( roc->print_se )
957     {
958       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
959       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
960
961       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
962       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
963
964       tab_joint_text (tbl, n_cols - 2, 0, 4, 0,
965                       TAT_TITLE | TAB_CENTER | TAT_PRINTF,
966                       _("Asymp. %g%% Confidence Interval"), roc->ci);
967       tab_vline (tbl, 0, n_cols - 1, 0, 0);
968       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
969     }
970
971   if ( roc->n_vars > 1)
972     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
973
974   if ( roc->n_vars > 1)
975     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
976
977
978   for ( i = 0 ; i < roc->n_vars ; ++i )
979     {
980       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
981
982       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
983
984       if ( roc->print_se )
985         {
986           double se ;
987           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
988                                       (12 * rs[i].n1 * rs[i].n2));
989           double ci ;
990           double yy ;
991
992           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
993             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
994
995           se /= rs[i].n1 * rs[i].n2;
996
997           se = sqrt (se);
998
999           tab_double (tbl, n_cols - 4, 2 + i, 0,
1000                       se,
1001                       NULL);
1002
1003           ci = 1 - roc->ci / 100.0;
1004           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1005
1006           tab_double (tbl, n_cols - 2, 2 + i, 0,
1007                       rs[i].auc - yy,
1008                       NULL);
1009
1010           tab_double (tbl, n_cols - 1, 2 + i, 0,
1011                       rs[i].auc + yy,
1012                       NULL);
1013
1014           tab_double (tbl, n_cols - 3, 2 + i, 0,
1015                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1016                       NULL);
1017         }
1018     }
1019
1020   tab_submit (tbl);
1021 }
1022
1023
1024 static void
1025 show_summary (const struct cmd_roc *roc)
1026 {
1027   const int n_cols = 3;
1028   const int n_rows = 4;
1029   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
1030
1031   tab_title (tbl, _("Case Summary"));
1032
1033   tab_headers (tbl, 1, 0, 2, 0);
1034
1035   tab_dim (tbl, tab_natural_dimensions, NULL);
1036
1037   tab_box (tbl,
1038            TAL_2, TAL_2,
1039            -1, -1,
1040            0, 0,
1041            n_cols - 1,
1042            n_rows - 1);
1043
1044   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1045   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1046
1047
1048   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1049   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1050
1051
1052   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1053   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1054   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1055
1056   tab_joint_text (tbl, 1, 0, 2, 0,
1057                   TAT_TITLE | TAB_CENTER,
1058                   _("Valid N (listwise)"));
1059
1060
1061   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1062   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1063
1064
1065   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1066   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1067
1068   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1069   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1070
1071   tab_submit (tbl);
1072 }
1073
1074
1075 static void
1076 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1077 {
1078   int x = 1;
1079   int i;
1080   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1081   int n_rows = 1;
1082   struct tab_table *tbl ;
1083
1084   for (i = 0; i < roc->n_vars; ++i)
1085     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1086
1087   tbl = tab_create (n_cols, n_rows, 0);
1088
1089   if ( roc->n_vars > 1)
1090     tab_title (tbl, _("Coordinates of the Curve"));
1091   else
1092     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1093
1094
1095   tab_headers (tbl, 1, 0, 1, 0);
1096
1097   tab_dim (tbl, tab_natural_dimensions, NULL);
1098
1099   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1100
1101   if ( roc->n_vars > 1)
1102     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1103
1104   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1105   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1106   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1107
1108   tab_box (tbl,
1109            TAL_2, TAL_2,
1110            -1, TAL_1,
1111            0, 0,
1112            n_cols - 1,
1113            n_rows - 1);
1114
1115   if ( roc->n_vars > 1)
1116     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1117
1118   for (i = 0; i < roc->n_vars; ++i)
1119     {
1120       struct ccase *cc;
1121       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1122
1123       if ( roc->n_vars > 1)
1124         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1125
1126       if ( i > 0)
1127         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1128
1129
1130       for (; (cc = casereader_read (r)) != NULL;
1131            case_unref (cc), x++)
1132         {
1133           const double se = case_data_idx (cc, TP)->f /
1134             (
1135              case_data_idx (cc, TP)->f
1136              +
1137              case_data_idx (cc, FN)->f
1138              );
1139
1140           const double sp = case_data_idx (cc, TN)->f /
1141             (
1142              case_data_idx (cc, TN)->f
1143              +
1144              case_data_idx (cc, FP)->f
1145              );
1146
1147           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, CUTPOINT)->f,
1148                       var_get_print_format (roc->vars[i]));
1149
1150           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1151           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1152         }
1153
1154       casereader_destroy (r);
1155     }
1156
1157   tab_submit (tbl);
1158 }
1159
1160
1161 static void
1162 draw_roc (struct roc_state *rs, const struct cmd_roc *roc)
1163 {
1164   int i;
1165
1166   struct chart *roc_chart = chart_create ();
1167
1168   chart_write_title (roc_chart, _("ROC Curve"));
1169   chart_write_xlabel (roc_chart, _("1 - Specificity"));
1170   chart_write_ylabel (roc_chart, _("Sensitivity"));
1171
1172   chart_write_xscale (roc_chart, 0, 1, 5);
1173   chart_write_yscale (roc_chart, 0, 1, 5);
1174
1175   if ( roc->reference )
1176     {
1177       chart_line (roc_chart, 1.0, 0,
1178                   0.0, 1.0,
1179                   CHART_DIM_X);
1180     }
1181
1182   for (i = 0; i < roc->n_vars; ++i)
1183     {
1184       struct ccase *cc;
1185       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1186
1187       chart_vector_start (roc_chart, var_get_name (roc->vars[i]));
1188       for (; (cc = casereader_read (r)) != NULL;
1189            case_unref (cc))
1190         {
1191           double se = case_data_idx (cc, TP)->f;
1192           double sp = case_data_idx (cc, TN)->f;
1193
1194           se /= case_data_idx (cc, FN)->f +
1195             case_data_idx (cc, TP)->f ;
1196
1197           sp /= case_data_idx (cc, TN)->f +
1198             case_data_idx (cc, FP)->f ;
1199
1200           chart_vector (roc_chart, 1 - sp, se);
1201         }
1202       chart_vector_end (roc_chart);
1203       casereader_destroy (r);
1204     }
1205
1206   chart_write_legend (roc_chart);
1207
1208   chart_submit (roc_chart);
1209 }
1210
1211
1212 static void
1213 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1214 {
1215   show_summary (roc);
1216
1217   if ( roc->curve )
1218     draw_roc (rs, roc);
1219
1220   show_auc (rs, roc);
1221
1222
1223   if ( roc->print_coords )
1224     show_coords (rs, roc);
1225 }
1226