1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2009 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <language/stats/roc.h>
21 #include <data/procedure.h>
22 #include <language/lexer/variable-parser.h>
23 #include <language/lexer/value-parser.h>
24 #include <language/command.h>
25 #include <language/lexer/lexer.h>
27 #include <data/casegrouper.h>
28 #include <data/casereader.h>
29 #include <data/casewriter.h>
30 #include <data/dictionary.h>
31 #include <data/format.h>
32 #include <math/sort.h>
33 #include <data/subcase.h>
36 #include <libpspp/misc.h>
38 #include <gsl/gsl_cdf.h>
39 #include <output/table.h>
41 #include <output/chart.h>
42 #include <output/charts/roc-chart.h>
45 #define _(msgid) gettext (msgid)
46 #define N_(msgid) msgid
51 const struct variable **vars;
52 const struct dictionary *dict;
54 const struct variable *state_var ;
55 union value state_value;
57 /* Plot the roc curve */
59 /* Plot the reference line */
66 bool bi_neg_exp; /* True iff the bi-negative exponential critieria
68 enum mv_class exclude;
70 bool invert ; /* True iff a smaller test result variable indicates
79 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
82 cmd_roc (struct lexer *lexer, struct dataset *ds)
85 const struct dictionary *dict = dataset_dict (ds);
90 roc.print_coords = false;
93 roc.reference = false;
95 roc.bi_neg_exp = false;
97 roc.pos = roc.pos_weighted = 0;
98 roc.neg = roc.neg_weighted = 0;
99 roc.dict = dataset_dict (ds);
101 if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
102 PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
105 if ( ! lex_force_match (lexer, T_BY))
110 roc.state_var = parse_variable (lexer, dict);
112 if ( !lex_force_match (lexer, '('))
117 value_init (&roc.state_value, var_get_width (roc.state_var));
118 parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
121 if ( !lex_force_match (lexer, ')'))
127 while (lex_token (lexer) != '.')
129 lex_match (lexer, '/');
130 if (lex_match_id (lexer, "MISSING"))
132 lex_match (lexer, '=');
133 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
135 if (lex_match_id (lexer, "INCLUDE"))
137 roc.exclude = MV_SYSTEM;
139 else if (lex_match_id (lexer, "EXCLUDE"))
141 roc.exclude = MV_ANY;
145 lex_error (lexer, NULL);
150 else if (lex_match_id (lexer, "PLOT"))
152 lex_match (lexer, '=');
153 if (lex_match_id (lexer, "CURVE"))
156 if (lex_match (lexer, '('))
158 roc.reference = true;
159 lex_force_match_id (lexer, "REFERENCE");
160 lex_force_match (lexer, ')');
163 else if (lex_match_id (lexer, "NONE"))
169 lex_error (lexer, NULL);
173 else if (lex_match_id (lexer, "PRINT"))
175 lex_match (lexer, '=');
176 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
178 if (lex_match_id (lexer, "SE"))
182 else if (lex_match_id (lexer, "COORDINATES"))
184 roc.print_coords = true;
188 lex_error (lexer, NULL);
193 else if (lex_match_id (lexer, "CRITERIA"))
195 lex_match (lexer, '=');
196 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
198 if (lex_match_id (lexer, "CUTOFF"))
200 lex_force_match (lexer, '(');
201 if (lex_match_id (lexer, "INCLUDE"))
203 roc.exclude = MV_SYSTEM;
205 else if (lex_match_id (lexer, "EXCLUDE"))
207 roc.exclude = MV_USER | MV_SYSTEM;
211 lex_error (lexer, NULL);
214 lex_force_match (lexer, ')');
216 else if (lex_match_id (lexer, "TESTPOS"))
218 lex_force_match (lexer, '(');
219 if (lex_match_id (lexer, "LARGE"))
223 else if (lex_match_id (lexer, "SMALL"))
229 lex_error (lexer, NULL);
232 lex_force_match (lexer, ')');
234 else if (lex_match_id (lexer, "CI"))
236 lex_force_match (lexer, '(');
237 lex_force_num (lexer);
238 roc.ci = lex_number (lexer);
240 lex_force_match (lexer, ')');
242 else if (lex_match_id (lexer, "DISTRIBUTION"))
244 lex_force_match (lexer, '(');
245 if (lex_match_id (lexer, "FREE"))
247 roc.bi_neg_exp = false;
249 else if (lex_match_id (lexer, "NEGEXPO"))
251 roc.bi_neg_exp = true;
255 lex_error (lexer, NULL);
258 lex_force_match (lexer, ')');
262 lex_error (lexer, NULL);
269 lex_error (lexer, NULL);
274 if ( ! run_roc (ds, &roc))
277 value_destroy (&roc.state_value, var_get_width (roc.state_var));
282 value_destroy (&roc.state_value, var_get_width (roc.state_var));
291 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
295 run_roc (struct dataset *ds, struct cmd_roc *roc)
297 struct dictionary *dict = dataset_dict (ds);
299 struct casereader *group;
301 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
302 while (casegrouper_get_next_group (grouper, &group))
304 do_roc (roc, group, dataset_dict (ds));
306 ok = casegrouper_destroy (grouper);
307 ok = proc_commit (ds) && ok;
314 dump_casereader (struct casereader *reader)
317 struct casereader *r = casereader_clone (reader);
319 for ( ; (c = casereader_read (r) ); case_unref (c))
322 for (i = 0 ; i < case_get_value_cnt (c); ++i)
324 printf ("%g ", case_data_idx (c, i)->f);
329 casereader_destroy (r);
335 Return true iff the state variable indicates that C has positive actual state.
337 As a side effect, this function also accumulates the roc->{pos,neg} and
338 roc->{pos,neg}_weighted counts.
341 match_positives (const struct ccase *c, void *aux)
343 struct cmd_roc *roc = aux;
344 const struct variable *wv = dict_get_weight (roc->dict);
345 const double weight = wv ? case_data (c, wv)->f : 1.0;
347 const bool positive =
348 ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
349 var_get_width (roc->state_var)));
354 roc->pos_weighted += weight;
359 roc->neg_weighted += weight;
370 /* Some intermediate state for calculating the cutpoints and the
371 standard error values */
374 double auc; /* Area under the curve */
376 double n1; /* total weight of positives */
377 double n2; /* total weight of negatives */
379 /* intermediates for standard error */
383 /* intermediates for cutpoints */
384 struct casewriter *cutpoint_wtr;
385 struct casereader *cutpoint_rdr;
392 Return a new casereader based upon CUTPOINT_RDR.
393 The number of "positive" cases are placed into
394 the position TRUE_INDEX, and the number of "negative" cases
396 POS_COND and RESULT determine the semantics of what is
398 WEIGHT is the value of a single count.
400 static struct casereader *
401 accumulate_counts (struct casereader *cutpoint_rdr,
402 double result, double weight,
403 bool (*pos_cond) (double, double),
404 int true_index, int false_index)
406 const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
407 struct casewriter *w =
408 autopaging_writer_create (proto);
409 struct casereader *r = casereader_clone (cutpoint_rdr);
411 double prev_cp = SYSMIS;
413 for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
415 struct ccase *new_case;
416 const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
418 assert (cp != SYSMIS);
420 /* We don't want duplicates here */
424 new_case = case_clone (cpc);
426 if ( pos_cond (result, cp))
427 case_data_rw_idx (new_case, true_index)->f += weight;
429 case_data_rw_idx (new_case, false_index)->f += weight;
433 casewriter_write (w, new_case);
435 casereader_destroy (r);
437 return casewriter_make_reader (w);
442 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
445 This function does 3 things:
447 1. Counts the number of cases which are equal to every other case in READER,
448 and those cases for which the relationship between it and every other case
449 satifies PRED (normally either > or <). VAR is variable defining a case's value
452 2. Counts the number of true and false cases in reader, and populates
453 CUTPOINT_RDR accordingly. TRUE_INDEX and FALSE_INDEX are the indices
454 which receive these values. POS_COND is the condition defining true
457 3. CC is filled with the cumulative weight of all cases of READER.
459 static struct casereader *
460 process_group (const struct variable *var, struct casereader *reader,
461 bool (*pred) (double, double),
462 const struct dictionary *dict,
464 struct casereader **cutpoint_rdr,
465 bool (*pos_cond) (double, double),
469 const struct variable *w = dict_get_weight (dict);
471 struct casereader *r1 =
472 casereader_create_distinct (sort_execute_1var (reader, var), var, w);
474 const int weight_idx = w ? var_get_case_index (w) :
475 caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
479 struct casereader *rclone = casereader_clone (r1);
480 struct casewriter *wtr;
481 struct caseproto *proto = caseproto_create ();
483 proto = caseproto_add_width (proto, 0);
484 proto = caseproto_add_width (proto, 0);
485 proto = caseproto_add_width (proto, 0);
487 wtr = autopaging_writer_create (proto);
491 for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
493 struct ccase *new_case = case_create (proto);
495 struct casereader *r2 = casereader_clone (rclone);
497 const double weight1 = case_data_idx (c1, weight_idx)->f;
498 const double d1 = case_data (c1, var)->f;
502 *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
504 true_index, false_index);
508 for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
510 const double d2 = case_data (c2, var)->f;
511 const double weight2 = case_data_idx (c2, weight_idx)->f;
518 else if ( pred (d2, d1))
524 case_data_rw_idx (new_case, VALUE)->f = d1;
525 case_data_rw_idx (new_case, N_EQ)->f = n_eq;
526 case_data_rw_idx (new_case, N_PRED)->f = n_pred;
528 casewriter_write (wtr, new_case);
530 casereader_destroy (r2);
533 casereader_destroy (r1);
534 casereader_destroy (rclone);
536 return casewriter_make_reader (wtr);
539 /* Some more indeces into case data */
540 #define N_POS_EQ 1 /* number of positive cases with values equal to n */
541 #define N_POS_GT 2 /* number of postive cases with values greater than n */
542 #define N_NEG_EQ 3 /* number of negative cases with values equal to n */
543 #define N_NEG_LT 4 /* number of negative cases with values less than n */
546 gt (double d1, double d2)
553 ge (double d1, double d2)
559 lt (double d1, double d2)
566 Return a casereader with width 3,
567 populated with cases based upon READER.
568 The cases will have the values:
569 (N, number of cases equal to N, number of cases greater than N)
570 As a side effect, update RS->n1 with the number of positive cases.
572 static struct casereader *
573 process_positive_group (const struct variable *var, struct casereader *reader,
574 const struct dictionary *dict,
575 struct roc_state *rs)
577 return process_group (var, reader, gt, dict, &rs->n1,
584 Return a casereader with width 3,
585 populated with cases based upon READER.
586 The cases will have the values:
587 (N, number of cases equal to N, number of cases less than N)
588 As a side effect, update RS->n2 with the number of negative cases.
590 static struct casereader *
591 process_negative_group (const struct variable *var, struct casereader *reader,
592 const struct dictionary *dict,
593 struct roc_state *rs)
595 return process_group (var, reader, lt, dict, &rs->n2,
605 append_cutpoint (struct casewriter *writer, double cutpoint)
607 struct ccase *cc = case_create (casewriter_get_proto (writer));
609 case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
610 case_data_rw_idx (cc, ROC_TP)->f = 0;
611 case_data_rw_idx (cc, ROC_FN)->f = 0;
612 case_data_rw_idx (cc, ROC_TN)->f = 0;
613 case_data_rw_idx (cc, ROC_FP)->f = 0;
615 casewriter_write (writer, cc);
620 Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will
621 be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
622 reader will be populated with its final number of cases.
623 However on exit from this function, only ROC_CUTPOINT entries will be set to their final
624 value. The other entries will be initialised to zero.
627 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
630 struct casereader *r = casereader_clone (input);
632 struct caseproto *proto = caseproto_create ();
634 struct subcase ordering;
635 subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
637 proto = caseproto_add_width (proto, 0); /* cutpoint */
638 proto = caseproto_add_width (proto, 0); /* ROC_TP */
639 proto = caseproto_add_width (proto, 0); /* ROC_FN */
640 proto = caseproto_add_width (proto, 0); /* ROC_TN */
641 proto = caseproto_add_width (proto, 0); /* ROC_FP */
643 for (i = 0 ; i < roc->n_vars; ++i)
645 rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
646 rs[i].prev_result = SYSMIS;
647 rs[i].max = -DBL_MAX;
651 for (; (c = casereader_read (r)) != NULL; case_unref (c))
653 for (i = 0 ; i < roc->n_vars; ++i)
655 const union value *v = case_data (c, roc->vars[i]);
656 const double result = v->f;
658 if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
661 minimize (&rs[i].min, result);
662 maximize (&rs[i].max, result);
664 if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
666 const double mean = (result + rs[i].prev_result ) / 2.0;
667 append_cutpoint (rs[i].cutpoint_wtr, mean);
670 rs[i].prev_result = result;
673 casereader_destroy (r);
676 /* Append the min and max cutpoints */
677 for (i = 0 ; i < roc->n_vars; ++i)
679 append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
680 append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
682 rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
687 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
691 struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
693 struct casereader *negatives = NULL;
694 struct casereader *positives = NULL;
696 struct caseproto *n_proto = caseproto_create ();
698 struct subcase up_ordering;
699 struct subcase down_ordering;
701 struct casewriter *neg_wtr = NULL;
703 struct casereader *input = casereader_create_filter_missing (reader,
704 roc->vars, roc->n_vars,
709 input = casereader_create_filter_missing (input,
715 neg_wtr = autopaging_writer_create (casereader_get_proto (input));
717 prepare_cutpoints (roc, rs, input);
720 /* Separate the positive actual state cases from the negative ones */
722 casereader_create_filter_func (input,
728 n_proto = caseproto_create ();
730 n_proto = caseproto_add_width (n_proto, 0);
731 n_proto = caseproto_add_width (n_proto, 0);
732 n_proto = caseproto_add_width (n_proto, 0);
733 n_proto = caseproto_add_width (n_proto, 0);
734 n_proto = caseproto_add_width (n_proto, 0);
736 subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
737 subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
739 for (i = 0 ; i < roc->n_vars; ++i)
741 struct casewriter *w = NULL;
742 struct casereader *r = NULL;
747 struct casereader *n_neg ;
748 const struct variable *var = roc->vars[i];
750 struct casereader *neg ;
751 struct casereader *pos = casereader_clone (positives);
754 struct casereader *n_pos =
755 process_positive_group (var, pos, dict, &rs[i]);
757 if ( negatives == NULL)
759 negatives = casewriter_make_reader (neg_wtr);
762 neg = casereader_clone (negatives);
764 n_neg = process_negative_group (var, neg, dict, &rs[i]);
767 /* Merge the n_pos and n_neg casereaders */
768 w = sort_create_writer (&up_ordering, n_proto);
769 for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
771 struct ccase *pos_case = case_create (n_proto);
773 const double jpos = case_data_idx (cpos, VALUE)->f;
775 while ((cneg = casereader_read (n_neg)))
777 struct ccase *nc = case_create (n_proto);
779 const double jneg = case_data_idx (cneg, VALUE)->f;
781 case_data_rw_idx (nc, VALUE)->f = jneg;
782 case_data_rw_idx (nc, N_POS_EQ)->f = 0;
784 case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
786 *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
787 *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
789 casewriter_write (w, nc);
796 case_data_rw_idx (pos_case, VALUE)->f = jpos;
797 *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
798 *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
799 case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
800 case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
802 casewriter_write (w, pos_case);
805 /* These aren't used anymore */
809 r = casewriter_make_reader (w);
811 /* Propagate the N_POS_GT values from the positive cases
812 to the negative ones */
814 double prev_pos_gt = rs[i].n1;
815 w = sort_create_writer (&down_ordering, n_proto);
817 for ( ; (c = casereader_read (r) ); case_unref (c))
819 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
820 struct ccase *nc = case_clone (c);
822 if ( n_pos_gt == SYSMIS)
824 n_pos_gt = prev_pos_gt;
825 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
828 casewriter_write (w, nc);
829 prev_pos_gt = n_pos_gt;
832 r = casewriter_make_reader (w);
835 /* Propagate the N_NEG_LT values from the negative cases
836 to the positive ones */
838 double prev_neg_lt = rs[i].n2;
839 w = sort_create_writer (&up_ordering, n_proto);
841 for ( ; (c = casereader_read (r) ); case_unref (c))
843 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
844 struct ccase *nc = case_clone (c);
846 if ( n_neg_lt == SYSMIS)
848 n_neg_lt = prev_neg_lt;
849 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
852 casewriter_write (w, nc);
853 prev_neg_lt = n_neg_lt;
856 r = casewriter_make_reader (w);
860 struct ccase *prev_case = NULL;
861 for ( ; (c = casereader_read (r) ); case_unref (c))
863 const struct ccase *next_case = casereader_peek (r, 0);
865 const double j = case_data_idx (c, VALUE)->f;
866 double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
867 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
868 double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
869 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
871 if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
873 if ( 0 == case_data_idx (c, N_POS_EQ)->f)
875 n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
876 n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
879 if ( 0 == case_data_idx (c, N_NEG_EQ)->f)
881 n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
882 n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
886 if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
888 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
891 n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
893 n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
897 case_unref (prev_case);
898 prev_case = case_clone (c);
901 rs[i].auc /= rs[i].n1 * rs[i].n2;
903 rs[i].auc = 1 - rs[i].auc;
905 if ( roc->bi_neg_exp )
907 rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
908 rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
912 rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
913 rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
918 casereader_destroy (positives);
919 casereader_destroy (negatives);
921 output_roc (rs, roc);
927 show_auc (struct roc_state *rs, const struct cmd_roc *roc)
930 const int n_fields = roc->print_se ? 5 : 1;
931 const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
932 const int n_rows = 2 + roc->n_vars;
933 struct tab_table *tbl = tab_create (n_cols, n_rows);
935 if ( roc->n_vars > 1)
936 tab_title (tbl, _("Area Under the Curve"));
938 tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
940 tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
942 tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
944 tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
946 tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
957 tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
958 tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
960 tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
961 tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
963 tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
964 TAT_TITLE | TAB_CENTER,
965 _("Asymp. %g%% Confidence Interval"), roc->ci);
966 tab_vline (tbl, 0, n_cols - 1, 0, 0);
967 tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
970 if ( roc->n_vars > 1)
971 tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
973 if ( roc->n_vars > 1)
974 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
977 for ( i = 0 ; i < roc->n_vars ; ++i )
979 tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
981 tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
986 const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
987 (12 * rs[i].n1 * rs[i].n2));
991 se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
992 (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
994 se /= rs[i].n1 * rs[i].n2;
998 tab_double (tbl, n_cols - 4, 2 + i, 0,
1002 ci = 1 - roc->ci / 100.0;
1003 yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1005 tab_double (tbl, n_cols - 2, 2 + i, 0,
1009 tab_double (tbl, n_cols - 1, 2 + i, 0,
1013 tab_double (tbl, n_cols - 3, 2 + i, 0,
1014 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1024 show_summary (const struct cmd_roc *roc)
1026 const int n_cols = 3;
1027 const int n_rows = 4;
1028 struct tab_table *tbl = tab_create (n_cols, n_rows);
1030 tab_title (tbl, _("Case Summary"));
1032 tab_headers (tbl, 1, 0, 2, 0);
1034 tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
1043 tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1044 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1047 tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1048 tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1051 tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1052 tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1053 tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1055 tab_joint_text (tbl, 1, 0, 2, 0,
1056 TAT_TITLE | TAB_CENTER,
1057 _("Valid N (listwise)"));
1060 tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1061 tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1064 tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1065 tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1067 tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1068 tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1075 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1079 const int n_cols = roc->n_vars > 1 ? 4 : 3;
1081 struct tab_table *tbl ;
1083 for (i = 0; i < roc->n_vars; ++i)
1084 n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1086 tbl = tab_create (n_cols, n_rows);
1088 if ( roc->n_vars > 1)
1089 tab_title (tbl, _("Coordinates of the Curve"));
1091 tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1094 tab_headers (tbl, 1, 0, 1, 0);
1096 tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
1098 tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1100 if ( roc->n_vars > 1)
1101 tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1103 tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1104 tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1105 tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1114 if ( roc->n_vars > 1)
1115 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1117 for (i = 0; i < roc->n_vars; ++i)
1120 struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1122 if ( roc->n_vars > 1)
1123 tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1126 tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1129 for (; (cc = casereader_read (r)) != NULL;
1130 case_unref (cc), x++)
1132 const double se = case_data_idx (cc, ROC_TP)->f /
1134 case_data_idx (cc, ROC_TP)->f
1136 case_data_idx (cc, ROC_FN)->f
1139 const double sp = case_data_idx (cc, ROC_TN)->f /
1141 case_data_idx (cc, ROC_TN)->f
1143 case_data_idx (cc, ROC_FP)->f
1146 tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1147 var_get_print_format (roc->vars[i]));
1149 tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1150 tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1153 casereader_destroy (r);
1161 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1167 struct roc_chart *rc;
1170 rc = roc_chart_create (roc->reference);
1171 for (i = 0; i < roc->n_vars; i++)
1172 roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1173 rs[i].cutpoint_rdr);
1174 chart_submit (roc_chart_get_chart (rc));
1179 if ( roc->print_coords )
1180 show_coords (rs, roc);