1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2009 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <language/stats/roc.h>
21 #include <data/procedure.h>
22 #include <language/lexer/variable-parser.h>
23 #include <language/lexer/value-parser.h>
24 #include <language/command.h>
25 #include <language/lexer/lexer.h>
27 #include <data/casegrouper.h>
28 #include <data/casereader.h>
29 #include <data/casewriter.h>
30 #include <data/dictionary.h>
31 #include <data/format.h>
32 #include <math/sort.h>
33 #include <data/subcase.h>
36 #include <libpspp/misc.h>
38 #include <gsl/gsl_cdf.h>
39 #include <output/table.h>
41 #include <output/chart.h>
42 #include <output/charts/roc-chart.h>
45 #define _(msgid) gettext (msgid)
46 #define N_(msgid) msgid
51 const struct variable **vars;
52 const struct dictionary *dict;
54 const struct variable *state_var ;
55 union value state_value;
57 /* Plot the roc curve */
59 /* Plot the reference line */
66 bool bi_neg_exp; /* True iff the bi-negative exponential critieria
68 enum mv_class exclude;
70 bool invert ; /* True iff a smaller test result variable indicates
79 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
82 cmd_roc (struct lexer *lexer, struct dataset *ds)
85 const struct dictionary *dict = dataset_dict (ds);
90 roc.print_coords = false;
93 roc.reference = false;
95 roc.bi_neg_exp = false;
97 roc.pos = roc.pos_weighted = 0;
98 roc.neg = roc.neg_weighted = 0;
99 roc.dict = dataset_dict (ds);
101 if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
102 PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
105 if ( ! lex_force_match (lexer, T_BY))
110 roc.state_var = parse_variable (lexer, dict);
112 if ( !lex_force_match (lexer, '('))
117 parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
120 if ( !lex_force_match (lexer, ')'))
126 while (lex_token (lexer) != '.')
128 lex_match (lexer, '/');
129 if (lex_match_id (lexer, "MISSING"))
131 lex_match (lexer, '=');
132 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
134 if (lex_match_id (lexer, "INCLUDE"))
136 roc.exclude = MV_SYSTEM;
138 else if (lex_match_id (lexer, "EXCLUDE"))
140 roc.exclude = MV_ANY;
144 lex_error (lexer, NULL);
149 else if (lex_match_id (lexer, "PLOT"))
151 lex_match (lexer, '=');
152 if (lex_match_id (lexer, "CURVE"))
155 if (lex_match (lexer, '('))
157 roc.reference = true;
158 lex_force_match_id (lexer, "REFERENCE");
159 lex_force_match (lexer, ')');
162 else if (lex_match_id (lexer, "NONE"))
168 lex_error (lexer, NULL);
172 else if (lex_match_id (lexer, "PRINT"))
174 lex_match (lexer, '=');
175 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
177 if (lex_match_id (lexer, "SE"))
181 else if (lex_match_id (lexer, "COORDINATES"))
183 roc.print_coords = true;
187 lex_error (lexer, NULL);
192 else if (lex_match_id (lexer, "CRITERIA"))
194 lex_match (lexer, '=');
195 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
197 if (lex_match_id (lexer, "CUTOFF"))
199 lex_force_match (lexer, '(');
200 if (lex_match_id (lexer, "INCLUDE"))
202 roc.exclude = MV_SYSTEM;
204 else if (lex_match_id (lexer, "EXCLUDE"))
206 roc.exclude = MV_USER | MV_SYSTEM;
210 lex_error (lexer, NULL);
213 lex_force_match (lexer, ')');
215 else if (lex_match_id (lexer, "TESTPOS"))
217 lex_force_match (lexer, '(');
218 if (lex_match_id (lexer, "LARGE"))
222 else if (lex_match_id (lexer, "SMALL"))
228 lex_error (lexer, NULL);
231 lex_force_match (lexer, ')');
233 else if (lex_match_id (lexer, "CI"))
235 lex_force_match (lexer, '(');
236 lex_force_num (lexer);
237 roc.ci = lex_number (lexer);
239 lex_force_match (lexer, ')');
241 else if (lex_match_id (lexer, "DISTRIBUTION"))
243 lex_force_match (lexer, '(');
244 if (lex_match_id (lexer, "FREE"))
246 roc.bi_neg_exp = false;
248 else if (lex_match_id (lexer, "NEGEXPO"))
250 roc.bi_neg_exp = true;
254 lex_error (lexer, NULL);
257 lex_force_match (lexer, ')');
261 lex_error (lexer, NULL);
268 lex_error (lexer, NULL);
273 if ( ! run_roc (ds, &roc))
288 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
292 run_roc (struct dataset *ds, struct cmd_roc *roc)
294 struct dictionary *dict = dataset_dict (ds);
296 struct casereader *group;
298 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
299 while (casegrouper_get_next_group (grouper, &group))
301 do_roc (roc, group, dataset_dict (ds));
303 ok = casegrouper_destroy (grouper);
304 ok = proc_commit (ds) && ok;
311 dump_casereader (struct casereader *reader)
314 struct casereader *r = casereader_clone (reader);
316 for ( ; (c = casereader_read (r) ); case_unref (c))
319 for (i = 0 ; i < case_get_value_cnt (c); ++i)
321 printf ("%g ", case_data_idx (c, i)->f);
326 casereader_destroy (r);
332 Return true iff the state variable indicates that C has positive actual state.
334 As a side effect, this function also accumulates the roc->{pos,neg} and
335 roc->{pos,neg}_weighted counts.
338 match_positives (const struct ccase *c, void *aux)
340 struct cmd_roc *roc = aux;
341 const struct variable *wv = dict_get_weight (roc->dict);
342 const double weight = wv ? case_data (c, wv)->f : 1.0;
344 const bool positive =
345 ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
346 var_get_width (roc->state_var)));
351 roc->pos_weighted += weight;
356 roc->neg_weighted += weight;
367 /* Some intermediate state for calculating the cutpoints and the
368 standard error values */
371 double auc; /* Area under the curve */
373 double n1; /* total weight of positives */
374 double n2; /* total weight of negatives */
376 /* intermediates for standard error */
380 /* intermediates for cutpoints */
381 struct casewriter *cutpoint_wtr;
382 struct casereader *cutpoint_rdr;
389 Return a new casereader based upon CUTPOINT_RDR.
390 The number of "positive" cases are placed into
391 the position TRUE_INDEX, and the number of "negative" cases
393 POS_COND and RESULT determine the semantics of what is
395 WEIGHT is the value of a single count.
397 static struct casereader *
398 accumulate_counts (struct casereader *cutpoint_rdr,
399 double result, double weight,
400 bool (*pos_cond) (double, double),
401 int true_index, int false_index)
403 const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
404 struct casewriter *w =
405 autopaging_writer_create (proto);
406 struct casereader *r = casereader_clone (cutpoint_rdr);
408 double prev_cp = SYSMIS;
410 for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
412 struct ccase *new_case;
413 const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
415 assert (cp != SYSMIS);
417 /* We don't want duplicates here */
421 new_case = case_clone (cpc);
423 if ( pos_cond (result, cp))
424 case_data_rw_idx (new_case, true_index)->f += weight;
426 case_data_rw_idx (new_case, false_index)->f += weight;
430 casewriter_write (w, new_case);
432 casereader_destroy (r);
434 return casewriter_make_reader (w);
439 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
442 This function does 3 things:
444 1. Counts the number of cases which are equal to every other case in READER,
445 and those cases for which the relationship between it and every other case
446 satifies PRED (normally either > or <). VAR is variable defining a case's value
449 2. Counts the number of true and false cases in reader, and populates
450 CUTPOINT_RDR accordingly. TRUE_INDEX and FALSE_INDEX are the indices
451 which receive these values. POS_COND is the condition defining true
454 3. CC is filled with the cumulative weight of all cases of READER.
456 static struct casereader *
457 process_group (const struct variable *var, struct casereader *reader,
458 bool (*pred) (double, double),
459 const struct dictionary *dict,
461 struct casereader **cutpoint_rdr,
462 bool (*pos_cond) (double, double),
466 const struct variable *w = dict_get_weight (dict);
468 struct casereader *r1 =
469 casereader_create_distinct (sort_execute_1var (reader, var), var, w);
471 const int weight_idx = w ? var_get_case_index (w) :
472 caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
476 struct casereader *rclone = casereader_clone (r1);
477 struct casewriter *wtr;
478 struct caseproto *proto = caseproto_create ();
480 proto = caseproto_add_width (proto, 0);
481 proto = caseproto_add_width (proto, 0);
482 proto = caseproto_add_width (proto, 0);
484 wtr = autopaging_writer_create (proto);
488 for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
490 struct ccase *new_case = case_create (proto);
492 struct casereader *r2 = casereader_clone (rclone);
494 const double weight1 = case_data_idx (c1, weight_idx)->f;
495 const double d1 = case_data (c1, var)->f;
499 *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
501 true_index, false_index);
505 for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
507 const double d2 = case_data (c2, var)->f;
508 const double weight2 = case_data_idx (c2, weight_idx)->f;
515 else if ( pred (d2, d1))
521 case_data_rw_idx (new_case, VALUE)->f = d1;
522 case_data_rw_idx (new_case, N_EQ)->f = n_eq;
523 case_data_rw_idx (new_case, N_PRED)->f = n_pred;
525 casewriter_write (wtr, new_case);
527 casereader_destroy (r2);
530 casereader_destroy (r1);
531 casereader_destroy (rclone);
533 return casewriter_make_reader (wtr);
536 /* Some more indeces into case data */
537 #define N_POS_EQ 1 /* number of positive cases with values equal to n */
538 #define N_POS_GT 2 /* number of postive cases with values greater than n */
539 #define N_NEG_EQ 3 /* number of negative cases with values equal to n */
540 #define N_NEG_LT 4 /* number of negative cases with values less than n */
543 gt (double d1, double d2)
550 ge (double d1, double d2)
556 lt (double d1, double d2)
563 Return a casereader with width 3,
564 populated with cases based upon READER.
565 The cases will have the values:
566 (N, number of cases equal to N, number of cases greater than N)
567 As a side effect, update RS->n1 with the number of positive cases.
569 static struct casereader *
570 process_positive_group (const struct variable *var, struct casereader *reader,
571 const struct dictionary *dict,
572 struct roc_state *rs)
574 return process_group (var, reader, gt, dict, &rs->n1,
581 Return a casereader with width 3,
582 populated with cases based upon READER.
583 The cases will have the values:
584 (N, number of cases equal to N, number of cases less than N)
585 As a side effect, update RS->n2 with the number of negative cases.
587 static struct casereader *
588 process_negative_group (const struct variable *var, struct casereader *reader,
589 const struct dictionary *dict,
590 struct roc_state *rs)
592 return process_group (var, reader, lt, dict, &rs->n2,
602 append_cutpoint (struct casewriter *writer, double cutpoint)
604 struct ccase *cc = case_create (casewriter_get_proto (writer));
606 case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
607 case_data_rw_idx (cc, ROC_TP)->f = 0;
608 case_data_rw_idx (cc, ROC_FN)->f = 0;
609 case_data_rw_idx (cc, ROC_TN)->f = 0;
610 case_data_rw_idx (cc, ROC_FP)->f = 0;
612 casewriter_write (writer, cc);
617 Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will
618 be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
619 reader will be populated with its final number of cases.
620 However on exit from this function, only ROC_CUTPOINT entries will be set to their final
621 value. The other entries will be initialised to zero.
624 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
627 struct casereader *r = casereader_clone (input);
629 struct caseproto *proto = caseproto_create ();
631 struct subcase ordering;
632 subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
634 proto = caseproto_add_width (proto, 0); /* cutpoint */
635 proto = caseproto_add_width (proto, 0); /* ROC_TP */
636 proto = caseproto_add_width (proto, 0); /* ROC_FN */
637 proto = caseproto_add_width (proto, 0); /* ROC_TN */
638 proto = caseproto_add_width (proto, 0); /* ROC_FP */
640 for (i = 0 ; i < roc->n_vars; ++i)
642 rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
643 rs[i].prev_result = SYSMIS;
644 rs[i].max = -DBL_MAX;
648 for (; (c = casereader_read (r)) != NULL; case_unref (c))
650 for (i = 0 ; i < roc->n_vars; ++i)
652 const union value *v = case_data (c, roc->vars[i]);
653 const double result = v->f;
655 if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
658 minimize (&rs[i].min, result);
659 maximize (&rs[i].max, result);
661 if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
663 const double mean = (result + rs[i].prev_result ) / 2.0;
664 append_cutpoint (rs[i].cutpoint_wtr, mean);
667 rs[i].prev_result = result;
670 casereader_destroy (r);
673 /* Append the min and max cutpoints */
674 for (i = 0 ; i < roc->n_vars; ++i)
676 append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
677 append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
679 rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
684 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
688 struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
690 struct casereader *negatives = NULL;
691 struct casereader *positives = NULL;
693 struct caseproto *n_proto = caseproto_create ();
695 struct subcase up_ordering;
696 struct subcase down_ordering;
698 struct casewriter *neg_wtr = NULL;
700 struct casereader *input = casereader_create_filter_missing (reader,
701 roc->vars, roc->n_vars,
706 input = casereader_create_filter_missing (input,
712 neg_wtr = autopaging_writer_create (casereader_get_proto (input));
714 prepare_cutpoints (roc, rs, input);
717 /* Separate the positive actual state cases from the negative ones */
719 casereader_create_filter_func (input,
725 n_proto = caseproto_create ();
727 n_proto = caseproto_add_width (n_proto, 0);
728 n_proto = caseproto_add_width (n_proto, 0);
729 n_proto = caseproto_add_width (n_proto, 0);
730 n_proto = caseproto_add_width (n_proto, 0);
731 n_proto = caseproto_add_width (n_proto, 0);
733 subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
734 subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
736 for (i = 0 ; i < roc->n_vars; ++i)
738 struct casewriter *w = NULL;
739 struct casereader *r = NULL;
744 struct casereader *n_neg ;
745 const struct variable *var = roc->vars[i];
747 struct casereader *neg ;
748 struct casereader *pos = casereader_clone (positives);
751 struct casereader *n_pos =
752 process_positive_group (var, pos, dict, &rs[i]);
754 if ( negatives == NULL)
756 negatives = casewriter_make_reader (neg_wtr);
759 neg = casereader_clone (negatives);
761 n_neg = process_negative_group (var, neg, dict, &rs[i]);
764 /* Merge the n_pos and n_neg casereaders */
765 w = sort_create_writer (&up_ordering, n_proto);
766 for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
768 struct ccase *pos_case = case_create (n_proto);
770 const double jpos = case_data_idx (cpos, VALUE)->f;
772 while ((cneg = casereader_read (n_neg)))
774 struct ccase *nc = case_create (n_proto);
776 const double jneg = case_data_idx (cneg, VALUE)->f;
778 case_data_rw_idx (nc, VALUE)->f = jneg;
779 case_data_rw_idx (nc, N_POS_EQ)->f = 0;
781 case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
783 *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
784 *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
786 casewriter_write (w, nc);
793 case_data_rw_idx (pos_case, VALUE)->f = jpos;
794 *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
795 *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
796 case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
797 case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
799 casewriter_write (w, pos_case);
802 /* These aren't used anymore */
806 r = casewriter_make_reader (w);
808 /* Propagate the N_POS_GT values from the positive cases
809 to the negative ones */
811 double prev_pos_gt = rs[i].n1;
812 w = sort_create_writer (&down_ordering, n_proto);
814 for ( ; (c = casereader_read (r) ); case_unref (c))
816 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
817 struct ccase *nc = case_clone (c);
819 if ( n_pos_gt == SYSMIS)
821 n_pos_gt = prev_pos_gt;
822 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
825 casewriter_write (w, nc);
826 prev_pos_gt = n_pos_gt;
829 r = casewriter_make_reader (w);
832 /* Propagate the N_NEG_LT values from the negative cases
833 to the positive ones */
835 double prev_neg_lt = rs[i].n2;
836 w = sort_create_writer (&up_ordering, n_proto);
838 for ( ; (c = casereader_read (r) ); case_unref (c))
840 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
841 struct ccase *nc = case_clone (c);
843 if ( n_neg_lt == SYSMIS)
845 n_neg_lt = prev_neg_lt;
846 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
849 casewriter_write (w, nc);
850 prev_neg_lt = n_neg_lt;
853 r = casewriter_make_reader (w);
857 struct ccase *prev_case = NULL;
858 for ( ; (c = casereader_read (r) ); case_unref (c))
860 const struct ccase *next_case = casereader_peek (r, 0);
862 const double j = case_data_idx (c, VALUE)->f;
863 double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
864 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
865 double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
866 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
868 if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
870 if ( 0 == case_data_idx (c, N_POS_EQ)->f)
872 n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
873 n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
876 if ( 0 == case_data_idx (c, N_NEG_EQ)->f)
878 n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
879 n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
883 if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
885 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
888 n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
890 n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
894 case_unref (prev_case);
895 prev_case = case_clone (c);
898 rs[i].auc /= rs[i].n1 * rs[i].n2;
900 rs[i].auc = 1 - rs[i].auc;
902 if ( roc->bi_neg_exp )
904 rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
905 rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
909 rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
910 rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
915 casereader_destroy (positives);
916 casereader_destroy (negatives);
918 output_roc (rs, roc);
924 show_auc (struct roc_state *rs, const struct cmd_roc *roc)
927 const int n_fields = roc->print_se ? 5 : 1;
928 const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
929 const int n_rows = 2 + roc->n_vars;
930 struct tab_table *tbl = tab_create (n_cols, n_rows);
932 if ( roc->n_vars > 1)
933 tab_title (tbl, _("Area Under the Curve"));
935 tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
937 tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
939 tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
941 tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
943 tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
954 tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
955 tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
957 tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
958 tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
960 tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
961 TAT_TITLE | TAB_CENTER,
962 _("Asymp. %g%% Confidence Interval"), roc->ci);
963 tab_vline (tbl, 0, n_cols - 1, 0, 0);
964 tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
967 if ( roc->n_vars > 1)
968 tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
970 if ( roc->n_vars > 1)
971 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
974 for ( i = 0 ; i < roc->n_vars ; ++i )
976 tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
978 tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
983 const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
984 (12 * rs[i].n1 * rs[i].n2));
988 se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
989 (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
991 se /= rs[i].n1 * rs[i].n2;
995 tab_double (tbl, n_cols - 4, 2 + i, 0,
999 ci = 1 - roc->ci / 100.0;
1000 yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1002 tab_double (tbl, n_cols - 2, 2 + i, 0,
1006 tab_double (tbl, n_cols - 1, 2 + i, 0,
1010 tab_double (tbl, n_cols - 3, 2 + i, 0,
1011 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1021 show_summary (const struct cmd_roc *roc)
1023 const int n_cols = 3;
1024 const int n_rows = 4;
1025 struct tab_table *tbl = tab_create (n_cols, n_rows);
1027 tab_title (tbl, _("Case Summary"));
1029 tab_headers (tbl, 1, 0, 2, 0);
1031 tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
1040 tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1041 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1044 tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1045 tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1048 tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1049 tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1050 tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1052 tab_joint_text (tbl, 1, 0, 2, 0,
1053 TAT_TITLE | TAB_CENTER,
1054 _("Valid N (listwise)"));
1057 tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1058 tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1061 tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1062 tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1064 tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1065 tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1072 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1076 const int n_cols = roc->n_vars > 1 ? 4 : 3;
1078 struct tab_table *tbl ;
1080 for (i = 0; i < roc->n_vars; ++i)
1081 n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1083 tbl = tab_create (n_cols, n_rows);
1085 if ( roc->n_vars > 1)
1086 tab_title (tbl, _("Coordinates of the Curve"));
1088 tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1091 tab_headers (tbl, 1, 0, 1, 0);
1093 tab_dim (tbl, tab_natural_dimensions, NULL, NULL);
1095 tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1097 if ( roc->n_vars > 1)
1098 tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1100 tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1101 tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1102 tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1111 if ( roc->n_vars > 1)
1112 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1114 for (i = 0; i < roc->n_vars; ++i)
1117 struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1119 if ( roc->n_vars > 1)
1120 tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1123 tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1126 for (; (cc = casereader_read (r)) != NULL;
1127 case_unref (cc), x++)
1129 const double se = case_data_idx (cc, ROC_TP)->f /
1131 case_data_idx (cc, ROC_TP)->f
1133 case_data_idx (cc, ROC_FN)->f
1136 const double sp = case_data_idx (cc, ROC_TN)->f /
1138 case_data_idx (cc, ROC_TN)->f
1140 case_data_idx (cc, ROC_FP)->f
1143 tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1144 var_get_print_format (roc->vars[i]));
1146 tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1147 tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1150 casereader_destroy (r);
1158 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1164 struct roc_chart *rc;
1167 rc = roc_chart_create (roc->reference);
1168 for (i = 0; i < roc->n_vars; i++)
1169 roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1170 rs[i].cutpoint_rdr);
1171 chart_submit (roc_chart_get_chart (rc));
1176 if ( roc->print_coords )
1177 show_coords (rs, roc);