session: Fix two memory leaks.
[pspp] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/roc.h"
20
21 #include <gsl/gsl_cdf.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/chart-item.h"
37 #include "output/charts/roc-chart.h"
38 #include "output/tab.h"
39
40 #include "gettext.h"
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
43
44 struct cmd_roc
45 {
46   size_t n_vars;
47   const struct variable **vars;
48   const struct dictionary *dict;
49
50   const struct variable *state_var;
51   union value state_value;
52   size_t state_var_width;
53
54   /* Plot the roc curve */
55   bool curve;
56   /* Plot the reference line */
57   bool reference;
58
59   double ci;
60
61   bool print_coords;
62   bool print_se;
63   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
64                       should be used */
65   enum mv_class exclude;
66
67   bool invert ; /* True iff a smaller test result variable indicates
68                    a positive result */
69
70   double pos;
71   double neg;
72   double pos_weighted;
73   double neg_weighted;
74 };
75
76 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
77
78 int
79 cmd_roc (struct lexer *lexer, struct dataset *ds)
80 {
81   struct cmd_roc roc ;
82   const struct dictionary *dict = dataset_dict (ds);
83
84   roc.vars = NULL;
85   roc.n_vars = 0;
86   roc.print_se = false;
87   roc.print_coords = false;
88   roc.exclude = MV_ANY;
89   roc.curve = true;
90   roc.reference = false;
91   roc.ci = 95;
92   roc.bi_neg_exp = false;
93   roc.invert = false;
94   roc.pos = roc.pos_weighted = 0;
95   roc.neg = roc.neg_weighted = 0;
96   roc.dict = dataset_dict (ds);
97   roc.state_var = NULL;
98   roc.state_var_width = -1;
99
100   lex_match (lexer, T_SLASH);
101   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
102                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
103     goto error;
104
105   if ( ! lex_force_match (lexer, T_BY))
106     {
107       goto error;
108     }
109
110   roc.state_var = parse_variable (lexer, dict);
111
112   if ( !lex_force_match (lexer, T_LPAREN))
113     {
114       goto error;
115     }
116
117   roc.state_var_width = var_get_width (roc.state_var);
118   value_init (&roc.state_value, roc.state_var_width);
119   parse_value (lexer, &roc.state_value, roc.state_var);
120
121
122   if ( !lex_force_match (lexer, T_RPAREN))
123     {
124       goto error;
125     }
126
127   while (lex_token (lexer) != T_ENDCMD)
128     {
129       lex_match (lexer, T_SLASH);
130       if (lex_match_id (lexer, "MISSING"))
131         {
132           lex_match (lexer, T_EQUALS);
133           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
134             {
135               if (lex_match_id (lexer, "INCLUDE"))
136                 {
137                   roc.exclude = MV_SYSTEM;
138                 }
139               else if (lex_match_id (lexer, "EXCLUDE"))
140                 {
141                   roc.exclude = MV_ANY;
142                 }
143               else
144                 {
145                   lex_error (lexer, NULL);
146                   goto error;
147                 }
148             }
149         }
150       else if (lex_match_id (lexer, "PLOT"))
151         {
152           lex_match (lexer, T_EQUALS);
153           if (lex_match_id (lexer, "CURVE"))
154             {
155               roc.curve = true;
156               if (lex_match (lexer, T_LPAREN))
157                 {
158                   roc.reference = true;
159                   lex_force_match_id (lexer, "REFERENCE");
160                   lex_force_match (lexer, T_RPAREN);
161                 }
162             }
163           else if (lex_match_id (lexer, "NONE"))
164             {
165               roc.curve = false;
166             }
167           else
168             {
169               lex_error (lexer, NULL);
170               goto error;
171             }
172         }
173       else if (lex_match_id (lexer, "PRINT"))
174         {
175           lex_match (lexer, T_EQUALS);
176           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
177             {
178               if (lex_match_id (lexer, "SE"))
179                 {
180                   roc.print_se = true;
181                 }
182               else if (lex_match_id (lexer, "COORDINATES"))
183                 {
184                   roc.print_coords = true;
185                 }
186               else
187                 {
188                   lex_error (lexer, NULL);
189                   goto error;
190                 }
191             }
192         }
193       else if (lex_match_id (lexer, "CRITERIA"))
194         {
195           lex_match (lexer, T_EQUALS);
196           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
197             {
198               if (lex_match_id (lexer, "CUTOFF"))
199                 {
200                   lex_force_match (lexer, T_LPAREN);
201                   if (lex_match_id (lexer, "INCLUDE"))
202                     {
203                       roc.exclude = MV_SYSTEM;
204                     }
205                   else if (lex_match_id (lexer, "EXCLUDE"))
206                     {
207                       roc.exclude = MV_USER | MV_SYSTEM;
208                     }
209                   else
210                     {
211                       lex_error (lexer, NULL);
212                       goto error;
213                     }
214                   lex_force_match (lexer, T_RPAREN);
215                 }
216               else if (lex_match_id (lexer, "TESTPOS"))
217                 {
218                   lex_force_match (lexer, T_LPAREN);
219                   if (lex_match_id (lexer, "LARGE"))
220                     {
221                       roc.invert = false;
222                     }
223                   else if (lex_match_id (lexer, "SMALL"))
224                     {
225                       roc.invert = true;
226                     }
227                   else
228                     {
229                       lex_error (lexer, NULL);
230                       goto error;
231                     }
232                   lex_force_match (lexer, T_RPAREN);
233                 }
234               else if (lex_match_id (lexer, "CI"))
235                 {
236                   lex_force_match (lexer, T_LPAREN);
237                   lex_force_num (lexer);
238                   roc.ci = lex_number (lexer);
239                   lex_get (lexer);
240                   lex_force_match (lexer, T_RPAREN);
241                 }
242               else if (lex_match_id (lexer, "DISTRIBUTION"))
243                 {
244                   lex_force_match (lexer, T_LPAREN);
245                   if (lex_match_id (lexer, "FREE"))
246                     {
247                       roc.bi_neg_exp = false;
248                     }
249                   else if (lex_match_id (lexer, "NEGEXPO"))
250                     {
251                       roc.bi_neg_exp = true;
252                     }
253                   else
254                     {
255                       lex_error (lexer, NULL);
256                       goto error;
257                     }
258                   lex_force_match (lexer, T_RPAREN);
259                 }
260               else
261                 {
262                   lex_error (lexer, NULL);
263                   goto error;
264                 }
265             }
266         }
267       else
268         {
269           lex_error (lexer, NULL);
270           break;
271         }
272     }
273
274   if ( ! run_roc (ds, &roc)) 
275     goto error;
276
277   if ( roc.state_var)
278     value_destroy (&roc.state_value, roc.state_var_width);
279   free (roc.vars);
280   return CMD_SUCCESS;
281
282  error:
283   if ( roc.state_var)
284     value_destroy (&roc.state_value, roc.state_var_width);
285   free (roc.vars);
286   return CMD_FAILURE;
287 }
288
289
290
291
292 static void
293 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
294
295
296 static int
297 run_roc (struct dataset *ds, struct cmd_roc *roc)
298 {
299   struct dictionary *dict = dataset_dict (ds);
300   bool ok;
301   struct casereader *group;
302
303   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
304   while (casegrouper_get_next_group (grouper, &group))
305     {
306       do_roc (roc, group, dataset_dict (ds));
307     }
308   ok = casegrouper_destroy (grouper);
309   ok = proc_commit (ds) && ok;
310
311   return ok;
312 }
313
314 #if 0
315 static void
316 dump_casereader (struct casereader *reader)
317 {
318   struct ccase *c;
319   struct casereader *r = casereader_clone (reader);
320
321   for ( ; (c = casereader_read (r) ); case_unref (c))
322     {
323       int i;
324       for (i = 0 ; i < case_get_value_cnt (c); ++i)
325         {
326           printf ("%g ", case_data_idx (c, i)->f);
327         }
328       printf ("\n");
329     }
330
331   casereader_destroy (r);
332 }
333 #endif
334
335
336 /* 
337    Return true iff the state variable indicates that C has positive actual state.
338
339    As a side effect, this function also accumulates the roc->{pos,neg} and 
340    roc->{pos,neg}_weighted counts.
341  */
342 static bool
343 match_positives (const struct ccase *c, void *aux)
344 {
345   struct cmd_roc *roc = aux;
346   const struct variable *wv = dict_get_weight (roc->dict);
347   const double weight = wv ? case_data (c, wv)->f : 1.0;
348
349   const bool positive =
350   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
351     var_get_width (roc->state_var)));
352
353   if ( positive )
354     {
355       roc->pos++;
356       roc->pos_weighted += weight;
357     }
358   else
359     {
360       roc->neg++;
361       roc->neg_weighted += weight;
362     }
363
364   return positive;
365 }
366
367
368 #define VALUE  0
369 #define N_EQ   1
370 #define N_PRED 2
371
372 /* Some intermediate state for calculating the cutpoints and the 
373    standard error values */
374 struct roc_state
375 {
376   double auc;  /* Area under the curve */
377
378   double n1;  /* total weight of positives */
379   double n2;  /* total weight of negatives */
380
381   /* intermediates for standard error */
382   double q1hat; 
383   double q2hat;
384
385   /* intermediates for cutpoints */
386   struct casewriter *cutpoint_wtr;
387   struct casereader *cutpoint_rdr;
388   double prev_result;
389   double min;
390   double max;
391 };
392
393 /* 
394    Return a new casereader based upon CUTPOINT_RDR.
395    The number of "positive" cases are placed into
396    the position TRUE_INDEX, and the number of "negative" cases
397    into FALSE_INDEX.
398    POS_COND and RESULT determine the semantics of what is 
399    "positive".
400    WEIGHT is the value of a single count.
401  */
402 static struct casereader *
403 accumulate_counts (struct casereader *input,
404                    double result, double weight, 
405                    bool (*pos_cond) (double, double),
406                    int true_index, int false_index)
407 {
408   const struct caseproto *proto = casereader_get_proto (input);
409   struct casewriter *w =
410     autopaging_writer_create (proto);
411   struct ccase *cpc;
412   double prev_cp = SYSMIS;
413
414   for ( ; (cpc = casereader_read (input) ); case_unref (cpc))
415     {
416       struct ccase *new_case;
417       const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
418
419       assert (cp != SYSMIS);
420
421       /* We don't want duplicates here */
422       if ( cp == prev_cp )
423         continue;
424
425       new_case = case_clone (cpc);
426
427       if ( pos_cond (result, cp))
428         case_data_rw_idx (new_case, true_index)->f += weight;
429       else
430         case_data_rw_idx (new_case, false_index)->f += weight;
431
432       prev_cp = cp;
433
434       casewriter_write (w, new_case);
435     }
436   casereader_destroy (input);
437
438   return casewriter_make_reader (w);
439 }
440
441
442
443 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
444
445 /*
446   This function does 3 things:
447
448   1. Counts the number of cases which are equal to every other case in READER,
449   and those cases for which the relationship between it and every other case
450   satifies PRED (normally either > or <).  VAR is variable defining a case's value
451   for this purpose.
452
453   2. Counts the number of true and false cases in reader, and populates
454   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
455   which receive these values.  POS_COND is the condition defining true
456   and false.
457   
458   3. CC is filled with the cumulative weight of all cases of READER.
459 */
460 static struct casereader *
461 process_group (const struct variable *var, struct casereader *reader,
462                bool (*pred) (double, double),
463                const struct dictionary *dict,
464                double *cc,
465                struct casereader **cutpoint_rdr, 
466                bool (*pos_cond) (double, double),
467                int true_index,
468                int false_index)
469 {
470   const struct variable *w = dict_get_weight (dict);
471
472   struct casereader *r1 =
473     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
474
475   const int weight_idx  = w ? var_get_case_index (w) :
476     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
477   
478   struct ccase *c1;
479
480   struct casereader *rclone = casereader_clone (r1);
481   struct casewriter *wtr;
482   struct caseproto *proto = caseproto_create ();
483
484   proto = caseproto_add_width (proto, 0);
485   proto = caseproto_add_width (proto, 0);
486   proto = caseproto_add_width (proto, 0);
487
488   wtr = autopaging_writer_create (proto);  
489
490   *cc = 0;
491
492   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
493     {
494       struct ccase *new_case = case_create (proto);
495       struct ccase *c2;
496       struct casereader *r2 = casereader_clone (rclone);
497
498       const double weight1 = case_data_idx (c1, weight_idx)->f;
499       const double d1 = case_data (c1, var)->f;
500       double n_eq = 0.0;
501       double n_pred = 0.0;
502
503       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
504                                          pos_cond,
505                                          true_index, false_index);
506
507       *cc += weight1;
508
509       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
510         {
511           const double d2 = case_data (c2, var)->f;
512           const double weight2 = case_data_idx (c2, weight_idx)->f;
513
514           if ( d1 == d2 )
515             {
516               n_eq += weight2;
517               continue;
518             }
519           else  if ( pred (d2, d1))
520             {
521               n_pred += weight2;
522             }
523         }
524
525       case_data_rw_idx (new_case, VALUE)->f = d1;
526       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
527       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
528
529       casewriter_write (wtr, new_case);
530
531       casereader_destroy (r2);
532     }
533
534   
535   casereader_destroy (r1);
536   casereader_destroy (rclone);
537
538   caseproto_unref (proto);
539
540   return casewriter_make_reader (wtr);
541 }
542
543 /* Some more indeces into case data */
544 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
545 #define N_POS_GT 2  /* number of postive cases with values greater than n */
546 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
547 #define N_NEG_LT 4  /* number of negative cases with values less than n */
548
549 static bool
550 gt (double d1, double d2)
551 {
552   return d1 > d2;
553 }
554
555
556 static bool
557 ge (double d1, double d2)
558 {
559   return d1 > d2;
560 }
561
562 static bool
563 lt (double d1, double d2)
564 {
565   return d1 < d2;
566 }
567
568
569 /*
570   Return a casereader with width 3,
571   populated with cases based upon READER.
572   The cases will have the values:
573   (N, number of cases equal to N, number of cases greater than N)
574   As a side effect, update RS->n1 with the number of positive cases.
575 */
576 static struct casereader *
577 process_positive_group (const struct variable *var, struct casereader *reader,
578                         const struct dictionary *dict,
579                         struct roc_state *rs)
580 {
581   return process_group (var, reader, gt, dict, &rs->n1,
582                         &rs->cutpoint_rdr,
583                         ge,
584                         ROC_TP, ROC_FN);
585 }
586
587 /*
588   Return a casereader with width 3,
589   populated with cases based upon READER.
590   The cases will have the values:
591   (N, number of cases equal to N, number of cases less than N)
592   As a side effect, update RS->n2 with the number of negative cases.
593 */
594 static struct casereader *
595 process_negative_group (const struct variable *var, struct casereader *reader,
596                         const struct dictionary *dict,
597                         struct roc_state *rs)
598 {
599   return process_group (var, reader, lt, dict, &rs->n2,
600                         &rs->cutpoint_rdr,
601                         lt,
602                         ROC_TN, ROC_FP);
603 }
604
605
606
607
608 static void
609 append_cutpoint (struct casewriter *writer, double cutpoint)
610 {
611   struct ccase *cc = case_create (casewriter_get_proto (writer));
612
613   case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
614   case_data_rw_idx (cc, ROC_TP)->f = 0;
615   case_data_rw_idx (cc, ROC_FN)->f = 0;
616   case_data_rw_idx (cc, ROC_TN)->f = 0;
617   case_data_rw_idx (cc, ROC_FP)->f = 0;
618
619   casewriter_write (writer, cc);
620 }
621
622
623 /* 
624    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
625    be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
626    reader will be populated with its final number of cases.
627    However on exit from this function, only ROC_CUTPOINT entries will be set to their final
628    value.  The other entries will be initialised to zero.
629 */
630 static void
631 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
632 {
633   int i;
634   struct casereader *r = casereader_clone (input);
635   struct ccase *c;
636
637   {
638     struct caseproto *proto = caseproto_create ();
639     struct subcase ordering;
640     subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
641
642     proto = caseproto_add_width (proto, 0); /* cutpoint */
643     proto = caseproto_add_width (proto, 0); /* ROC_TP */
644     proto = caseproto_add_width (proto, 0); /* ROC_FN */
645     proto = caseproto_add_width (proto, 0); /* ROC_TN */
646     proto = caseproto_add_width (proto, 0); /* ROC_FP */
647
648     for (i = 0 ; i < roc->n_vars; ++i)
649       {
650         rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
651         rs[i].prev_result = SYSMIS;
652         rs[i].max = -DBL_MAX;
653         rs[i].min = DBL_MAX;
654       }
655
656     caseproto_unref (proto);
657     subcase_destroy (&ordering);
658   }
659
660   for (; (c = casereader_read (r)) != NULL; case_unref (c))
661     {
662       for (i = 0 ; i < roc->n_vars; ++i)
663         {
664           const union value *v = case_data (c, roc->vars[i]); 
665           const double result = v->f;
666
667           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
668             continue;
669
670           minimize (&rs[i].min, result);
671           maximize (&rs[i].max, result);
672
673           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
674             {
675               const double mean = (result + rs[i].prev_result ) / 2.0;
676               append_cutpoint (rs[i].cutpoint_wtr, mean);
677             }
678
679           rs[i].prev_result = result;
680         }
681     }
682   casereader_destroy (r);
683
684
685   /* Append the min and max cutpoints */
686   for (i = 0 ; i < roc->n_vars; ++i)
687     {
688       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
689       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
690
691       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
692     }
693 }
694
695 static void
696 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
697 {
698   int i;
699
700   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
701
702   struct casereader *negatives = NULL;
703   struct casereader *positives = NULL;
704
705   struct caseproto *n_proto = NULL;
706
707   struct subcase up_ordering;
708   struct subcase down_ordering;
709
710   struct casewriter *neg_wtr = NULL;
711
712   struct casereader *input = casereader_create_filter_missing (reader,
713                                                                roc->vars, roc->n_vars,
714                                                                roc->exclude,
715                                                                NULL,
716                                                                NULL);
717
718   input = casereader_create_filter_missing (input,
719                                             &roc->state_var, 1,
720                                             roc->exclude,
721                                             NULL,
722                                             NULL);
723
724   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
725
726   prepare_cutpoints (roc, rs, input);
727
728
729   /* Separate the positive actual state cases from the negative ones */
730   positives = 
731     casereader_create_filter_func (input,
732                                    match_positives,
733                                    NULL,
734                                    roc,
735                                    neg_wtr);
736
737   n_proto = caseproto_create ();
738       
739   n_proto = caseproto_add_width (n_proto, 0);
740   n_proto = caseproto_add_width (n_proto, 0);
741   n_proto = caseproto_add_width (n_proto, 0);
742   n_proto = caseproto_add_width (n_proto, 0);
743   n_proto = caseproto_add_width (n_proto, 0);
744
745   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
746   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
747
748   for (i = 0 ; i < roc->n_vars; ++i)
749     {
750       struct casewriter *w = NULL;
751       struct casereader *r = NULL;
752
753       struct ccase *c;
754
755       struct ccase *cpos;
756       struct casereader *n_neg_reader ;
757       const struct variable *var = roc->vars[i];
758
759       struct casereader *neg ;
760       struct casereader *pos = casereader_clone (positives);
761
762       struct casereader *n_pos_reader =
763         process_positive_group (var, pos, dict, &rs[i]);
764
765       if ( negatives == NULL)
766         {
767           negatives = casewriter_make_reader (neg_wtr);
768         }
769
770       neg = casereader_clone (negatives);
771
772       n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
773
774       /* Merge the n_pos and n_neg casereaders */
775       w = sort_create_writer (&up_ordering, n_proto);
776       for ( ; (cpos = casereader_read (n_pos_reader) ); case_unref (cpos))
777         {
778           struct ccase *pos_case = case_create (n_proto);
779           struct ccase *cneg;
780           const double jpos = case_data_idx (cpos, VALUE)->f;
781
782           while ((cneg = casereader_read (n_neg_reader)))
783             {
784               struct ccase *nc = case_create (n_proto);
785
786               const double jneg = case_data_idx (cneg, VALUE)->f;
787
788               case_data_rw_idx (nc, VALUE)->f = jneg;
789               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
790
791               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
792
793               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
794               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
795
796               casewriter_write (w, nc);
797
798               case_unref (cneg);
799               if ( jneg > jpos)
800                 break;
801             }
802
803           case_data_rw_idx (pos_case, VALUE)->f = jpos;
804           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
805           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
806           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
807           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
808
809           casewriter_write (w, pos_case);
810         }
811
812       casereader_destroy (n_pos_reader);
813       casereader_destroy (n_neg_reader);
814
815 /* These aren't used anymore */
816 #undef N_EQ
817 #undef N_PRED
818
819       r = casewriter_make_reader (w);
820
821       /* Propagate the N_POS_GT values from the positive cases
822          to the negative ones */
823       {
824         double prev_pos_gt = rs[i].n1;
825         w = sort_create_writer (&down_ordering, n_proto);
826
827         for ( ; (c = casereader_read (r) ); case_unref (c))
828           {
829             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
830             struct ccase *nc = case_clone (c);
831
832             if ( n_pos_gt == SYSMIS)
833               {
834                 n_pos_gt = prev_pos_gt;
835                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
836               }
837             
838             casewriter_write (w, nc);
839             prev_pos_gt = n_pos_gt;
840           }
841
842         casereader_destroy (r);
843         r = casewriter_make_reader (w);
844       }
845
846       /* Propagate the N_NEG_LT values from the negative cases
847          to the positive ones */
848       {
849         double prev_neg_lt = rs[i].n2;
850         w = sort_create_writer (&up_ordering, n_proto);
851
852         for ( ; (c = casereader_read (r) ); case_unref (c))
853           {
854             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
855             struct ccase *nc = case_clone (c);
856
857             if ( n_neg_lt == SYSMIS)
858               {
859                 n_neg_lt = prev_neg_lt;
860                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
861               }
862             
863             casewriter_write (w, nc);
864             prev_neg_lt = n_neg_lt;
865           }
866
867         casereader_destroy (r);
868         r = casewriter_make_reader (w);
869       }
870
871       {
872         struct ccase *prev_case = NULL;
873         for ( ; (c = casereader_read (r) ); case_unref (c))
874           {
875             struct ccase *next_case = casereader_peek (r, 0);
876
877             const double j = case_data_idx (c, VALUE)->f;
878             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
879             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
880             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
881             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
882
883             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
884               {
885                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
886                   {
887                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
888                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
889                   }
890
891                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
892                   {
893                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
894                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
895                   }
896               }
897
898             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
899               {
900                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
901
902                 rs[i].q1hat +=
903                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
904                 rs[i].q2hat +=
905                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
906
907               }
908
909             case_unref (next_case);
910             case_unref (prev_case);
911             prev_case = case_clone (c);
912           }
913         casereader_destroy (r);
914         case_unref (prev_case);
915
916         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
917         if ( roc->invert ) 
918           rs[i].auc = 1 - rs[i].auc;
919
920         if ( roc->bi_neg_exp )
921           {
922             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
923             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
924           }
925         else
926           {
927             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
928             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
929           }
930       }
931     }
932
933   casereader_destroy (positives);
934   casereader_destroy (negatives);
935
936   caseproto_unref (n_proto);
937   subcase_destroy (&up_ordering);
938   subcase_destroy (&down_ordering);
939
940   output_roc (rs, roc);
941  
942   for (i = 0 ; i < roc->n_vars; ++i)
943     casereader_destroy (rs[i].cutpoint_rdr);
944
945   free (rs);
946 }
947
948 static void
949 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
950 {
951   int i;
952   const int n_fields = roc->print_se ? 5 : 1;
953   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
954   const int n_rows = 2 + roc->n_vars;
955   struct tab_table *tbl = tab_create (n_cols, n_rows);
956
957   if ( roc->n_vars > 1)
958     tab_title (tbl, _("Area Under the Curve"));
959   else
960     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
961
962   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
963
964
965   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
966
967   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
968
969   tab_box (tbl,
970            TAL_2, TAL_2,
971            -1, TAL_1,
972            0, 0,
973            n_cols - 1,
974            n_rows - 1);
975
976   if ( roc->print_se )
977     {
978       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
979       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
980
981       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
982       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
983
984       tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
985                              TAT_TITLE | TAB_CENTER,
986                              _("Asymp. %g%% Confidence Interval"), roc->ci);
987       tab_vline (tbl, 0, n_cols - 1, 0, 0);
988       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
989     }
990
991   if ( roc->n_vars > 1)
992     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
993
994   if ( roc->n_vars > 1)
995     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
996
997
998   for ( i = 0 ; i < roc->n_vars ; ++i )
999     {
1000       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
1001
1002       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
1003
1004       if ( roc->print_se )
1005         {
1006           double se ;
1007           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1008                                       (12 * rs[i].n1 * rs[i].n2));
1009           double ci ;
1010           double yy ;
1011
1012           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
1013             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
1014
1015           se /= rs[i].n1 * rs[i].n2;
1016
1017           se = sqrt (se);
1018
1019           tab_double (tbl, n_cols - 4, 2 + i, 0,
1020                       se,
1021                       NULL);
1022
1023           ci = 1 - roc->ci / 100.0;
1024           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1025
1026           tab_double (tbl, n_cols - 2, 2 + i, 0,
1027                       rs[i].auc - yy,
1028                       NULL);
1029
1030           tab_double (tbl, n_cols - 1, 2 + i, 0,
1031                       rs[i].auc + yy,
1032                       NULL);
1033
1034           tab_double (tbl, n_cols - 3, 2 + i, 0,
1035                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1036                       NULL);
1037         }
1038     }
1039
1040   tab_submit (tbl);
1041 }
1042
1043
1044 static void
1045 show_summary (const struct cmd_roc *roc)
1046 {
1047   const int n_cols = 3;
1048   const int n_rows = 4;
1049   struct tab_table *tbl = tab_create (n_cols, n_rows);
1050
1051   tab_title (tbl, _("Case Summary"));
1052
1053   tab_headers (tbl, 1, 0, 2, 0);
1054
1055   tab_box (tbl,
1056            TAL_2, TAL_2,
1057            -1, -1,
1058            0, 0,
1059            n_cols - 1,
1060            n_rows - 1);
1061
1062   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1063   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1064
1065
1066   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1067   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1068
1069
1070   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1071   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1072   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1073
1074   tab_joint_text (tbl, 1, 0, 2, 0,
1075                   TAT_TITLE | TAB_CENTER,
1076                   _("Valid N (listwise)"));
1077
1078
1079   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1080   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1081
1082
1083   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1084   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1085
1086   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1087   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1088
1089   tab_submit (tbl);
1090 }
1091
1092
1093 static void
1094 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1095 {
1096   int x = 1;
1097   int i;
1098   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1099   int n_rows = 1;
1100   struct tab_table *tbl ;
1101
1102   for (i = 0; i < roc->n_vars; ++i)
1103     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1104
1105   tbl = tab_create (n_cols, n_rows);
1106
1107   if ( roc->n_vars > 1)
1108     tab_title (tbl, _("Coordinates of the Curve"));
1109   else
1110     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1111
1112
1113   tab_headers (tbl, 1, 0, 1, 0);
1114
1115   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1116
1117   if ( roc->n_vars > 1)
1118     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1119
1120   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1121   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1122   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1123
1124   tab_box (tbl,
1125            TAL_2, TAL_2,
1126            -1, TAL_1,
1127            0, 0,
1128            n_cols - 1,
1129            n_rows - 1);
1130
1131   if ( roc->n_vars > 1)
1132     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1133
1134   for (i = 0; i < roc->n_vars; ++i)
1135     {
1136       struct ccase *cc;
1137       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1138
1139       if ( roc->n_vars > 1)
1140         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1141
1142       if ( i > 0)
1143         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1144
1145
1146       for (; (cc = casereader_read (r)) != NULL;
1147            case_unref (cc), x++)
1148         {
1149           const double se = case_data_idx (cc, ROC_TP)->f /
1150             (
1151              case_data_idx (cc, ROC_TP)->f
1152              +
1153              case_data_idx (cc, ROC_FN)->f
1154              );
1155
1156           const double sp = case_data_idx (cc, ROC_TN)->f /
1157             (
1158              case_data_idx (cc, ROC_TN)->f
1159              +
1160              case_data_idx (cc, ROC_FP)->f
1161              );
1162
1163           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1164                       var_get_print_format (roc->vars[i]));
1165
1166           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1167           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1168         }
1169
1170       casereader_destroy (r);
1171     }
1172
1173   tab_submit (tbl);
1174 }
1175
1176
1177 static void
1178 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1179 {
1180   show_summary (roc);
1181
1182   if ( roc->curve )
1183     {
1184       struct roc_chart *rc;
1185       size_t i;
1186
1187       rc = roc_chart_create (roc->reference);
1188       for (i = 0; i < roc->n_vars; i++)
1189         roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1190                            rs[i].cutpoint_rdr);
1191       roc_chart_submit (rc);
1192     }
1193
1194   show_auc (rs, roc);
1195
1196   if ( roc->print_coords )
1197     show_coords (rs, roc);
1198 }
1199