Remove double semicolons.
[pspp-builds.git] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/procedure.h>
20 #include <language/lexer/variable-parser.h>
21 #include <language/lexer/value-parser.h>
22 #include <language/command.h>
23 #include <language/lexer/lexer.h>
24
25 #include <data/casegrouper.h>
26 #include <data/casereader.h>
27 #include <data/casewriter.h>
28 #include <data/dictionary.h>
29 #include <data/format.h>
30 #include <math/sort.h>
31 #include <data/subcase.h>
32
33
34 #include <libpspp/misc.h>
35
36 #include <gsl/gsl_cdf.h>
37 #include <output/table.h>
38
39 #include <output/charts/plot-chart.h>
40 #include <output/charts/cartesian.h>
41
42 #include "gettext.h"
43 #define _(msgid) gettext (msgid)
44 #define N_(msgid) msgid
45
46 struct cmd_roc
47 {
48   size_t n_vars;
49   const struct variable **vars;
50   const struct dictionary *dict;
51
52   const struct variable *state_var ;
53   union value state_value;
54
55   /* Plot the roc curve */
56   bool curve;
57   /* Plot the reference line */
58   bool reference;
59
60   double ci;
61
62   bool print_coords;
63   bool print_se;
64   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
65                       should be used */
66   enum mv_class exclude;
67
68   bool invert ; /* True iff a smaller test result variable indicates
69                    a positive result */
70
71   double pos;
72   double neg;
73   double pos_weighted;
74   double neg_weighted;
75 };
76
77 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
78
79 int
80 cmd_roc (struct lexer *lexer, struct dataset *ds)
81 {
82   struct cmd_roc roc ;
83   const struct dictionary *dict = dataset_dict (ds);
84
85   roc.vars = NULL;
86   roc.n_vars = 0;
87   roc.print_se = false;
88   roc.print_coords = false;
89   roc.exclude = MV_ANY;
90   roc.curve = true;
91   roc.reference = false;
92   roc.ci = 95;
93   roc.bi_neg_exp = false;
94   roc.invert = false;
95   roc.pos = roc.pos_weighted = 0;
96   roc.neg = roc.neg_weighted = 0;
97   roc.dict = dataset_dict (ds);
98
99   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
100                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
101     goto error;
102
103   if ( ! lex_force_match (lexer, T_BY))
104     {
105       goto error;
106     }
107
108   roc.state_var = parse_variable (lexer, dict);
109
110   if ( !lex_force_match (lexer, '('))
111     {
112       goto error;
113     }
114
115   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
116
117
118   if ( !lex_force_match (lexer, ')'))
119     {
120       goto error;
121     }
122
123
124   while (lex_token (lexer) != '.')
125     {
126       lex_match (lexer, '/');
127       if (lex_match_id (lexer, "MISSING"))
128         {
129           lex_match (lexer, '=');
130           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
131             {
132               if (lex_match_id (lexer, "INCLUDE"))
133                 {
134                   roc.exclude = MV_SYSTEM;
135                 }
136               else if (lex_match_id (lexer, "EXCLUDE"))
137                 {
138                   roc.exclude = MV_ANY;
139                 }
140               else
141                 {
142                   lex_error (lexer, NULL);
143                   goto error;
144                 }
145             }
146         }
147       else if (lex_match_id (lexer, "PLOT"))
148         {
149           lex_match (lexer, '=');
150           if (lex_match_id (lexer, "CURVE"))
151             {
152               roc.curve = true;
153               if (lex_match (lexer, '('))
154                 {
155                   roc.reference = true;
156                   lex_force_match_id (lexer, "REFERENCE");
157                   lex_force_match (lexer, ')');
158                 }
159             }
160           else if (lex_match_id (lexer, "NONE"))
161             {
162               roc.curve = false;
163             }
164           else
165             {
166               lex_error (lexer, NULL);
167               goto error;
168             }
169         }
170       else if (lex_match_id (lexer, "PRINT"))
171         {
172           lex_match (lexer, '=');
173           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
174             {
175               if (lex_match_id (lexer, "SE"))
176                 {
177                   roc.print_se = true;
178                 }
179               else if (lex_match_id (lexer, "COORDINATES"))
180                 {
181                   roc.print_coords = true;
182                 }
183               else
184                 {
185                   lex_error (lexer, NULL);
186                   goto error;
187                 }
188             }
189         }
190       else if (lex_match_id (lexer, "CRITERIA"))
191         {
192           lex_match (lexer, '=');
193           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
194             {
195               if (lex_match_id (lexer, "CUTOFF"))
196                 {
197                   lex_force_match (lexer, '(');
198                   if (lex_match_id (lexer, "INCLUDE"))
199                     {
200                       roc.exclude = MV_SYSTEM;
201                     }
202                   else if (lex_match_id (lexer, "EXCLUDE"))
203                     {
204                       roc.exclude = MV_USER | MV_SYSTEM;
205                     }
206                   else
207                     {
208                       lex_error (lexer, NULL);
209                       goto error;
210                     }
211                   lex_force_match (lexer, ')');
212                 }
213               else if (lex_match_id (lexer, "TESTPOS"))
214                 {
215                   lex_force_match (lexer, '(');
216                   if (lex_match_id (lexer, "LARGE"))
217                     {
218                       roc.invert = false;
219                     }
220                   else if (lex_match_id (lexer, "SMALL"))
221                     {
222                       roc.invert = true;
223                     }
224                   else
225                     {
226                       lex_error (lexer, NULL);
227                       goto error;
228                     }
229                   lex_force_match (lexer, ')');
230                 }
231               else if (lex_match_id (lexer, "CI"))
232                 {
233                   lex_force_match (lexer, '(');
234                   lex_force_num (lexer);
235                   roc.ci = lex_number (lexer);
236                   lex_get (lexer);
237                   lex_force_match (lexer, ')');
238                 }
239               else if (lex_match_id (lexer, "DISTRIBUTION"))
240                 {
241                   lex_force_match (lexer, '(');
242                   if (lex_match_id (lexer, "FREE"))
243                     {
244                       roc.bi_neg_exp = false;
245                     }
246                   else if (lex_match_id (lexer, "NEGEXPO"))
247                     {
248                       roc.bi_neg_exp = true;
249                     }
250                   else
251                     {
252                       lex_error (lexer, NULL);
253                       goto error;
254                     }
255                   lex_force_match (lexer, ')');
256                 }
257               else
258                 {
259                   lex_error (lexer, NULL);
260                   goto error;
261                 }
262             }
263         }
264       else
265         {
266           lex_error (lexer, NULL);
267           break;
268         }
269     }
270
271   if ( ! run_roc (ds, &roc)) 
272     goto error;
273
274   free (roc.vars);
275   return CMD_SUCCESS;
276
277  error:
278   free (roc.vars);
279   return CMD_FAILURE;
280 }
281
282
283
284
285 static void
286 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
287
288
289 static int
290 run_roc (struct dataset *ds, struct cmd_roc *roc)
291 {
292   struct dictionary *dict = dataset_dict (ds);
293   bool ok;
294   struct casereader *group;
295
296   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
297   while (casegrouper_get_next_group (grouper, &group))
298     {
299       do_roc (roc, group, dataset_dict (ds));
300     }
301   ok = casegrouper_destroy (grouper);
302   ok = proc_commit (ds) && ok;
303
304   return ok;
305 }
306
307 #if 0
308 static void
309 dump_casereader (struct casereader *reader)
310 {
311   struct ccase *c;
312   struct casereader *r = casereader_clone (reader);
313
314   for ( ; (c = casereader_read (r) ); case_unref (c))
315     {
316       int i;
317       for (i = 0 ; i < case_get_value_cnt (c); ++i)
318         {
319           printf ("%g ", case_data_idx (c, i)->f);
320         }
321       printf ("\n");
322     }
323
324   casereader_destroy (r);
325 }
326 #endif
327
328
329 /* 
330    Return true iff the state variable indicates that C has positive actual state.
331
332    As a side effect, this function also accumulates the roc->{pos,neg} and 
333    roc->{pos,neg}_weighted counts.
334  */
335 static bool
336 match_positives (const struct ccase *c, void *aux)
337 {
338   struct cmd_roc *roc = aux;
339   const struct variable *wv = dict_get_weight (roc->dict);
340   const double weight = wv ? case_data (c, wv)->f : 1.0;
341
342   const bool positive =
343   ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
344     var_get_width (roc->state_var)));
345
346   if ( positive )
347     {
348       roc->pos++;
349       roc->pos_weighted += weight;
350     }
351   else
352     {
353       roc->neg++;
354       roc->neg_weighted += weight;
355     }
356
357   return positive;
358 }
359
360
361 #define VALUE  0
362 #define N_EQ   1
363 #define N_PRED 2
364
365 /* Some intermediate state for calculating the cutpoints and the 
366    standard error values */
367 struct roc_state
368 {
369   double auc;  /* Area under the curve */
370
371   double n1;  /* total weight of positives */
372   double n2;  /* total weight of negatives */
373
374   /* intermediates for standard error */
375   double q1hat; 
376   double q2hat;
377
378   /* intermediates for cutpoints */
379   struct casewriter *cutpoint_wtr;
380   struct casereader *cutpoint_rdr;
381   double prev_result;
382   double min;
383   double max;
384 };
385
386 #define CUTPOINT 0
387 #define TP 1
388 #define FN 2
389 #define TN 3
390 #define FP 4
391
392
393 /* 
394    Return a new casereader based upon CUTPOINT_RDR.
395    The number of "positive" cases are placed into
396    the position TRUE_INDEX, and the number of "negative" cases
397    into FALSE_INDEX.
398    POS_COND and RESULT determine the semantics of what is 
399    "positive".
400    WEIGHT is the value of a single count.
401  */
402 static struct casereader *
403 accumulate_counts (struct casereader *cutpoint_rdr, 
404                    double result, double weight, 
405                    bool (*pos_cond) (double, double),
406                    int true_index, int false_index)
407 {
408   const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
409   struct casewriter *w =
410     autopaging_writer_create (proto);
411   struct casereader *r = casereader_clone (cutpoint_rdr);
412   struct ccase *cpc;
413   double prev_cp = SYSMIS;
414
415   for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
416     {
417       struct ccase *new_case;
418       const double cp = case_data_idx (cpc, CUTPOINT)->f;
419
420       assert (cp != SYSMIS);
421
422       /* We don't want duplicates here */
423       if ( cp == prev_cp )
424         continue;
425
426       new_case = case_clone (cpc);
427
428       if ( pos_cond (result, cp))
429         case_data_rw_idx (new_case, true_index)->f += weight;
430       else
431         case_data_rw_idx (new_case, false_index)->f += weight;
432
433       prev_cp = cp;
434
435       casewriter_write (w, new_case);
436     }
437   casereader_destroy (r);
438
439   return casewriter_make_reader (w);
440 }
441
442
443
444 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
445
446 /*
447   This function does 3 things:
448
449   1. Counts the number of cases which are equal to every other case in READER,
450   and those cases for which the relationship between it and every other case
451   satifies PRED (normally either > or <).  VAR is variable defining a case's value
452   for this purpose.
453
454   2. Counts the number of true and false cases in reader, and populates
455   CUTPOINT_RDR accordingly.  TRUE_INDEX and FALSE_INDEX are the indices
456   which receive these values.  POS_COND is the condition defining true
457   and false.
458   
459   3. CC is filled with the cumulative weight of all cases of READER.
460 */
461 static struct casereader *
462 process_group (const struct variable *var, struct casereader *reader,
463                bool (*pred) (double, double),
464                const struct dictionary *dict,
465                double *cc,
466                struct casereader **cutpoint_rdr, 
467                bool (*pos_cond) (double, double),
468                int true_index,
469                int false_index)
470 {
471   const struct variable *w = dict_get_weight (dict);
472
473   struct casereader *r1 =
474     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
475
476   const int weight_idx  = w ? var_get_case_index (w) :
477     caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
478   
479   struct ccase *c1;
480
481   struct casereader *rclone = casereader_clone (r1);
482   struct casewriter *wtr;
483   struct caseproto *proto = caseproto_create ();
484
485   proto = caseproto_add_width (proto, 0);
486   proto = caseproto_add_width (proto, 0);
487   proto = caseproto_add_width (proto, 0);
488
489   wtr = autopaging_writer_create (proto);  
490
491   *cc = 0;
492
493   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
494     {
495       struct ccase *new_case = case_create (proto);
496       struct ccase *c2;
497       struct casereader *r2 = casereader_clone (rclone);
498
499       const double weight1 = case_data_idx (c1, weight_idx)->f;
500       const double d1 = case_data (c1, var)->f;
501       double n_eq = 0.0;
502       double n_pred = 0.0;
503
504       *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
505                                          pos_cond,
506                                          true_index, false_index);
507
508       *cc += weight1;
509
510       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
511         {
512           const double d2 = case_data (c2, var)->f;
513           const double weight2 = case_data_idx (c2, weight_idx)->f;
514
515           if ( d1 == d2 )
516             {
517               n_eq += weight2;
518               continue;
519             }
520           else  if ( pred (d2, d1))
521             {
522               n_pred += weight2;
523             }
524         }
525
526       case_data_rw_idx (new_case, VALUE)->f = d1;
527       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
528       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
529
530       casewriter_write (wtr, new_case);
531
532       casereader_destroy (r2);
533     }
534
535   casereader_destroy (r1);
536   casereader_destroy (rclone);
537
538   return casewriter_make_reader (wtr);
539 }
540
541 /* Some more indeces into case data */
542 #define N_POS_EQ 1  /* number of positive cases with values equal to n */
543 #define N_POS_GT 2  /* number of postive cases with values greater than n */
544 #define N_NEG_EQ 3  /* number of negative cases with values equal to n */
545 #define N_NEG_LT 4  /* number of negative cases with values less than n */
546
547 static bool
548 gt (double d1, double d2)
549 {
550   return d1 > d2;
551 }
552
553
554 static bool
555 ge (double d1, double d2)
556 {
557   return d1 > d2;
558 }
559
560 static bool
561 lt (double d1, double d2)
562 {
563   return d1 < d2;
564 }
565
566
567 /*
568   Return a casereader with width 3,
569   populated with cases based upon READER.
570   The cases will have the values:
571   (N, number of cases equal to N, number of cases greater than N)
572   As a side effect, update RS->n1 with the number of positive cases.
573 */
574 static struct casereader *
575 process_positive_group (const struct variable *var, struct casereader *reader,
576                         const struct dictionary *dict,
577                         struct roc_state *rs)
578 {
579   return process_group (var, reader, gt, dict, &rs->n1,
580                         &rs->cutpoint_rdr,
581                         ge,
582                         TP, FN);
583 }
584
585 /*
586   Return a casereader with width 3,
587   populated with cases based upon READER.
588   The cases will have the values:
589   (N, number of cases equal to N, number of cases less than N)
590   As a side effect, update RS->n2 with the number of negative cases.
591 */
592 static struct casereader *
593 process_negative_group (const struct variable *var, struct casereader *reader,
594                         const struct dictionary *dict,
595                         struct roc_state *rs)
596 {
597   return process_group (var, reader, lt, dict, &rs->n2,
598                         &rs->cutpoint_rdr,
599                         lt,
600                         TN, FP);
601 }
602
603
604
605
606 static void
607 append_cutpoint (struct casewriter *writer, double cutpoint)
608 {
609   struct ccase *cc = case_create (casewriter_get_proto (writer));
610
611   case_data_rw_idx (cc, CUTPOINT)->f = cutpoint;
612   case_data_rw_idx (cc, TP)->f = 0;
613   case_data_rw_idx (cc, FN)->f = 0;
614   case_data_rw_idx (cc, TN)->f = 0;
615   case_data_rw_idx (cc, FP)->f = 0;
616
617   casewriter_write (writer, cc);
618 }
619
620
621 /* 
622    Create and initialise the rs[x].cutpoint_rdr casereaders.  That is, the readers will
623    be created with width 5, ready to take the values (cutpoint, TP, FN, TN, FP), and the
624    reader will be populated with its final number of cases.
625    However on exit from this function, only CUTPOINT entries will be set to their final
626    value.  The other entries will be initialised to zero.
627 */
628 static void
629 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
630 {
631   int i;
632   struct casereader *r = casereader_clone (input);
633   struct ccase *c;
634   struct caseproto *proto = caseproto_create ();
635
636   struct subcase ordering;
637   subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND);
638
639   proto = caseproto_add_width (proto, 0); /* cutpoint */
640   proto = caseproto_add_width (proto, 0); /* TP */
641   proto = caseproto_add_width (proto, 0); /* FN */
642   proto = caseproto_add_width (proto, 0); /* TN */
643   proto = caseproto_add_width (proto, 0); /* FP */
644
645   for (i = 0 ; i < roc->n_vars; ++i)
646     {
647       rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
648       rs[i].prev_result = SYSMIS;
649       rs[i].max = -DBL_MAX;
650       rs[i].min = DBL_MAX;
651     }
652
653   for (; (c = casereader_read (r)) != NULL; case_unref (c))
654     {
655       for (i = 0 ; i < roc->n_vars; ++i)
656         {
657           const union value *v = case_data (c, roc->vars[i]); 
658           const double result = v->f;
659
660           if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
661             continue;
662
663           minimize (&rs[i].min, result);
664           maximize (&rs[i].max, result);
665
666           if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
667             {
668               const double mean = (result + rs[i].prev_result ) / 2.0;
669               append_cutpoint (rs[i].cutpoint_wtr, mean);
670             }
671
672           rs[i].prev_result = result;
673         }
674     }
675   casereader_destroy (r);
676
677
678   /* Append the min and max cutpoints */
679   for (i = 0 ; i < roc->n_vars; ++i)
680     {
681       append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
682       append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
683
684       rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
685     }
686 }
687
688 static void
689 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
690 {
691   int i;
692
693   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
694
695   struct casereader *negatives = NULL;
696   struct casereader *positives = NULL;
697
698   struct caseproto *n_proto = caseproto_create ();
699
700   struct subcase up_ordering;
701   struct subcase down_ordering;
702
703   struct casewriter *neg_wtr = NULL;
704
705   struct casereader *input = casereader_create_filter_missing (reader,
706                                                                roc->vars, roc->n_vars,
707                                                                roc->exclude,
708                                                                NULL,
709                                                                NULL);
710
711   input = casereader_create_filter_missing (input,
712                                             &roc->state_var, 1,
713                                             roc->exclude,
714                                             NULL,
715                                             NULL);
716
717   neg_wtr = autopaging_writer_create (casereader_get_proto (input));
718
719   prepare_cutpoints (roc, rs, input);
720
721
722   /* Separate the positive actual state cases from the negative ones */
723   positives = 
724     casereader_create_filter_func (input,
725                                    match_positives,
726                                    NULL,
727                                    roc,
728                                    neg_wtr);
729
730   n_proto = caseproto_create ();
731       
732   n_proto = caseproto_add_width (n_proto, 0);
733   n_proto = caseproto_add_width (n_proto, 0);
734   n_proto = caseproto_add_width (n_proto, 0);
735   n_proto = caseproto_add_width (n_proto, 0);
736   n_proto = caseproto_add_width (n_proto, 0);
737
738   subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
739   subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
740
741   for (i = 0 ; i < roc->n_vars; ++i)
742     {
743       struct casewriter *w = NULL;
744       struct casereader *r = NULL;
745
746       struct ccase *c;
747
748       struct ccase *cpos;
749       struct casereader *n_neg ;
750       const struct variable *var = roc->vars[i];
751
752       struct casereader *neg ;
753       struct casereader *pos = casereader_clone (positives);
754
755
756       struct casereader *n_pos =
757         process_positive_group (var, pos, dict, &rs[i]);
758
759       if ( negatives == NULL)
760         {
761           negatives = casewriter_make_reader (neg_wtr);
762         }
763
764       neg = casereader_clone (negatives);
765
766       n_neg = process_negative_group (var, neg, dict, &rs[i]);
767
768
769       /* Merge the n_pos and n_neg casereaders */
770       w = sort_create_writer (&up_ordering, n_proto);
771       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
772         {
773           struct ccase *pos_case = case_create (n_proto);
774           struct ccase *cneg;
775           const double jpos = case_data_idx (cpos, VALUE)->f;
776
777           while ((cneg = casereader_read (n_neg)))
778             {
779               struct ccase *nc = case_create (n_proto);
780
781               const double jneg = case_data_idx (cneg, VALUE)->f;
782
783               case_data_rw_idx (nc, VALUE)->f = jneg;
784               case_data_rw_idx (nc, N_POS_EQ)->f = 0;
785
786               case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
787
788               *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
789               *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
790
791               casewriter_write (w, nc);
792
793               case_unref (cneg);
794               if ( jneg > jpos)
795                 break;
796             }
797
798           case_data_rw_idx (pos_case, VALUE)->f = jpos;
799           *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
800           *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
801           case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
802           case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
803
804           casewriter_write (w, pos_case);
805         }
806
807 /* These aren't used anymore */
808 #undef N_EQ
809 #undef N_PRED
810
811       r = casewriter_make_reader (w);
812
813       /* Propagate the N_POS_GT values from the positive cases
814          to the negative ones */
815       {
816         double prev_pos_gt = rs[i].n1;
817         w = sort_create_writer (&down_ordering, n_proto);
818
819         for ( ; (c = casereader_read (r) ); case_unref (c))
820           {
821             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
822             struct ccase *nc = case_clone (c);
823
824             if ( n_pos_gt == SYSMIS)
825               {
826                 n_pos_gt = prev_pos_gt;
827                 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
828               }
829             
830             casewriter_write (w, nc);
831             prev_pos_gt = n_pos_gt;
832           }
833
834         r = casewriter_make_reader (w);
835       }
836
837       /* Propagate the N_NEG_LT values from the negative cases
838          to the positive ones */
839       {
840         double prev_neg_lt = rs[i].n2;
841         w = sort_create_writer (&up_ordering, n_proto);
842
843         for ( ; (c = casereader_read (r) ); case_unref (c))
844           {
845             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
846             struct ccase *nc = case_clone (c);
847
848             if ( n_neg_lt == SYSMIS)
849               {
850                 n_neg_lt = prev_neg_lt;
851                 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
852               }
853             
854             casewriter_write (w, nc);
855             prev_neg_lt = n_neg_lt;
856           }
857
858         r = casewriter_make_reader (w);
859       }
860
861       {
862         struct ccase *prev_case = NULL;
863         for ( ; (c = casereader_read (r) ); case_unref (c))
864           {
865             const struct ccase *next_case = casereader_peek (r, 0);
866
867             const double j = case_data_idx (c, VALUE)->f;
868             double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
869             double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
870             double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
871             double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
872
873             if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
874               {
875                 if ( 0 ==  case_data_idx (c, N_POS_EQ)->f)
876                   {
877                     n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
878                     n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
879                   }
880
881                 if ( 0 ==  case_data_idx (c, N_NEG_EQ)->f)
882                   {
883                     n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
884                     n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
885                   }
886               }
887
888             if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
889               {
890                 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
891
892                 rs[i].q1hat +=
893                   n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
894                 rs[i].q2hat +=
895                   n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
896
897               }
898
899             case_unref (prev_case);
900             prev_case = case_clone (c);
901           }
902
903         rs[i].auc /=  rs[i].n1 * rs[i].n2; 
904         if ( roc->invert ) 
905           rs[i].auc = 1 - rs[i].auc;
906
907         if ( roc->bi_neg_exp )
908           {
909             rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
910             rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
911           }
912         else
913           {
914             rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
915             rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
916           }
917       }
918     }
919
920   casereader_destroy (positives);
921   casereader_destroy (negatives);
922
923   output_roc (rs, roc);
924
925   free (rs);
926 }
927
928 static void
929 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
930 {
931   int i;
932   const int n_fields = roc->print_se ? 5 : 1;
933   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
934   const int n_rows = 2 + roc->n_vars;
935   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
936
937   if ( roc->n_vars > 1)
938     tab_title (tbl, _("Area Under the Curve"));
939   else
940     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
941
942   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
943
944   tab_dim (tbl, tab_natural_dimensions, NULL);
945
946   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
947
948   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
949
950   tab_box (tbl,
951            TAL_2, TAL_2,
952            -1, TAL_1,
953            0, 0,
954            n_cols - 1,
955            n_rows - 1);
956
957   if ( roc->print_se )
958     {
959       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
960       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
961
962       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
963       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
964
965       tab_joint_text (tbl, n_cols - 2, 0, 4, 0,
966                       TAT_TITLE | TAB_CENTER | TAT_PRINTF,
967                       _("Asymp. %g%% Confidence Interval"), roc->ci);
968       tab_vline (tbl, 0, n_cols - 1, 0, 0);
969       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
970     }
971
972   if ( roc->n_vars > 1)
973     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
974
975   if ( roc->n_vars > 1)
976     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
977
978
979   for ( i = 0 ; i < roc->n_vars ; ++i )
980     {
981       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
982
983       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
984
985       if ( roc->print_se )
986         {
987           double se ;
988           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
989                                       (12 * rs[i].n1 * rs[i].n2));
990           double ci ;
991           double yy ;
992
993           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
994             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
995
996           se /= rs[i].n1 * rs[i].n2;
997
998           se = sqrt (se);
999
1000           tab_double (tbl, n_cols - 4, 2 + i, 0,
1001                       se,
1002                       NULL);
1003
1004           ci = 1 - roc->ci / 100.0;
1005           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1006
1007           tab_double (tbl, n_cols - 2, 2 + i, 0,
1008                       rs[i].auc - yy,
1009                       NULL);
1010
1011           tab_double (tbl, n_cols - 1, 2 + i, 0,
1012                       rs[i].auc + yy,
1013                       NULL);
1014
1015           tab_double (tbl, n_cols - 3, 2 + i, 0,
1016                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1017                       NULL);
1018         }
1019     }
1020
1021   tab_submit (tbl);
1022 }
1023
1024
1025 static void
1026 show_summary (const struct cmd_roc *roc)
1027 {
1028   const int n_cols = 3;
1029   const int n_rows = 4;
1030   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
1031
1032   tab_title (tbl, _("Case Summary"));
1033
1034   tab_headers (tbl, 1, 0, 2, 0);
1035
1036   tab_dim (tbl, tab_natural_dimensions, NULL);
1037
1038   tab_box (tbl,
1039            TAL_2, TAL_2,
1040            -1, -1,
1041            0, 0,
1042            n_cols - 1,
1043            n_rows - 1);
1044
1045   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1046   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1047
1048
1049   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1050   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1051
1052
1053   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1054   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1055   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1056
1057   tab_joint_text (tbl, 1, 0, 2, 0,
1058                   TAT_TITLE | TAB_CENTER,
1059                   _("Valid N (listwise)"));
1060
1061
1062   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1063   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1064
1065
1066   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1067   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1068
1069   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1070   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1071
1072   tab_submit (tbl);
1073 }
1074
1075
1076 static void
1077 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1078 {
1079   int x = 1;
1080   int i;
1081   const int n_cols = roc->n_vars > 1 ? 4 : 3;
1082   int n_rows = 1;
1083   struct tab_table *tbl ;
1084
1085   for (i = 0; i < roc->n_vars; ++i)
1086     n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1087
1088   tbl = tab_create (n_cols, n_rows, 0);
1089
1090   if ( roc->n_vars > 1)
1091     tab_title (tbl, _("Coordinates of the Curve"));
1092   else
1093     tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1094
1095
1096   tab_headers (tbl, 1, 0, 1, 0);
1097
1098   tab_dim (tbl, tab_natural_dimensions, NULL);
1099
1100   tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1101
1102   if ( roc->n_vars > 1)
1103     tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1104
1105   tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1106   tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1107   tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1108
1109   tab_box (tbl,
1110            TAL_2, TAL_2,
1111            -1, TAL_1,
1112            0, 0,
1113            n_cols - 1,
1114            n_rows - 1);
1115
1116   if ( roc->n_vars > 1)
1117     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1118
1119   for (i = 0; i < roc->n_vars; ++i)
1120     {
1121       struct ccase *cc;
1122       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1123
1124       if ( roc->n_vars > 1)
1125         tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1126
1127       if ( i > 0)
1128         tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1129
1130
1131       for (; (cc = casereader_read (r)) != NULL;
1132            case_unref (cc), x++)
1133         {
1134           const double se = case_data_idx (cc, TP)->f /
1135             (
1136              case_data_idx (cc, TP)->f
1137              +
1138              case_data_idx (cc, FN)->f
1139              );
1140
1141           const double sp = case_data_idx (cc, TN)->f /
1142             (
1143              case_data_idx (cc, TN)->f
1144              +
1145              case_data_idx (cc, FP)->f
1146              );
1147
1148           tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, CUTPOINT)->f,
1149                       var_get_print_format (roc->vars[i]));
1150
1151           tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1152           tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1153         }
1154
1155       casereader_destroy (r);
1156     }
1157
1158   tab_submit (tbl);
1159 }
1160
1161
1162 static void
1163 draw_roc (struct roc_state *rs, const struct cmd_roc *roc)
1164 {
1165   int i;
1166
1167   struct chart *roc_chart = chart_create ();
1168
1169   chart_write_title (roc_chart, _("ROC Curve"));
1170   chart_write_xlabel (roc_chart, _("1 - Specificity"));
1171   chart_write_ylabel (roc_chart, _("Sensitivity"));
1172
1173   chart_write_xscale (roc_chart, 0, 1, 5);
1174   chart_write_yscale (roc_chart, 0, 1, 5);
1175
1176   if ( roc->reference )
1177     {
1178       chart_line (roc_chart, 1.0, 0,
1179                   0.0, 1.0,
1180                   CHART_DIM_X);
1181     }
1182
1183   for (i = 0; i < roc->n_vars; ++i)
1184     {
1185       struct ccase *cc;
1186       struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1187
1188       chart_vector_start (roc_chart, var_get_name (roc->vars[i]));
1189       for (; (cc = casereader_read (r)) != NULL;
1190            case_unref (cc))
1191         {
1192           double se = case_data_idx (cc, TP)->f;
1193           double sp = case_data_idx (cc, TN)->f;
1194
1195           se /= case_data_idx (cc, FN)->f +
1196             case_data_idx (cc, TP)->f ;
1197
1198           sp /= case_data_idx (cc, TN)->f +
1199             case_data_idx (cc, FP)->f ;
1200
1201           chart_vector (roc_chart, 1 - sp, se);
1202         }
1203       chart_vector_end (roc_chart);
1204       casereader_destroy (r);
1205     }
1206
1207   chart_write_legend (roc_chart);
1208
1209   chart_submit (roc_chart);
1210 }
1211
1212
1213 static void
1214 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1215 {
1216   show_summary (roc);
1217
1218   if ( roc->curve )
1219     draw_roc (rs, roc);
1220
1221   show_auc (rs, roc);
1222
1223
1224   if ( roc->print_coords )
1225     show_coords (rs, roc);
1226 }
1227