Fix bug when positive and negative groups are of different lengths
[pspp-builds.git] / src / language / stats / roc.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "roc.h"
20 #include <data/procedure.h>
21 #include <language/lexer/variable-parser.h>
22 #include <language/lexer/value-parser.h>
23 #include <language/lexer/lexer.h>
24
25 #include <data/casegrouper.h>
26 #include <data/casereader.h>
27 #include <data/casewriter.h>
28 #include <data/dictionary.h>
29 #include <data/format.h>
30 #include <math/sort.h>
31
32 #include <libpspp/misc.h>
33
34 #include <gsl/gsl_cdf.h>
35 #include <output/table.h>
36
37 #include "gettext.h"
38 #define _(msgid) gettext (msgid)
39 #define N_(msgid) msgid
40
41 struct cmd_roc
42 {
43   size_t n_vars;
44   const struct variable **vars;
45
46   struct variable *state_var ;
47   union value state_value;
48
49   /* Plot the roc curve */
50   bool curve;
51   /* Plot the reference line */
52   bool reference;
53
54   double ci;
55
56   bool print_coords;
57   bool print_se;
58   bool bi_neg_exp; /* True iff the bi-negative exponential critieria
59                       should be used */
60   enum mv_class exclude;
61
62   bool invert ; /* True iff a smaller test result variable indicates
63                    a positive result */
64 };
65
66 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
67
68 int
69 cmd_roc (struct lexer *lexer, struct dataset *ds)
70 {
71   struct cmd_roc roc ;
72   const struct dictionary *dict = dataset_dict (ds);
73
74   roc.vars = NULL;
75   roc.n_vars = 0;
76   roc.print_se = false;
77   roc.print_coords = false;
78   roc.exclude = MV_ANY;
79   roc.curve = true;
80   roc.reference = false;
81   roc.ci = 95;
82   roc.bi_neg_exp = false;
83   roc.invert = false;
84
85   if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
86                               PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
87     return 2;
88
89   if ( ! lex_force_match (lexer, T_BY))
90     {
91       return 2;
92     }
93
94   roc.state_var = parse_variable (lexer, dict);
95
96   if ( !lex_force_match (lexer, '('))
97     {
98       return 2;
99     }
100
101   parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
102
103
104   if ( !lex_force_match (lexer, ')'))
105     {
106       return 2;
107     }
108
109
110   while (lex_token (lexer) != '.')
111     {
112       lex_match (lexer, '/');
113       if (lex_match_id (lexer, "MISSING"))
114         {
115           lex_match (lexer, '=');
116           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
117             {
118               if (lex_match_id (lexer, "INCLUDE"))
119                 {
120                   roc.exclude = MV_SYSTEM;
121                 }
122               else if (lex_match_id (lexer, "EXCLUDE"))
123                 {
124                   roc.exclude = MV_ANY;
125                 }
126               else
127                 {
128                   lex_error (lexer, NULL);
129                   return 2;
130                 }
131             }
132         }
133       else if (lex_match_id (lexer, "PLOT"))
134         {
135           lex_match (lexer, '=');
136           if (lex_match_id (lexer, "CURVE"))
137             {
138               roc.curve = true;
139               if (lex_match (lexer, '('))
140                 {
141                   roc.reference = true;
142                   lex_force_match_id (lexer, "REFERENCE");
143                   lex_force_match (lexer, ')');
144                 }
145             }
146           else if (lex_match_id (lexer, "NONE"))
147             {
148               roc.curve = false;
149             }
150           else
151             {
152               lex_error (lexer, NULL);
153               return 2;
154             }
155         }
156       else if (lex_match_id (lexer, "PRINT"))
157         {
158           lex_match (lexer, '=');
159           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
160             {
161               if (lex_match_id (lexer, "SE"))
162                 {
163                   roc.print_se = true;
164                 }
165               else if (lex_match_id (lexer, "COORDINATES"))
166                 {
167                   roc.print_coords = true;
168                 }
169               else
170                 {
171                   lex_error (lexer, NULL);
172                   return 2;
173                 }
174             }
175         }
176       else if (lex_match_id (lexer, "CRITERIA"))
177         {
178           lex_match (lexer, '=');
179           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
180             {
181               if (lex_match_id (lexer, "CUTOFF"))
182                 {
183                   lex_force_match (lexer, '(');
184                   if (lex_match_id (lexer, "INCLUDE"))
185                     {
186                       roc.exclude = MV_SYSTEM;
187                     }
188                   else if (lex_match_id (lexer, "EXCLUDE"))
189                     {
190                       roc.exclude = MV_USER | MV_SYSTEM;
191                     }
192                   else
193                     {
194                       lex_error (lexer, NULL);
195                       return 2;
196                     }
197                   lex_force_match (lexer, ')');
198                 }
199               else if (lex_match_id (lexer, "TESTPOS"))
200                 {
201                   lex_force_match (lexer, '(');
202                   if (lex_match_id (lexer, "LARGE"))
203                     {
204                       roc.invert = false;
205                     }
206                   else if (lex_match_id (lexer, "SMALL"))
207                     {
208                       roc.invert = true;
209                     }
210                   else
211                     {
212                       lex_error (lexer, NULL);
213                       return 2;
214                     }
215                   lex_force_match (lexer, ')');
216                 }
217               else if (lex_match_id (lexer, "CI"))
218                 {
219                   lex_force_match (lexer, '(');
220                   lex_force_num (lexer);
221                   roc.ci = lex_number (lexer);
222                   lex_get (lexer);
223                   lex_force_match (lexer, ')');
224                 }
225               else if (lex_match_id (lexer, "DISTRIBUTION"))
226                 {
227                   lex_force_match (lexer, '(');
228                   if (lex_match_id (lexer, "FREE"))
229                     {
230                       roc.bi_neg_exp = false;
231                     }
232                   else if (lex_match_id (lexer, "NEGEXPO"))
233                     {
234                       roc.bi_neg_exp = true;
235                     }
236                   else
237                     {
238                       lex_error (lexer, NULL);
239                       return 2;
240                     }
241                   lex_force_match (lexer, ')');
242                 }
243               else
244                 {
245                   lex_error (lexer, NULL);
246                   return 2;
247                 }
248             }
249         }
250       else
251         {
252           lex_error (lexer, NULL);
253           break;
254         }
255     }
256
257   run_roc (ds, &roc);
258
259   return 1;
260 }
261
262
263
264
265 static void
266 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
267
268
269 static int
270 run_roc (struct dataset *ds, struct cmd_roc *roc)
271 {
272   struct dictionary *dict = dataset_dict (ds);
273   bool ok;
274   struct casereader *group;
275
276   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
277   while (casegrouper_get_next_group (grouper, &group))
278     {
279       do_roc (roc, group, dataset_dict (ds));
280     }
281   ok = casegrouper_destroy (grouper);
282   ok = proc_commit (ds) && ok;
283
284   return ok;
285 }
286
287
288 static void
289 dump_casereader (struct casereader *reader)
290 {
291   struct ccase *c;
292   struct casereader *r = casereader_clone (reader);
293
294   for ( ; (c = casereader_read (r) ); case_unref (c))
295     {
296       int i;
297       for (i = 0 ; i < case_get_value_cnt (c); ++i)
298         {
299           printf ("%g ", case_data_idx (c, i)->f);
300         }
301       printf ("\n");
302     }
303
304   casereader_destroy (r);
305 }
306
307 static bool
308 match_positives (const struct ccase *c, void *aux)
309 {
310   struct cmd_roc *roc = aux;
311
312   return 0 == value_compare_3way (case_data (c, roc->state_var),
313                                  &roc->state_value,
314                                  var_get_width (roc->state_var));
315 }
316
317
318 #define VALUE  0
319 #define N_EQ   1
320 #define N_PRED 2
321
322 struct roc_state
323 {
324   double auc;
325
326   double n1;
327   double n2;
328
329   double q1hat;
330   double q2hat;
331 };
332
333
334 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
335
336
337
338 static struct casereader *
339 process_group (const struct variable *var, struct casereader *reader,
340                bool (*pred) (double, double),
341                const struct dictionary *dict,
342                double *cc)
343 {
344   const struct variable *w = dict_get_weight (dict);
345   const int weight_idx  = w ? var_get_case_index (w) :
346     caseproto_get_n_widths (casereader_get_proto (reader)) - 1;
347
348   struct casereader *r1 =
349     casereader_create_distinct (sort_execute_1var (reader, var), var, w);
350
351   struct ccase *c1;
352
353   struct casereader *rclone = casereader_clone (r1);
354   struct casewriter *wtr;
355   struct caseproto *proto = caseproto_create ();
356
357   proto = caseproto_add_width (proto, 0);
358   proto = caseproto_add_width (proto, 0);
359   proto = caseproto_add_width (proto, 0);
360
361   wtr = autopaging_writer_create (proto);  
362
363   *cc = 0;
364   
365   for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
366     {
367       struct ccase *c2;
368       struct casereader *r2 = casereader_clone (rclone);
369
370       const double weight1 = case_data_idx (c1, weight_idx)->f;
371       const double d1 = case_data (c1, var)->f;
372       double n_eq = 0.0;
373       double n_pred = 0.0;
374
375
376       struct ccase *new_case = case_create (proto);
377
378       *cc += weight1;
379
380       for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
381         {
382           const double d2 = case_data (c2, var)->f;
383           const double weight2 = case_data_idx (c2, weight_idx)->f;
384
385           if ( d1 == d2 )
386             {
387               n_eq += weight2;
388               continue;
389             }
390           else  if ( pred (d2, d1))
391             {
392               n_pred += weight2;
393             }
394         }
395
396       case_data_rw_idx (new_case, VALUE)->f = d1;
397       case_data_rw_idx (new_case, N_EQ)->f = n_eq;
398       case_data_rw_idx (new_case, N_PRED)->f = n_pred;
399
400       casewriter_write (wtr, new_case);
401
402       casereader_destroy (r2);
403     }
404
405   casereader_destroy (r1);
406   casereader_destroy (rclone);
407
408   return casewriter_make_reader (wtr);
409 }
410
411 static bool
412 gt (double d1, double d2)
413 {
414   return d1 > d2;
415 }
416
417 static bool
418 lt (double d1, double d2)
419 {
420   return d1 < d2;
421 }
422
423
424 static void
425 do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict)
426 {
427   int i;
428
429   struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
430
431   const struct caseproto *proto = casereader_get_proto (input);
432
433   struct casewriter *neg_wtr = autopaging_writer_create (proto);
434
435   struct casereader *negatives = NULL;
436
437   struct casereader *positives = 
438     casereader_create_filter_func (input,
439                                    match_positives,
440                                    NULL,
441                                    roc,
442                                    neg_wtr);
443
444
445   for (i = 0 ; i < roc->n_vars; ++i)
446     {
447       struct ccase *cpos;
448       struct casereader *n_neg ;
449       const struct variable *var = roc->vars[i];
450
451       struct casereader *neg ;
452       struct casereader *pos = casereader_clone (positives);
453
454       struct casereader *n_pos = process_group (var, pos, gt, dict, &rs[i].n1);
455
456       if ( negatives == NULL)
457         {
458           negatives = casewriter_make_reader (neg_wtr);
459         }
460     
461       neg = casereader_clone (negatives);
462
463       n_neg = process_group (var, neg, lt, dict, &rs[i].n2);
464
465       /* Simple join on VALUE */
466       for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
467         {
468           struct ccase *cneg = NULL;
469           double dneg = -DBL_MAX;
470           const double dpos = case_data_idx (cpos, VALUE)->f;
471           while (dneg < dpos)
472             {
473               if ( cneg )
474                 case_unref (cneg);
475
476               cneg = casereader_read (n_neg);
477               if ( ! cneg )
478                 break;
479               dneg = case_data_idx (cneg, VALUE)->f;
480             }
481         
482           if ( dpos == dneg )
483             {
484               double n_pos_eq = case_data_idx (cpos, N_EQ)->f;
485               double n_neg_eq = case_data_idx (cneg, N_EQ)->f;
486               double n_pos_gt = case_data_idx (cpos, N_PRED)->f;
487               double n_neg_lt = case_data_idx (cneg, N_PRED)->f;
488
489               rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
490               rs[i].q1hat +=
491                 n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
492               rs[i].q2hat +=
493                 n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
494             }
495
496           if ( cneg )
497             case_unref (cneg);
498         }
499
500       rs[i].auc /=  rs[i].n1 * rs[i].n2; 
501       if ( roc->invert ) 
502         rs[i].auc = 1 - rs[i].auc;
503
504       if ( roc->bi_neg_exp )
505         {
506           rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
507           rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
508         }
509       else
510         {
511           rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
512           rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
513         }
514     }
515
516   casereader_destroy (positives);
517   casereader_destroy (negatives);
518
519   output_roc (rs, roc);
520
521   free (rs);
522 }
523
524
525
526
527 static void
528 show_auc  (struct roc_state *rs, const struct cmd_roc *roc)
529 {
530   int i;
531   const int n_fields = roc->print_se ? 5 : 1;
532   const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
533   const int n_rows = 2 + roc->n_vars;
534   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
535
536   if ( roc->n_vars > 1)
537     tab_title (tbl, _("Area Under the Curve"));
538   else
539     tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
540
541   tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
542
543   tab_dim (tbl, tab_natural_dimensions, NULL);
544
545   tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
546
547   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
548
549   tab_box (tbl,
550            TAL_2, TAL_2,
551            -1, TAL_1,
552            0, 0,
553            n_cols - 1,
554            n_rows - 1);
555
556   if ( roc->print_se )
557     {
558       tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
559       tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
560
561       tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
562       tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
563
564       tab_joint_text (tbl, n_cols - 2, 0, 4, 0,
565                       TAT_TITLE | TAB_CENTER | TAT_PRINTF,
566                       _("Asymp. %g%% Confidence Interval"), roc->ci);
567       tab_vline (tbl, 0, n_cols - 1, 0, 0);
568       tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
569     }
570
571   if ( roc->n_vars > 1)
572     tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
573
574   if ( roc->n_vars > 1)
575     tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
576
577
578   for ( i = 0 ; i < roc->n_vars ; ++i )
579     {
580       tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
581
582       tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
583
584       if ( roc->print_se )
585         {
586
587           double se ;
588           const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
589                                       (12 * rs[i].n1 * rs[i].n2));
590           double ci ;
591           double yy ;
592
593           se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
594             (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
595
596           se /= rs[i].n1 * rs[i].n2;
597
598           se = sqrt (se);
599
600           tab_double (tbl, n_cols - 4, 2 + i, 0,
601                       se,
602                       NULL);
603
604           ci = 1 - roc->ci / 100.0;
605           yy = gsl_cdf_gaussian_Qinv (ci, se) ;
606
607           tab_double (tbl, n_cols - 2, 2 + i, 0,
608                       rs[i].auc - yy,
609                       NULL);
610
611           tab_double (tbl, n_cols - 1, 2 + i, 0,
612                       rs[i].auc + yy,
613                       NULL);
614
615           tab_double (tbl, n_cols - 3, 2 + i, 0,
616                       2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
617                       NULL);
618         }
619     }
620
621   tab_submit (tbl);
622 }
623
624
625 static void
626 show_summary (const struct cmd_roc *roc)
627 {
628   const int n_cols = 3;
629   const int n_rows = 4;
630   struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
631
632   tab_title (tbl, _("Case Summary"));
633
634   tab_headers (tbl, 1, 0, 2, 0);
635
636   tab_dim (tbl, tab_natural_dimensions, NULL);
637
638   tab_box (tbl,
639            TAL_2, TAL_2,
640            -1, -1,
641            0, 0,
642            n_cols - 1,
643            n_rows - 1);
644
645   tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
646   tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
647
648
649   tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
650   tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
651
652
653   tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
654   tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
655   tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
656
657   tab_joint_text (tbl, 1, 0, 2, 0,
658                   TAT_TITLE | TAB_CENTER,
659                   _("Valid N (listwise)"));
660
661
662   tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
663   tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
664
665
666 #if 0
667   tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
668   tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
669
670   tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
671   tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
672 #endif
673
674   tab_submit (tbl);
675 }
676
677
678 static void
679 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
680 {
681   show_summary (roc);
682
683 #if 0
684
685   if ( roc->curve )
686     draw_roc (rs, roc);
687 #endif
688
689   show_auc (rs, roc);
690
691 #if 0
692   if ( roc->print_coords )
693     show_coords (rs, roc);
694 #endif
695 }