Add some comments and macros to make the code more readable
[pspp-builds.git] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/procedure.h>
20 #include <language/lexer/variable-parser.h>
21 #include <language/lexer/value-parser.h>
22 #include <language/command.h>
23 #include <language/lexer/lexer.h>
24
25 #include <data/casegrouper.h>
26 #include <data/casereader.h>
27 #include <data/casewriter.h>
28 #include <data/dictionary.h>
29 #include <data/format.h>
30 #include <math/sort.h>
31 #include <data/subcase.h>
32
33
34 #include <libpspp/misc.h>
35
36 #include <gsl/gsl_cdf.h>
37 #include <output/table.h>
38
39 #include <output/charts/plot-chart.h>
40 #include <output/charts/cartesian.h>
41
42 #include "gettext.h"
43 #define _(msgid) gettext (msgid)
44 #define N_(msgid) msgid
45
46 struct cmd_roc
47 {
48   size_t n_vars;
49   const struct variable **vars;
50   const struct dictionary *dict;
51
52   const struct variable *state_var ;
53   union value state_value;
54
55   /* Plot the roc curve */
56   bool curve;
57   /* Plot the reference line */
58   bool reference;
59
60   double ci;
61
62   bool print_coords;
63   bool print_se;
64   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
65                       should be used */
66   enum mv_class exclude;
67
68   bool invert ; /* True iff a smaller test result variable indicates
69                    a positive result */
70
71   double pos;
72   double neg;
73   double pos_weighted;
74   double neg_weighted;
75 };
76
77 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
78
79 int
80 cmd_roc (struct lexer *lexer, struct dataset *ds)
81 {
82   struct cmd_roc roc ;
83   const struct dictionary *dict = dataset_dict (ds);
84
85   roc.vars = NULL;
86   roc.n_vars = 0;
87   roc.print_se = false;
88   roc.print_coords = false;
89   roc.exclude = MV_ANY;
90   roc.curve = true;
91   roc.reference = false;
92   roc.ci = 95;
93   roc.bi_neg_exp = false;
94   roc.invert = false;
95   roc.pos = roc.pos_weighted = 0;
96   roc.neg = roc.neg_weighted = 0;
97   roc.dict = dataset_dict (ds);
98
99   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
100                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
101     goto error;;
102
103   if ( ! lex_force_match (lexer, T_BY))
104     {
105       goto error;;
106     }
107
108   roc.state_var = parse_variable (lexer, dict);
109
110   if ( !lex_force_match (lexer, '('))
111     {
112       goto error;;
113     }
114
115   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
116
117
118   if ( !lex_force_match (lexer, ')'))
119     {
120       goto error;;
121     }
122
123
124   while (lex_token (lexer) != '.')
125     {
126       lex_match (lexer, '/');
127       if (lex_match_id (lexer, "MISSING"))
128         {
129           lex_match (lexer, '=');
130           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
131             {
132               if (lex_match_id (lexer, "INCLUDE"))
133                 {
134                   roc.exclude = MV_SYSTEM;
135                 }
136               else if (lex_match_id (lexer, "EXCLUDE"))
137                 {
138                   roc.exclude = MV_ANY;
139                 }
140               else
141                 {
142                   lex_error (lexer, NULL);
143                   goto error;;
144                 }
145             }
146         }
147       else if (lex_match_id (lexer, "PLOT"))
148         {
149           lex_match (lexer, '=');
150           if (lex_match_id (lexer, "CURVE"))
151             {
152               roc.curve = true;
153               if (lex_match (lexer, '('))
154                 {
155                   roc.reference = true;
156                   lex_force_match_id (lexer, "REFERENCE");
157                   lex_force_match (lexer, ')');
158                 }
159             }
160           else if (lex_match_id (lexer, "NONE"))
161             {
162               roc.curve = false;
163             }
164           else
165             {
166               lex_error (lexer, NULL);
167               goto error;;
168             }
169         }
170       else if (lex_match_id (lexer, "PRINT"))
171         {
172           lex_match (lexer, '=');
173           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
174             {
175               if (lex_match_id (lexer, "SE"))
176                 {
177                   roc.print_se = true;
178                 }
179               else if (lex_match_id (lexer, "COORDINATES"))
180                 {
181                   roc.print_coords = true;
182                 }
183               else
184                 {
185                   lex_error (lexer, NULL);
186                   goto error;;
187                 }
188             }
189         }
190       else if (lex_match_id (lexer, "CRITERIA"))
191         {
192           lex_match (lexer, '=');
193           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
194             {
195               if (lex_match_id (lexer, "CUTOFF"))
196                 {
197                   lex_force_match (lexer, '(');
198                   if (lex_match_id (lexer, "INCLUDE"))
199                     {
200                       roc.exclude = MV_SYSTEM;
201                     }
202                   else if (lex_match_id (lexer, "EXCLUDE"))
203                     {
204                       roc.exclude = MV_USER | MV_SYSTEM;
205                     }
206                   else
207                     {
208                       lex_error (lexer, NULL);
209                       goto error;;
210                     }
211                   lex_force_match (lexer, ')');
212                 }
213               else if (lex_match_id (lexer, "TESTPOS"))
214                 {
215                   lex_force_match (lexer, '(');
216                   if (lex_match_id (lexer, "LARGE"))
217                     {
218                       roc.invert = false;
219                     }
220                   else if (lex_match_id (lexer, "SMALL"))
221                     {
222                       roc.invert = true;
223                     }
224                   else
225                     {
226                       lex_error (lexer, NULL);
227                       goto error;;
228                     }
229                   lex_force_match (lexer, ')');
230                 }
231               else if (lex_match_id (lexer, "CI"))
232                 {
233                   lex_force_match (lexer, '(');
234                   lex_force_num (lexer);
235                   roc.ci = lex_number (lexer);
236                   lex_get (lexer);
237                   lex_force_match (lexer, ')');
238                 }
239               else if (lex_match_id (lexer, "DISTRIBUTION"))
240                 {
241                   lex_force_match (lexer, '(');
242                   if (lex_match_id (lexer, "FREE"))
243                     {
244                       roc.bi_neg_exp = false;
245                     }
246                   else if (lex_match_id (lexer, "NEGEXPO"))
247                     {
248                       roc.bi_neg_exp = true;
249                     }
250                   else
251                     {
252                       lex_error (lexer, NULL);
253                       goto error;;
254                     }
255                   lex_force_match (lexer, ')');
256                 }
257               else
258                 {
259                   lex_error (lexer, NULL);
260                   goto error;;
261                 }
262             }
263         }
264       else
265         {
266           lex_error (lexer, NULL);
267           break;
268         }
269     }
270
271   if ( ! run_roc (ds, &roc)) 
272     goto error;;
273
274   return CMD_SUCCESS;
275
276  error:
277   free (roc.vars);
278   return CMD_FAILURE;
279 }
280
281
282
283
284 static void
285 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
286
287
288 static int
289 run_roc (struct dataset *ds, struct cmd_roc *roc)
290 {
291   struct dictionary *dict = dataset_dict (ds);
292   bool ok;
293   struct casereader *group;
294
295   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
296   while (casegrouper_get_next_group (grouper, &group))
297     {
298       do_roc (roc, group, dataset_dict (ds));
299     }
300   ok = casegrouper_destroy (grouper);
301   ok = proc_commit (ds) && ok;
302
303   return ok;
304 }
305
306 #if 0
307 static void
308 dump_casereader (struct casereader *reader)
309 {
310   struct ccase *c;
311   struct casereader *r = casereader_clone (reader);
312
313   for ( ; (c = casereader_read (r) ); case_unref (c))
314     {
315       int i;
316       for (i = 0 ; i < case_get_value_cnt (c); ++i)
317         {
318           printf ("%g ", case_data_idx (c, i)->f);
319         }
320       printf ("\n");
321     }
322
323   casereader_destroy (r);
324 }
325 #endif
326
327
328 /* 
329    Return true iff the state variable indicates that C has positive actual state.
330
331    As a side effect, this function also accumulates the roc->{pos,neg} and 
332    roc->{pos,neg}_weighted counts.
333  */
334 static bool
335 match_positives (const struct ccase *c, void *aux)
336 {
337   struct cmd_roc *roc = aux;
338   const struct variable *wv = dict_get_weight (roc->dict);
339   const double weight = wv ? case_data (c, wv)->f : 1.0;
340
341   const bool positive =
342   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
343     var_get_width (roc->state_var)));
344
345   if ( positive )
346     {
347       roc->pos++;
348       roc->pos_weighted += weight;
349     }
350   else
351     {
352       roc->neg++;
353       roc->neg_weighted += weight;
354     }
355
356   return positive;
357 }
358
359
360 #define VALUE  0
361 #define N_EQ   1
362 #define N_PRED 2
363
364 /* Some intermediate state for calculating the cutpoints and the 
365    standard error values */
366 struct roc_state
367 {
368   double auc;
369
370   double n1;
371   double n2;
372
373   double q1hat;
374   double q2hat;
375
376   struct casewriter *cutpoint_wtr;
377   struct casereader *cutpoint_rdr;
378   double prev_result;
379   double min;
380   double max;
381 };
382
383 #define CUTPOINT 0
384 #define TP 1
385 #define FN 2
386 #define TN 3
387 #define FP 4
388
389
390 /* 
391    Return a new casereader based upon CUTPOINT_RDR.
392    The number of "positive" cases are placed into
393    the position TRUE_INDEX, and the number of "negative" cases
394    into FALSE_INDEX.
395    POS_COND and RESULT determine the semantics of what is 
396    "positive".
397    WEIGHT is the value of a single count.
398  */
399 static struct casereader *
400 accumulate_counts (struct casereader *cutpoint_rdr, 
401                    double result, double weight, 
402                    bool (*pos_cond) (double, double),
403                    int true_index, int false_index)
404 {
405   const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
406   struct casewriter *w =
407     autopaging_writer_create (proto);
408   struct casereader *r = casereader_clone (cutpoint_rdr);
409   struct ccase *cpc;
410   double prev_cp = SYSMIS;
411
412   for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
413     {
414       struct ccase *new_case;
415       const double cp = case_data_idx (cpc, CUTPOINT)->f;
416
417       assert (cp != SYSMIS);
418
419       /* We don't want duplicates here */
420       if ( cp == prev_cp )
421         continue;
422
423       new_case = case_clone (cpc);
424
425       if ( pos_cond (result, cp))
426         case_data_rw_idx (new_case, true_index)->f += weight;
427       else
428         case_data_rw_idx (new_case, false_index)->f += weight;
429
430       prev_cp = cp;
431
432       casewriter_write (w, new_case);
433     }
434   casereader_destroy (r);
435
436   return casewriter_make_reader (w);
437 }
438
439
440
441 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
442
443
444 static struct casereader *
445 process_group (const struct variable *var, struct casereader *reader,
446                bool (*pred) (double, double),
447                const struct dictionary *dict,
448                double *cc,
449                struct casereader **cutpoint_rdr, 
450                bool (*pos_cond) (double, double),
451                int true_index,
452                int false_index)
453 {
454   const struct variable *w = dict_get_weight (dict);
455
456   struct casereader *r1 =
457     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
458
459   const int weight_idx  = w ? var_get_case_index (w) :
460     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
461   
462   struct ccase *c1;
463
464   struct casereader *rclone = casereader_clone (r1);
465   struct casewriter *wtr;
466   struct caseproto *proto = caseproto_create ();
467
468   proto = caseproto_add_width (proto, 0);
469   proto = caseproto_add_width (proto, 0);
470   proto = caseproto_add_width (proto, 0);
471
472   wtr = autopaging_writer_create (proto);  
473
474   *cc = 0;
475
476   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
477     {
478       struct ccase *new_case = case_create (proto);
479       struct ccase *c2;
480       struct casereader *r2 = casereader_clone (rclone);
481
482       const double weight1 = case_data_idx (c1, weight_idx)->f;
483       const double d1 = case_data (c1, var)->f;
484       double n_eq = 0.0;
485       double n_pred = 0.0;
486
487       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
488                                          pos_cond,
489                                          true_index, false_index);
490
491       *cc += weight1;
492
493       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
494         {
495           const double d2 = case_data (c2, var)->f;
496           const double weight2 = case_data_idx (c2, weight_idx)->f;
497
498           if ( d1 == d2 )
499             {
500               n_eq += weight2;
501               continue;
502             }
503           else  if ( pred (d2, d1))
504             {
505               n_pred += weight2;
506             }
507         }
508
509       case_data_rw_idx (new_case, VALUE)->f = d1;
510       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
511       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
512
513       casewriter_write (wtr, new_case);
514
515       casereader_destroy (r2);
516     }
517
518   casereader_destroy (r1);
519   casereader_destroy (rclone);
520
521   return casewriter_make_reader (wtr);
522 }
523
524 /* Some more indeces into case data */
525 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
526 #define N_POS_GT 2  /* number of postive cases with values greater than n */
527 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
528 #define N_NEG_LT 4  /* number of negative cases with values less than n */
529
530 static bool
531 gt (double d1, double d2)
532 {
533   return d1 > d2;
534 }
535
536
537 static bool
538 ge (double d1, double d2)
539 {
540   return d1 > d2;
541 }
542
543 static bool
544 lt (double d1, double d2)
545 {
546   return d1 < d2;
547 }
548
549
550 /*
551   Return a casereader with width 3,
552   populated with cases based upon READER.
553   The cases will have the values:
554   (N, number of cases equal to N, number of cases greater than N)
555   As a side effect, update RS->n1 with the number of positive cases.
556 */
557 static struct casereader *
558 process_positive_group (const struct variable *var, struct casereader *reader,
559                         const struct dictionary *dict,
560                         struct roc_state *rs)
561 {
562   return process_group (var, reader, gt, dict, &rs->n1,
563                         &rs->cutpoint_rdr,
564                         ge,
565                         TP, FN);
566 }
567
568 /*
569   Return a casereader with width 3,
570   populated with cases based upon READER.
571   The cases will have the values:
572   (N, number of cases equal to N, number of cases less than N)
573   As a side effect, update RS->n2 with the number of negative cases.
574 */
575 static struct casereader *
576 process_negative_group (const struct variable *var, struct casereader *reader,
577                         const struct dictionary *dict,
578                         struct roc_state *rs)
579 {
580   return process_group (var, reader, lt, dict, &rs->n2,
581                         &rs->cutpoint_rdr,
582                         lt,
583                         TN, FP);
584 }
585
586
587
588
589 static void
590 append_cutpoint (struct casewriter *writer, double cutpoint)
591 {
592   struct ccase *cc = case_create (casewriter_get_proto (writer));
593
594   case_data_rw_idx (cc, CUTPOINT)->f = cutpoint;
595   case_data_rw_idx (cc, TP)->f = 0;
596   case_data_rw_idx (cc, FN)->f = 0;
597   case_data_rw_idx (cc, TN)->f = 0;
598   case_data_rw_idx (cc, FP)->f = 0;
599
600   casewriter_write (writer, cc);
601 }
602
603
604 /* 
605    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
606    be created with width 5, ready to take the values (cutpoint, TP, FN, TN, FP), and the
607    reader will be populated with its final number of cases.
608    However on exit from this function, only CUTPOINT entries will be set to their final
609    value.  The other entries will be initialised to zero.
610 */
611 static void
612 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
613 {
614   int i;
615   struct casereader *r = casereader_clone (input);
616   struct ccase *c;
617   struct caseproto *proto = caseproto_create ();
618
619   struct subcase ordering;
620   subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND);
621
622   proto = caseproto_add_width (proto, 0); /* cutpoint */
623   proto = caseproto_add_width (proto, 0); /* TP */
624   proto = caseproto_add_width (proto, 0); /* FN */
625   proto = caseproto_add_width (proto, 0); /* TN */
626   proto = caseproto_add_width (proto, 0); /* FP */
627
628   for (i = 0 ; i < roc->n_vars; ++i)
629     {
630       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
631       rs[i].prev_result = SYSMIS;
632       rs[i].max = -DBL_MAX;
633       rs[i].min = DBL_MAX;
634     }
635
636   for (; (c = casereader_read (r)) != NULL; case_unref (c))
637     {
638       for (i = 0 ; i < roc->n_vars; ++i)
639         {
640           const union value *v = case_data (c, roc->vars[i]); 
641           const double result = v->f;
642
643           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
644             continue;
645
646           minimize (&rs[i].min, result);
647           maximize (&rs[i].max, result);
648
649           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
650             {
651               const double mean = (result + rs[i].prev_result ) / 2.0;
652               append_cutpoint (rs[i].cutpoint_wtr, mean);
653             }
654
655           rs[i].prev_result = result;
656         }
657     }
658   casereader_destroy (r);
659
660
661   /* Append the min and max cutpoints */
662   for (i = 0 ; i < roc->n_vars; ++i)
663     {
664       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
665       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
666
667       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
668     }
669 }
670
671 static void
672 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
673 {
674   int i;
675
676   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
677
678   struct casereader *negatives = NULL;
679   struct casereader *positives = NULL;
680
681   struct caseproto *n_proto = caseproto_create ();
682
683   struct subcase up_ordering;
684   struct subcase down_ordering;
685
686   struct casewriter *neg_wtr = NULL;
687
688   struct casereader *input = casereader_create_filter_missing (reader,
689                                                                roc->vars, roc->n_vars,
690                                                                roc->exclude,
691                                                                NULL,
692                                                                NULL);
693
694   input = casereader_create_filter_missing (input,
695                                             &roc->state_var, 1,
696                                             roc->exclude,
697                                             NULL,
698                                             NULL);
699
700   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
701
702   prepare_cutpoints (roc, rs, input);
703
704
705   /* Separate the positive actual state cases from the negative ones */
706   positives = 
707     casereader_create_filter_func (input,
708                                    match_positives,
709                                    NULL,
710                                    roc,
711                                    neg_wtr);
712
713   n_proto = caseproto_create ();
714       
715   n_proto = caseproto_add_width (n_proto, 0);
716   n_proto = caseproto_add_width (n_proto, 0);
717   n_proto = caseproto_add_width (n_proto, 0);
718   n_proto = caseproto_add_width (n_proto, 0);
719   n_proto = caseproto_add_width (n_proto, 0);
720
721   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
722   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
723
724   for (i = 0 ; i < roc->n_vars; ++i)
725     {
726       struct casewriter *w = NULL;
727       struct casereader *r = NULL;
728
729       struct ccase *c;
730
731       struct ccase *cpos;
732       struct casereader *n_neg ;
733       const struct variable *var = roc->vars[i];
734
735       struct casereader *neg ;
736       struct casereader *pos = casereader_clone (positives);
737
738
739       struct casereader *n_pos =
740         process_positive_group (var, pos, dict, &rs[i]);
741
742       if ( negatives == NULL)
743         {
744           negatives = casewriter_make_reader (neg_wtr);
745         }
746
747       neg = casereader_clone (negatives);
748
749       n_neg = process_negative_group (var, neg, dict, &rs[i]);
750
751
752       /* Merge the n_pos and n_neg casereaders */
753       w = sort_create_writer (&up_ordering, n_proto);
754       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
755         {
756           struct ccase *pos_case = case_create (n_proto);
757           struct ccase *cneg;
758           const double jpos = case_data_idx (cpos, VALUE)->f;
759
760           while ((cneg = casereader_read (n_neg)))
761             {
762               struct ccase *nc = case_create (n_proto);
763
764               const double jneg = case_data_idx (cneg, VALUE)->f;
765
766               case_data_rw_idx (nc, VALUE)->f = jneg;
767               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
768
769               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
770
771               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
772               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
773
774               casewriter_write (w, nc);
775
776               case_unref (cneg);
777               if ( jneg > jpos)
778                 break;
779             }
780
781           case_data_rw_idx (pos_case, VALUE)->f = jpos;
782           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
783           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
784           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
785           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
786
787           casewriter_write (w, pos_case);
788         }
789
790 /* These aren't used anymore */
791 #undef N_EQ
792 #undef N_PRED
793
794       r = casewriter_make_reader (w);
795
796       /* Propagate the N_POS_GT values from the positive cases
797          to the negative ones */
798       {
799         double prev_pos_gt = rs[i].n1;
800         w = sort_create_writer (&down_ordering, n_proto);
801
802         for ( ; (c = casereader_read (r) ); case_unref (c))
803           {
804             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
805             struct ccase *nc = case_clone (c);
806
807             if ( n_pos_gt == SYSMIS)
808               {
809                 n_pos_gt = prev_pos_gt;
810                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
811               }
812             
813             casewriter_write (w, nc);
814             prev_pos_gt = n_pos_gt;
815           }
816
817         r = casewriter_make_reader (w);
818       }
819
820       /* Propagate the N_NEG_LT values from the negative cases
821          to the positive ones */
822       {
823         double prev_neg_lt = rs[i].n2;
824         w = sort_create_writer (&up_ordering, n_proto);
825
826         for ( ; (c = casereader_read (r) ); case_unref (c))
827           {
828             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
829             struct ccase *nc = case_clone (c);
830
831             if ( n_neg_lt == SYSMIS)
832               {
833                 n_neg_lt = prev_neg_lt;
834                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
835               }
836             
837             casewriter_write (w, nc);
838             prev_neg_lt = n_neg_lt;
839           }
840
841         r = casewriter_make_reader (w);
842       }
843
844       {
845         struct ccase *prev_case = NULL;
846         for ( ; (c = casereader_read (r) ); case_unref (c))
847           {
848             const struct ccase *next_case = casereader_peek (r, 0);
849
850             const double j = case_data_idx (c, VALUE)->f;
851             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
852             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
853             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
854             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
855
856             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
857               {
858                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
859                   {
860                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
861                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
862                   }
863
864                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
865                   {
866                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
867                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
868                   }
869               }
870
871             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
872               {
873                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
874
875                 rs[i].q1hat +=
876                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
877                 rs[i].q2hat +=
878                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
879
880               }
881
882             case_unref (prev_case);
883             prev_case = case_clone (c);
884           }
885
886         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
887         if ( roc->invert ) 
888           rs[i].auc = 1 - rs[i].auc;
889
890         if ( roc->bi_neg_exp )
891           {
892             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
893             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
894           }
895         else
896           {
897             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
898             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
899           }
900       }
901     }
902
903   casereader_destroy (positives);
904   casereader_destroy (negatives);
905
906   output_roc (rs, roc);
907
908   free (rs);
909 }
910
911 static void
912 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
913 {
914   int i;
915   const int n_fields = roc->print_se ? 5 : 1;
916   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
917   const int n_rows = 2 + roc->n_vars;
918   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
919
920   if ( roc->n_vars > 1)
921     tab_title (tbl, _("Area Under the Curve"));
922   else
923     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
924
925   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
926
927   tab_dim (tbl, tab_natural_dimensions, NULL);
928
929   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
930
931   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
932
933   tab_box (tbl,
934            TAL_2, TAL_2,
935            -1, TAL_1,
936            0, 0,
937            n_cols - 1,
938            n_rows - 1);
939
940   if ( roc->print_se )
941     {
942       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
943       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
944
945       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
946       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
947
948       tab_joint_text (tbl, n_cols - 2, 0, 4, 0,
949                       TAT_TITLE | TAB_CENTER | TAT_PRINTF,
950                       _("Asymp. %g%% Confidence Interval"), roc->ci);
951       tab_vline (tbl, 0, n_cols - 1, 0, 0);
952       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
953     }
954
955   if ( roc->n_vars > 1)
956     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
957
958   if ( roc->n_vars > 1)
959     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
960
961
962   for ( i = 0 ; i < roc->n_vars ; ++i )
963     {
964       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
965
966       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
967
968       if ( roc->print_se )
969         {
970           double se ;
971           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
972                                       (12 * rs[i].n1 * rs[i].n2));
973           double ci ;
974           double yy ;
975
976           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
977             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
978
979           se /= rs[i].n1 * rs[i].n2;
980
981           se = sqrt (se);
982
983           tab_double (tbl, n_cols - 4, 2 + i, 0,
984                       se,
985                       NULL);
986
987           ci = 1 - roc->ci / 100.0;
988           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
989
990           tab_double (tbl, n_cols - 2, 2 + i, 0,
991                       rs[i].auc - yy,
992                       NULL);
993
994           tab_double (tbl, n_cols - 1, 2 + i, 0,
995                       rs[i].auc + yy,
996                       NULL);
997
998           tab_double (tbl, n_cols - 3, 2 + i, 0,
999                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1000                       NULL);
1001         }
1002     }
1003
1004   tab_submit (tbl);
1005 }
1006
1007
1008 static void
1009 show_summary (const struct cmd_roc *roc)
1010 {
1011   const int n_cols = 3;
1012   const int n_rows = 4;
1013   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
1014
1015   tab_title (tbl, _("Case Summary"));
1016
1017   tab_headers (tbl, 1, 0, 2, 0);
1018
1019   tab_dim (tbl, tab_natural_dimensions, NULL);
1020
1021   tab_box (tbl,
1022            TAL_2, TAL_2,
1023            -1, -1,
1024            0, 0,
1025            n_cols - 1,
1026            n_rows - 1);
1027
1028   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1029   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1030
1031
1032   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1033   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1034
1035
1036   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1037   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1038   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1039
1040   tab_joint_text (tbl, 1, 0, 2, 0,
1041                   TAT_TITLE | TAB_CENTER,
1042                   _("Valid N (listwise)"));
1043
1044
1045   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1046   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1047
1048
1049   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1050   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1051
1052   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1053   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1054
1055   tab_submit (tbl);
1056 }
1057
1058
1059 static void
1060 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1061 {
1062   int x = 1;
1063   int i;
1064   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1065   int n_rows = 1;
1066   struct tab_table *tbl ;
1067
1068   for (i = 0; i < roc->n_vars; ++i)
1069     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1070
1071   tbl = tab_create (n_cols, n_rows, 0);
1072
1073   if ( roc->n_vars > 1)
1074     tab_title (tbl, _("Coordinates of the Curve"));
1075   else
1076     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1077
1078
1079   tab_headers (tbl, 1, 0, 1, 0);
1080
1081   tab_dim (tbl, tab_natural_dimensions, NULL);
1082
1083   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1084
1085   if ( roc->n_vars > 1)
1086     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1087
1088   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1089   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1090   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1091
1092   tab_box (tbl,
1093            TAL_2, TAL_2,
1094            -1, TAL_1,
1095            0, 0,
1096            n_cols - 1,
1097            n_rows - 1);
1098
1099   if ( roc->n_vars > 1)
1100     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1101
1102   for (i = 0; i < roc->n_vars; ++i)
1103     {
1104       struct ccase *cc;
1105       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1106
1107       if ( roc->n_vars > 1)
1108         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1109
1110       if ( i > 0)
1111         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1112
1113
1114       for (; (cc = casereader_read (r)) != NULL;
1115            case_unref (cc), x++)
1116         {
1117           const double se = case_data_idx (cc, TP)->f /
1118             (
1119              case_data_idx (cc, TP)->f
1120              +
1121              case_data_idx (cc, FN)->f
1122              );
1123
1124           const double sp = case_data_idx (cc, TN)->f /
1125             (
1126              case_data_idx (cc, TN)->f
1127              +
1128              case_data_idx (cc, FP)->f
1129              );
1130
1131           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, CUTPOINT)->f,
1132                       var_get_print_format (roc->vars[i]));
1133
1134           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1135           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1136         }
1137
1138       casereader_destroy (r);
1139     }
1140
1141   tab_submit (tbl);
1142 }
1143
1144
1145 static void
1146 draw_roc (struct roc_state *rs, const struct cmd_roc *roc)
1147 {
1148   int i;
1149
1150   struct chart *roc_chart = chart_create ();
1151
1152   chart_write_title (roc_chart, _("ROC Curve"));
1153   chart_write_xlabel (roc_chart, _("1 - Specificity"));
1154   chart_write_ylabel (roc_chart, _("Sensitivity"));
1155
1156   chart_write_xscale (roc_chart, 0, 1, 5);
1157   chart_write_yscale (roc_chart, 0, 1, 5);
1158
1159   if ( roc->reference )
1160     {
1161       chart_line (roc_chart, 1.0, 0,
1162                   0.0, 1.0,
1163                   CHART_DIM_X);
1164     }
1165
1166   for (i = 0; i < roc->n_vars; ++i)
1167     {
1168       struct ccase *cc;
1169       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1170
1171       chart_vector_start (roc_chart, var_get_name (roc->vars[i]));
1172       for (; (cc = casereader_read (r)) != NULL;
1173            case_unref (cc))
1174         {
1175           double se = case_data_idx (cc, TP)->f;
1176           double sp = case_data_idx (cc, TN)->f;
1177
1178           se /= case_data_idx (cc, FN)->f +
1179             case_data_idx (cc, TP)->f ;
1180
1181           sp /= case_data_idx (cc, TN)->f +
1182             case_data_idx (cc, FP)->f ;
1183
1184           chart_vector (roc_chart, 1 - sp, se);
1185         }
1186       chart_vector_end (roc_chart);
1187       casereader_destroy (r);
1188     }
1189
1190   chart_write_legend (roc_chart);
1191
1192   chart_submit (roc_chart);
1193 }
1194
1195
1196 static void
1197 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1198 {
1199   show_summary (roc);
1200
1201   if ( roc->curve )
1202     draw_roc (rs, roc);
1203
1204   show_auc (rs, roc);
1205
1206
1207   if ( roc->print_coords )
1208     show_coords (rs, roc);
1209 }
1210