Cope with new location of upstream files
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/roc.h"
20
21 #include <gsl/gsl_cdf.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/chart-item.h"
37 #include "output/charts/roc-chart.h"
38 #include "output/tab.h"
39
40 #include "gettext.h"
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
43
44 struct cmd_roc
45 {
46   size_t n_vars;
47   const struct variable **vars;
48   const struct dictionary *dict;
49
50   const struct variable *state_var;
51   union value state_value;
52   size_t state_var_width;
53
54   /* Plot the roc curve */
55   bool curve;
56   /* Plot the reference line */
57   bool reference;
58
59   double ci;
60
61   bool print_coords;
62   bool print_se;
63   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
64                       should be used */
65   enum mv_class exclude;
66
67   bool invert ; /* True iff a smaller test result variable indicates
68                    a positive result */
69
70   double pos;
71   double neg;
72   double pos_weighted;
73   double neg_weighted;
74 };
75
76 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
77
78 int
79 cmd_roc (struct lexer *lexer, struct dataset *ds)
80 {
81   struct cmd_roc roc ;
82   const struct dictionary *dict = dataset_dict (ds);
83
84   roc.vars = NULL;
85   roc.n_vars = 0;
86   roc.print_se = false;
87   roc.print_coords = false;
88   roc.exclude = MV_ANY;
89   roc.curve = true;
90   roc.reference = false;
91   roc.ci = 95;
92   roc.bi_neg_exp = false;
93   roc.invert = false;
94   roc.pos = roc.pos_weighted = 0;
95   roc.neg = roc.neg_weighted = 0;
96   roc.dict = dataset_dict (ds);
97   roc.state_var = NULL;
98   roc.state_var_width = -1;
99
100   lex_match (lexer, T_SLASH);
101   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
102                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
103     goto error;
104
105   if ( ! lex_force_match (lexer, T_BY))
106     {
107       goto error;
108     }
109
110   roc.state_var = parse_variable (lexer, dict);
111   if (! roc.state_var)
112     {
113       goto error;
114     }
115
116   if ( !lex_force_match (lexer, T_LPAREN))
117     {
118       goto error;
119     }
120
121   roc.state_var_width = var_get_width (roc.state_var);
122   value_init (&roc.state_value, roc.state_var_width);
123   parse_value (lexer, &roc.state_value, roc.state_var);
124
125
126   if ( !lex_force_match (lexer, T_RPAREN))
127     {
128       goto error;
129     }
130
131   while (lex_token (lexer) != T_ENDCMD)
132     {
133       lex_match (lexer, T_SLASH);
134       if (lex_match_id (lexer, "MISSING"))
135         {
136           lex_match (lexer, T_EQUALS);
137           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
138             {
139               if (lex_match_id (lexer, "INCLUDE"))
140                 {
141                   roc.exclude = MV_SYSTEM;
142                 }
143               else if (lex_match_id (lexer, "EXCLUDE"))
144                 {
145                   roc.exclude = MV_ANY;
146                 }
147               else
148                 {
149                   lex_error (lexer, NULL);
150                   goto error;
151                 }
152             }
153         }
154       else if (lex_match_id (lexer, "PLOT"))
155         {
156           lex_match (lexer, T_EQUALS);
157           if (lex_match_id (lexer, "CURVE"))
158             {
159               roc.curve = true;
160               if (lex_match (lexer, T_LPAREN))
161                 {
162                   roc.reference = true;
163                   if (! lex_force_match_id (lexer, "REFERENCE"))
164                     goto error;
165                   if (! lex_force_match (lexer, T_RPAREN))
166                     goto error;
167                 }
168             }
169           else if (lex_match_id (lexer, "NONE"))
170             {
171               roc.curve = false;
172             }
173           else
174             {
175               lex_error (lexer, NULL);
176               goto error;
177             }
178         }
179       else if (lex_match_id (lexer, "PRINT"))
180         {
181           lex_match (lexer, T_EQUALS);
182           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
183             {
184               if (lex_match_id (lexer, "SE"))
185                 {
186                   roc.print_se = true;
187                 }
188               else if (lex_match_id (lexer, "COORDINATES"))
189                 {
190                   roc.print_coords = true;
191                 }
192               else
193                 {
194                   lex_error (lexer, NULL);
195                   goto error;
196                 }
197             }
198         }
199       else if (lex_match_id (lexer, "CRITERIA"))
200         {
201           lex_match (lexer, T_EQUALS);
202           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
203             {
204               if (lex_match_id (lexer, "CUTOFF"))
205                 {
206                   if (! lex_force_match (lexer, T_LPAREN))
207                     goto error;
208                   if (lex_match_id (lexer, "INCLUDE"))
209                     {
210                       roc.exclude = MV_SYSTEM;
211                     }
212                   else if (lex_match_id (lexer, "EXCLUDE"))
213                     {
214                       roc.exclude = MV_USER | MV_SYSTEM;
215                     }
216                   else
217                     {
218                       lex_error (lexer, NULL);
219                       goto error;
220                     }
221                   if (! lex_force_match (lexer, T_RPAREN))
222                     goto error;
223                 }
224               else if (lex_match_id (lexer, "TESTPOS"))
225                 {
226                   if (! lex_force_match (lexer, T_LPAREN))
227                     goto error;
228                   if (lex_match_id (lexer, "LARGE"))
229                     {
230                       roc.invert = false;
231                     }
232                   else if (lex_match_id (lexer, "SMALL"))
233                     {
234                       roc.invert = true;
235                     }
236                   else
237                     {
238                       lex_error (lexer, NULL);
239                       goto error;
240                     }
241                   if (! lex_force_match (lexer, T_RPAREN))
242                     goto error;
243                 }
244               else if (lex_match_id (lexer, "CI"))
245                 {
246                   if (!lex_force_match (lexer, T_LPAREN))
247                     goto error;
248                   if (! lex_force_num (lexer))
249                     goto error;
250                   roc.ci = lex_number (lexer);
251                   lex_get (lexer);
252                   if (!lex_force_match (lexer, T_RPAREN))
253                     goto error;
254                 }
255               else if (lex_match_id (lexer, "DISTRIBUTION"))
256                 {
257                   if (!lex_force_match (lexer, T_LPAREN))
258                     goto error;
259                   if (lex_match_id (lexer, "FREE"))
260                     {
261                       roc.bi_neg_exp = false;
262                     }
263                   else if (lex_match_id (lexer, "NEGEXPO"))
264                     {
265                       roc.bi_neg_exp = true;
266                     }
267                   else
268                     {
269                       lex_error (lexer, NULL);
270                       goto error;
271                     }
272                   if (!lex_force_match (lexer, T_RPAREN))
273                     goto error;
274                 }
275               else
276                 {
277                   lex_error (lexer, NULL);
278                   goto error;
279                 }
280             }
281         }
282       else
283         {
284           lex_error (lexer, NULL);
285           break;
286         }
287     }
288
289   if ( ! run_roc (ds, &roc))
290     goto error;
291
292   if ( roc.state_var)
293     value_destroy (&roc.state_value, roc.state_var_width);
294   free (roc.vars);
295   return CMD_SUCCESS;
296
297  error:
298   if ( roc.state_var)
299     value_destroy (&roc.state_value, roc.state_var_width);
300   free (roc.vars);
301   return CMD_FAILURE;
302 }
303
304
305
306
307 static void
308 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
309
310
311 static int
312 run_roc (struct dataset *ds, struct cmd_roc *roc)
313 {
314   struct dictionary *dict = dataset_dict (ds);
315   bool ok;
316   struct casereader *group;
317
318   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
319   while (casegrouper_get_next_group (grouper, &group))
320     {
321       do_roc (roc, group, dataset_dict (ds));
322     }
323   ok = casegrouper_destroy (grouper);
324   ok = proc_commit (ds) && ok;
325
326   return ok;
327 }
328
329 #if 0
330 static void
331 dump_casereader (struct casereader *reader)
332 {
333   struct ccase *c;
334   struct casereader *r = casereader_clone (reader);
335
336   for ( ; (c = casereader_read (r) ); case_unref (c))
337     {
338       int i;
339       for (i = 0 ; i < case_get_value_cnt (c); ++i)
340         {
341           printf ("%g ", case_data_idx (c, i)->f);
342         }
343       printf ("\n");
344     }
345
346   casereader_destroy (r);
347 }
348 #endif
349
350
351 /*
352    Return true iff the state variable indicates that C has positive actual state.
353
354    As a side effect, this function also accumulates the roc->{pos,neg} and
355    roc->{pos,neg}_weighted counts.
356  */
357 static bool
358 match_positives (const struct ccase *c, void *aux)
359 {
360   struct cmd_roc *roc = aux;
361   const struct variable *wv = dict_get_weight (roc->dict);
362   const double weight = wv ? case_data (c, wv)->f : 1.0;
363
364   const bool positive =
365   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
366     var_get_width (roc->state_var)));
367
368   if ( positive )
369     {
370       roc->pos++;
371       roc->pos_weighted += weight;
372     }
373   else
374     {
375       roc->neg++;
376       roc->neg_weighted += weight;
377     }
378
379   return positive;
380 }
381
382
383 #define VALUE  0
384 #define N_EQ   1
385 #define N_PRED 2
386
387 /* Some intermediate state for calculating the cutpoints and the
388    standard error values */
389 struct roc_state
390 {
391   double auc;  /* Area under the curve */
392
393   double n1;  /* total weight of positives */
394   double n2;  /* total weight of negatives */
395
396   /* intermediates for standard error */
397   double q1hat;
398   double q2hat;
399
400   /* intermediates for cutpoints */
401   struct casewriter *cutpoint_wtr;
402   struct casereader *cutpoint_rdr;
403   double prev_result;
404   double min;
405   double max;
406 };
407
408 /*
409    Return a new casereader based upon CUTPOINT_RDR.
410    The number of "positive" cases are placed into
411    the position TRUE_INDEX, and the number of "negative" cases
412    into FALSE_INDEX.
413    POS_COND and RESULT determine the semantics of what is
414    "positive".
415    WEIGHT is the value of a single count.
416  */
417 static struct casereader *
418 accumulate_counts (struct casereader *input,
419                    double result, double weight,
420                    bool (*pos_cond) (double, double),
421                    int true_index, int false_index)
422 {
423   const struct caseproto *proto = casereader_get_proto (input);
424   struct casewriter *w =
425     autopaging_writer_create (proto);
426   struct ccase *cpc;
427   double prev_cp = SYSMIS;
428
429   for ( ; (cpc = casereader_read (input) ); case_unref (cpc))
430     {
431       struct ccase *new_case;
432       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
433
434       assert (cp != SYSMIS);
435
436       /* We don't want duplicates here */
437       if ( cp == prev_cp )
438         continue;
439
440       new_case = case_clone (cpc);
441
442       if ( pos_cond (result, cp))
443         case_data_rw_idx (new_case, true_index)->f += weight;
444       else
445         case_data_rw_idx (new_case, false_index)->f += weight;
446
447       prev_cp = cp;
448
449       casewriter_write (w, new_case);
450     }
451   casereader_destroy (input);
452
453   return casewriter_make_reader (w);
454 }
455
456
457
458 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
459
460 /*
461   This function does 3 things:
462
463   1. Counts the number of cases which are equal to every other case in READER,
464   and those cases for which the relationship between it and every other case
465   satifies PRED (normally either > or <).  VAR is variable defining a case's value
466   for this purpose.
467
468   2. Counts the number of true and false cases in reader, and populates
469   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
470   which receive these values.  POS_COND is the condition defining true
471   and false.
472
473   3. CC is filled with the cumulative weight of all cases of READER.
474 */
475 static struct casereader *
476 process_group (const struct variable *var, struct casereader *reader,
477                bool (*pred) (double, double),
478                const struct dictionary *dict,
479                double *cc,
480                struct casereader **cutpoint_rdr,
481                bool (*pos_cond) (double, double),
482                int true_index,
483                int false_index)
484 {
485   const struct variable *w = dict_get_weight (dict);
486
487   struct casereader *r1 =
488     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
489
490   const int weight_idx  = w ? var_get_case_index (w) :
491     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
492
493   struct ccase *c1;
494
495   struct casereader *rclone = casereader_clone (r1);
496   struct casewriter *wtr;
497   struct caseproto *proto = caseproto_create ();
498
499   proto = caseproto_add_width (proto, 0);
500   proto = caseproto_add_width (proto, 0);
501   proto = caseproto_add_width (proto, 0);
502
503   wtr = autopaging_writer_create (proto);
504
505   *cc = 0;
506
507   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
508     {
509       struct ccase *new_case = case_create (proto);
510       struct ccase *c2;
511       struct casereader *r2 = casereader_clone (rclone);
512
513       const double weight1 = case_data_idx (c1, weight_idx)->f;
514       const double d1 = case_data (c1, var)->f;
515       double n_eq = 0.0;
516       double n_pred = 0.0;
517
518       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
519                                          pos_cond,
520                                          true_index, false_index);
521
522       *cc += weight1;
523
524       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
525         {
526           const double d2 = case_data (c2, var)->f;
527           const double weight2 = case_data_idx (c2, weight_idx)->f;
528
529           if ( d1 == d2 )
530             {
531               n_eq += weight2;
532               continue;
533             }
534           else  if ( pred (d2, d1))
535             {
536               n_pred += weight2;
537             }
538         }
539
540       case_data_rw_idx (new_case, VALUE)->f = d1;
541       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
542       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
543
544       casewriter_write (wtr, new_case);
545
546       casereader_destroy (r2);
547     }
548
549
550   casereader_destroy (r1);
551   casereader_destroy (rclone);
552
553   caseproto_unref (proto);
554
555   return casewriter_make_reader (wtr);
556 }
557
558 /* Some more indeces into case data */
559 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
560 #define N_POS_GT 2  /* number of postive cases with values greater than n */
561 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
562 #define N_NEG_LT 4  /* number of negative cases with values less than n */
563
564 static bool
565 gt (double d1, double d2)
566 {
567   return d1 > d2;
568 }
569
570
571 static bool
572 ge (double d1, double d2)
573 {
574   return d1 > d2;
575 }
576
577 static bool
578 lt (double d1, double d2)
579 {
580   return d1 < d2;
581 }
582
583
584 /*
585   Return a casereader with width 3,
586   populated with cases based upon READER.
587   The cases will have the values:
588   (N, number of cases equal to N, number of cases greater than N)
589   As a side effect, update RS->n1 with the number of positive cases.
590 */
591 static struct casereader *
592 process_positive_group (const struct variable *var, struct casereader *reader,
593                         const struct dictionary *dict,
594                         struct roc_state *rs)
595 {
596   return process_group (var, reader, gt, dict, &rs->n1,
597                         &rs->cutpoint_rdr,
598                         ge,
599                         ROC_TP, ROC_FN);
600 }
601
602 /*
603   Return a casereader with width 3,
604   populated with cases based upon READER.
605   The cases will have the values:
606   (N, number of cases equal to N, number of cases less than N)
607   As a side effect, update RS->n2 with the number of negative cases.
608 */
609 static struct casereader *
610 process_negative_group (const struct variable *var, struct casereader *reader,
611                         const struct dictionary *dict,
612                         struct roc_state *rs)
613 {
614   return process_group (var, reader, lt, dict, &rs->n2,
615                         &rs->cutpoint_rdr,
616                         lt,
617                         ROC_TN, ROC_FP);
618 }
619
620
621
622
623 static void
624 append_cutpoint (struct casewriter *writer, double cutpoint)
625 {
626   struct ccase *cc = case_create (casewriter_get_proto (writer));
627
628   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
629   case_data_rw_idx (cc, ROC_TP)->f = 0;
630   case_data_rw_idx (cc, ROC_FN)->f = 0;
631   case_data_rw_idx (cc, ROC_TN)->f = 0;
632   case_data_rw_idx (cc, ROC_FP)->f = 0;
633
634   casewriter_write (writer, cc);
635 }
636
637
638 /*
639    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
640    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
641    reader will be populated with its final number of cases.
642    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
643    value.  The other entries will be initialised to zero.
644 */
645 static void
646 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
647 {
648   int i;
649   struct casereader *r = casereader_clone (input);
650   struct ccase *c;
651
652   {
653     struct caseproto *proto = caseproto_create ();
654     struct subcase ordering;
655     subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
656
657     proto = caseproto_add_width (proto, 0); /* cutpoint */
658     proto = caseproto_add_width (proto, 0); /* ROC_TP */
659     proto = caseproto_add_width (proto, 0); /* ROC_FN */
660     proto = caseproto_add_width (proto, 0); /* ROC_TN */
661     proto = caseproto_add_width (proto, 0); /* ROC_FP */
662
663     for (i = 0 ; i < roc->n_vars; ++i)
664       {
665         rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
666         rs[i].prev_result = SYSMIS;
667         rs[i].max = -DBL_MAX;
668         rs[i].min = DBL_MAX;
669       }
670
671     caseproto_unref (proto);
672     subcase_destroy (&ordering);
673   }
674
675   for (; (c = casereader_read (r)) != NULL; case_unref (c))
676     {
677       for (i = 0 ; i < roc->n_vars; ++i)
678         {
679           const union value *v = case_data (c, roc->vars[i]);
680           const double result = v->f;
681
682           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
683             continue;
684
685           minimize (&rs[i].min, result);
686           maximize (&rs[i].max, result);
687
688           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
689             {
690               const double mean = (result + rs[i].prev_result ) / 2.0;
691               append_cutpoint (rs[i].cutpoint_wtr, mean);
692             }
693
694           rs[i].prev_result = result;
695         }
696     }
697   casereader_destroy (r);
698
699
700   /* Append the min and max cutpoints */
701   for (i = 0 ; i < roc->n_vars; ++i)
702     {
703       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
704       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
705
706       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
707     }
708 }
709
710 static void
711 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
712 {
713   int i;
714
715   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
716
717   struct casereader *negatives = NULL;
718   struct casereader *positives = NULL;
719
720   struct caseproto *n_proto = NULL;
721
722   struct subcase up_ordering;
723   struct subcase down_ordering;
724
725   struct casewriter *neg_wtr = NULL;
726
727   struct casereader *input = casereader_create_filter_missing (reader,
728                                                                roc->vars, roc->n_vars,
729                                                                roc->exclude,
730                                                                NULL,
731                                                                NULL);
732
733   input = casereader_create_filter_missing (input,
734                                             &roc->state_var, 1,
735                                             roc->exclude,
736                                             NULL,
737                                             NULL);
738
739   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
740
741   prepare_cutpoints (roc, rs, input);
742
743
744   /* Separate the positive actual state cases from the negative ones */
745   positives =
746     casereader_create_filter_func (input,
747                                    match_positives,
748                                    NULL,
749                                    roc,
750                                    neg_wtr);
751
752   n_proto = caseproto_create ();
753
754   n_proto = caseproto_add_width (n_proto, 0);
755   n_proto = caseproto_add_width (n_proto, 0);
756   n_proto = caseproto_add_width (n_proto, 0);
757   n_proto = caseproto_add_width (n_proto, 0);
758   n_proto = caseproto_add_width (n_proto, 0);
759
760   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
761   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
762
763   for (i = 0 ; i < roc->n_vars; ++i)
764     {
765       struct casewriter *w = NULL;
766       struct casereader *r = NULL;
767
768       struct ccase *c;
769
770       struct ccase *cpos;
771       struct casereader *n_neg_reader ;
772       const struct variable *var = roc->vars[i];
773
774       struct casereader *neg ;
775       struct casereader *pos = casereader_clone (positives);
776
777       struct casereader *n_pos_reader =
778         process_positive_group (var, pos, dict, &rs[i]);
779
780       if ( negatives == NULL)
781         {
782           negatives = casewriter_make_reader (neg_wtr);
783         }
784
785       neg = casereader_clone (negatives);
786
787       n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
788
789       /* Merge the n_pos and n_neg casereaders */
790       w = sort_create_writer (&up_ordering, n_proto);
791       for ( ; (cpos = casereader_read (n_pos_reader) ); case_unref (cpos))
792         {
793           struct ccase *pos_case = case_create (n_proto);
794           struct ccase *cneg;
795           const double jpos = case_data_idx (cpos, VALUE)->f;
796
797           while ((cneg = casereader_read (n_neg_reader)))
798             {
799               struct ccase *nc = case_create (n_proto);
800
801               const double jneg = case_data_idx (cneg, VALUE)->f;
802
803               case_data_rw_idx (nc, VALUE)->f = jneg;
804               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
805
806               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
807
808               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
809               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
810
811               casewriter_write (w, nc);
812
813               case_unref (cneg);
814               if ( jneg > jpos)
815                 break;
816             }
817
818           case_data_rw_idx (pos_case, VALUE)->f = jpos;
819           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
820           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
821           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
822           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
823
824           casewriter_write (w, pos_case);
825         }
826
827       casereader_destroy (n_pos_reader);
828       casereader_destroy (n_neg_reader);
829
830 /* These aren't used anymore */
831 #undef N_EQ
832 #undef N_PRED
833
834       r = casewriter_make_reader (w);
835
836       /* Propagate the N_POS_GT values from the positive cases
837          to the negative ones */
838       {
839         double prev_pos_gt = rs[i].n1;
840         w = sort_create_writer (&down_ordering, n_proto);
841
842         for ( ; (c = casereader_read (r) ); case_unref (c))
843           {
844             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
845             struct ccase *nc = case_clone (c);
846
847             if ( n_pos_gt == SYSMIS)
848               {
849                 n_pos_gt = prev_pos_gt;
850                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
851               }
852
853             casewriter_write (w, nc);
854             prev_pos_gt = n_pos_gt;
855           }
856
857         casereader_destroy (r);
858         r = casewriter_make_reader (w);
859       }
860
861       /* Propagate the N_NEG_LT values from the negative cases
862          to the positive ones */
863       {
864         double prev_neg_lt = rs[i].n2;
865         w = sort_create_writer (&up_ordering, n_proto);
866
867         for ( ; (c = casereader_read (r) ); case_unref (c))
868           {
869             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
870             struct ccase *nc = case_clone (c);
871
872             if ( n_neg_lt == SYSMIS)
873               {
874                 n_neg_lt = prev_neg_lt;
875                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
876               }
877
878             casewriter_write (w, nc);
879             prev_neg_lt = n_neg_lt;
880           }
881
882         casereader_destroy (r);
883         r = casewriter_make_reader (w);
884       }
885
886       {
887         struct ccase *prev_case = NULL;
888         for ( ; (c = casereader_read (r) ); case_unref (c))
889           {
890             struct ccase *next_case = casereader_peek (r, 0);
891
892             const double j = case_data_idx (c, VALUE)->f;
893             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
894             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
895             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
896             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
897
898             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
899               {
900                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
901                   {
902                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
903                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
904                   }
905
906                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
907                   {
908                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
909                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
910                   }
911               }
912
913             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
914               {
915                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
916
917                 rs[i].q1hat +=
918                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
919                 rs[i].q2hat +=
920                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
921
922               }
923
924             case_unref (next_case);
925             case_unref (prev_case);
926             prev_case = case_clone (c);
927           }
928         casereader_destroy (r);
929         case_unref (prev_case);
930
931         rs[i].auc /=  rs[i].n1 * rs[i].n2;
932         if ( roc->invert )
933           rs[i].auc = 1 - rs[i].auc;
934
935         if ( roc->bi_neg_exp )
936           {
937             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
938             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
939           }
940         else
941           {
942             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
943             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
944           }
945       }
946     }
947
948   casereader_destroy (positives);
949   casereader_destroy (negatives);
950
951   caseproto_unref (n_proto);
952   subcase_destroy (&up_ordering);
953   subcase_destroy (&down_ordering);
954
955   output_roc (rs, roc);
956
957   for (i = 0 ; i < roc->n_vars; ++i)
958     casereader_destroy (rs[i].cutpoint_rdr);
959
960   free (rs);
961 }
962
963 static void
964 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
965 {
966   int i;
967   const int n_fields = roc->print_se ? 5 : 1;
968   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
969   const int n_rows = 2 + roc->n_vars;
970   struct tab_table *tbl = tab_create (n_cols, n_rows);
971
972   if ( roc->n_vars > 1)
973     tab_title (tbl, _("Area Under the Curve"));
974   else
975     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
976
977   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
978
979
980   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
981
982   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
983
984   tab_box (tbl,
985            TAL_2, TAL_2,
986            -1, TAL_1,
987            0, 0,
988            n_cols - 1,
989            n_rows - 1);
990
991   if ( roc->print_se )
992     {
993       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
994       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
995
996       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
997       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
998
999       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
1000                              TAT_TITLE | TAB_CENTER,
1001                              _("Asymp. %g%% Confidence Interval"), roc->ci);
1002       tab_vline (tbl, 0, n_cols - 1, 0, 0);
1003       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
1004     }
1005
1006   if ( roc->n_vars > 1)
1007     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
1008
1009   if ( roc->n_vars > 1)
1010     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1011
1012
1013   for ( i = 0 ; i < roc->n_vars ; ++i )
1014     {
1015       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
1016
1017       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL, RC_OTHER);
1018
1019       if ( roc->print_se )
1020         {
1021           double se ;
1022           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1023                                       (12 * rs[i].n1 * rs[i].n2));
1024           double ci ;
1025           double yy ;
1026
1027           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
1028             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
1029
1030           se /= rs[i].n1 * rs[i].n2;
1031
1032           se = sqrt (se);
1033
1034           tab_double (tbl, n_cols - 4, 2 + i, 0,
1035                       se,
1036                       NULL, RC_OTHER);
1037
1038           ci = 1 - roc->ci / 100.0;
1039           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1040
1041           tab_double (tbl, n_cols - 2, 2 + i, 0,
1042                       rs[i].auc - yy,
1043                       NULL, RC_OTHER);
1044
1045           tab_double (tbl, n_cols - 1, 2 + i, 0,
1046                       rs[i].auc + yy,
1047                       NULL, RC_OTHER);
1048
1049           tab_double (tbl, n_cols - 3, 2 + i, 0,
1050                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1051                       NULL, RC_PVALUE);
1052         }
1053     }
1054
1055   tab_submit (tbl);
1056 }
1057
1058
1059 static void
1060 show_summary (const struct cmd_roc *roc)
1061 {
1062   const int n_cols = 3;
1063   const int n_rows = 4;
1064   struct tab_table *tbl = tab_create (n_cols, n_rows);
1065
1066   tab_title (tbl, _("Case Summary"));
1067
1068   tab_headers (tbl, 1, 0, 2, 0);
1069
1070   tab_box (tbl,
1071            TAL_2, TAL_2,
1072            -1, -1,
1073            0, 0,
1074            n_cols - 1,
1075            n_rows - 1);
1076
1077   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1078   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1079
1080
1081   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1082   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1083
1084
1085   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1086   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1087   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1088
1089   tab_joint_text (tbl, 1, 0, 2, 0,
1090                   TAT_TITLE | TAB_CENTER,
1091                   _("Valid N (listwise)"));
1092
1093
1094   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1095   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1096
1097
1098   tab_double (tbl, 1, 2, 0, roc->pos, NULL, RC_INTEGER);
1099   tab_double (tbl, 1, 3, 0, roc->neg, NULL, RC_INTEGER);
1100
1101   tab_double (tbl, 2, 2, 0, roc->pos_weighted, NULL, RC_OTHER);
1102   tab_double (tbl, 2, 3, 0, roc->neg_weighted, NULL, RC_OTHER);
1103
1104   tab_submit (tbl);
1105 }
1106
1107
1108 static void
1109 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1110 {
1111   int x = 1;
1112   int i;
1113   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1114   int n_rows = 1;
1115   struct tab_table *tbl ;
1116
1117   for (i = 0; i < roc->n_vars; ++i)
1118     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1119
1120   tbl = tab_create (n_cols, n_rows);
1121
1122   if ( roc->n_vars > 1)
1123     tab_title (tbl, _("Coordinates of the Curve"));
1124   else
1125     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1126
1127
1128   tab_headers (tbl, 1, 0, 1, 0);
1129
1130   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1131
1132   if ( roc->n_vars > 1)
1133     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1134
1135   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1136   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1137   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1138
1139   tab_box (tbl,
1140            TAL_2, TAL_2,
1141            -1, TAL_1,
1142            0, 0,
1143            n_cols - 1,
1144            n_rows - 1);
1145
1146   if ( roc->n_vars > 1)
1147     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1148
1149   for (i = 0; i < roc->n_vars; ++i)
1150     {
1151       struct ccase *cc;
1152       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1153
1154       if ( roc->n_vars > 1)
1155         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1156
1157       if ( i > 0)
1158         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1159
1160
1161       for (; (cc = casereader_read (r)) != NULL;
1162            case_unref (cc), x++)
1163         {
1164           const double se = case_data_idx (cc, ROC_TP)->f /
1165             (
1166              case_data_idx (cc, ROC_TP)->f
1167              +
1168              case_data_idx (cc, ROC_FN)->f
1169              );
1170
1171           const double sp = case_data_idx (cc, ROC_TN)->f /
1172             (
1173              case_data_idx (cc, ROC_TN)->f
1174              +
1175              case_data_idx (cc, ROC_FP)->f
1176              );
1177
1178           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1179                       var_get_print_format (roc->vars[i]), RC_OTHER);
1180
1181           tab_double (tbl, n_cols - 2, x, 0, se, NULL, RC_OTHER);
1182           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL, RC_OTHER);
1183         }
1184
1185       casereader_destroy (r);
1186     }
1187
1188   tab_submit (tbl);
1189 }
1190
1191
1192 static void
1193 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1194 {
1195   show_summary (roc);
1196
1197   if ( roc->curve )
1198     {
1199       struct roc_chart *rc;
1200       size_t i;
1201
1202       rc = roc_chart_create (roc->reference);
1203       for (i = 0; i < roc->n_vars; i++)
1204         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1205                            rs[i].cutpoint_rdr);
1206       roc_chart_submit (rc);
1207     }
1208
1209   show_auc (rs, roc);
1210
1211   if ( roc->print_coords )
1212     show_coords (rs, roc);
1213 }
1214