1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2009 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <language/stats/roc.h>
21 #include <data/casegrouper.h>
22 #include <data/casereader.h>
23 #include <data/casewriter.h>
24 #include <data/dictionary.h>
25 #include <data/format.h>
26 #include <data/procedure.h>
27 #include <data/subcase.h>
28 #include <language/command.h>
29 #include <language/lexer/lexer.h>
30 #include <language/lexer/value-parser.h>
31 #include <language/lexer/variable-parser.h>
32 #include <libpspp/misc.h>
33 #include <math/sort.h>
34 #include <output/chart-item.h>
35 #include <output/charts/roc-chart.h>
36 #include <output/tab.h>
38 #include <gsl/gsl_cdf.h>
41 #define _(msgid) gettext (msgid)
42 #define N_(msgid) msgid
47 const struct variable **vars;
48 const struct dictionary *dict;
50 const struct variable *state_var;
51 union value state_value;
53 /* Plot the roc curve */
55 /* Plot the reference line */
62 bool bi_neg_exp; /* True iff the bi-negative exponential critieria
64 enum mv_class exclude;
66 bool invert ; /* True iff a smaller test result variable indicates
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
81 const struct dictionary *dict = dataset_dict (ds);
86 roc.print_coords = false;
89 roc.reference = false;
91 roc.bi_neg_exp = false;
93 roc.pos = roc.pos_weighted = 0;
94 roc.neg = roc.neg_weighted = 0;
95 roc.dict = dataset_dict (ds);
98 lex_match (lexer, '/');
99 if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
100 PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
103 if ( ! lex_force_match (lexer, T_BY))
108 roc.state_var = parse_variable (lexer, dict);
110 if ( !lex_force_match (lexer, '('))
115 value_init (&roc.state_value, var_get_width (roc.state_var));
116 parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
119 if ( !lex_force_match (lexer, ')'))
125 while (lex_token (lexer) != '.')
127 lex_match (lexer, '/');
128 if (lex_match_id (lexer, "MISSING"))
130 lex_match (lexer, '=');
131 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
133 if (lex_match_id (lexer, "INCLUDE"))
135 roc.exclude = MV_SYSTEM;
137 else if (lex_match_id (lexer, "EXCLUDE"))
139 roc.exclude = MV_ANY;
143 lex_error (lexer, NULL);
148 else if (lex_match_id (lexer, "PLOT"))
150 lex_match (lexer, '=');
151 if (lex_match_id (lexer, "CURVE"))
154 if (lex_match (lexer, '('))
156 roc.reference = true;
157 lex_force_match_id (lexer, "REFERENCE");
158 lex_force_match (lexer, ')');
161 else if (lex_match_id (lexer, "NONE"))
167 lex_error (lexer, NULL);
171 else if (lex_match_id (lexer, "PRINT"))
173 lex_match (lexer, '=');
174 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
176 if (lex_match_id (lexer, "SE"))
180 else if (lex_match_id (lexer, "COORDINATES"))
182 roc.print_coords = true;
186 lex_error (lexer, NULL);
191 else if (lex_match_id (lexer, "CRITERIA"))
193 lex_match (lexer, '=');
194 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
196 if (lex_match_id (lexer, "CUTOFF"))
198 lex_force_match (lexer, '(');
199 if (lex_match_id (lexer, "INCLUDE"))
201 roc.exclude = MV_SYSTEM;
203 else if (lex_match_id (lexer, "EXCLUDE"))
205 roc.exclude = MV_USER | MV_SYSTEM;
209 lex_error (lexer, NULL);
212 lex_force_match (lexer, ')');
214 else if (lex_match_id (lexer, "TESTPOS"))
216 lex_force_match (lexer, '(');
217 if (lex_match_id (lexer, "LARGE"))
221 else if (lex_match_id (lexer, "SMALL"))
227 lex_error (lexer, NULL);
230 lex_force_match (lexer, ')');
232 else if (lex_match_id (lexer, "CI"))
234 lex_force_match (lexer, '(');
235 lex_force_num (lexer);
236 roc.ci = lex_number (lexer);
238 lex_force_match (lexer, ')');
240 else if (lex_match_id (lexer, "DISTRIBUTION"))
242 lex_force_match (lexer, '(');
243 if (lex_match_id (lexer, "FREE"))
245 roc.bi_neg_exp = false;
247 else if (lex_match_id (lexer, "NEGEXPO"))
249 roc.bi_neg_exp = true;
253 lex_error (lexer, NULL);
256 lex_force_match (lexer, ')');
260 lex_error (lexer, NULL);
267 lex_error (lexer, NULL);
272 if ( ! run_roc (ds, &roc))
275 value_destroy (&roc.state_value, var_get_width (roc.state_var));
281 value_destroy (&roc.state_value, var_get_width (roc.state_var));
290 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
294 run_roc (struct dataset *ds, struct cmd_roc *roc)
296 struct dictionary *dict = dataset_dict (ds);
298 struct casereader *group;
300 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
301 while (casegrouper_get_next_group (grouper, &group))
303 do_roc (roc, group, dataset_dict (ds));
305 ok = casegrouper_destroy (grouper);
306 ok = proc_commit (ds) && ok;
313 dump_casereader (struct casereader *reader)
316 struct casereader *r = casereader_clone (reader);
318 for ( ; (c = casereader_read (r) ); case_unref (c))
321 for (i = 0 ; i < case_get_value_cnt (c); ++i)
323 printf ("%g ", case_data_idx (c, i)->f);
328 casereader_destroy (r);
334 Return true iff the state variable indicates that C has positive actual state.
336 As a side effect, this function also accumulates the roc->{pos,neg} and
337 roc->{pos,neg}_weighted counts.
340 match_positives (const struct ccase *c, void *aux)
342 struct cmd_roc *roc = aux;
343 const struct variable *wv = dict_get_weight (roc->dict);
344 const double weight = wv ? case_data (c, wv)->f : 1.0;
346 const bool positive =
347 ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
348 var_get_width (roc->state_var)));
353 roc->pos_weighted += weight;
358 roc->neg_weighted += weight;
369 /* Some intermediate state for calculating the cutpoints and the
370 standard error values */
373 double auc; /* Area under the curve */
375 double n1; /* total weight of positives */
376 double n2; /* total weight of negatives */
378 /* intermediates for standard error */
382 /* intermediates for cutpoints */
383 struct casewriter *cutpoint_wtr;
384 struct casereader *cutpoint_rdr;
391 Return a new casereader based upon CUTPOINT_RDR.
392 The number of "positive" cases are placed into
393 the position TRUE_INDEX, and the number of "negative" cases
395 POS_COND and RESULT determine the semantics of what is
397 WEIGHT is the value of a single count.
399 static struct casereader *
400 accumulate_counts (struct casereader *cutpoint_rdr,
401 double result, double weight,
402 bool (*pos_cond) (double, double),
403 int true_index, int false_index)
405 const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
406 struct casewriter *w =
407 autopaging_writer_create (proto);
408 struct casereader *r = casereader_clone (cutpoint_rdr);
410 double prev_cp = SYSMIS;
412 for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
414 struct ccase *new_case;
415 const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
417 assert (cp != SYSMIS);
419 /* We don't want duplicates here */
423 new_case = case_clone (cpc);
425 if ( pos_cond (result, cp))
426 case_data_rw_idx (new_case, true_index)->f += weight;
428 case_data_rw_idx (new_case, false_index)->f += weight;
432 casewriter_write (w, new_case);
434 casereader_destroy (r);
436 return casewriter_make_reader (w);
441 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
444 This function does 3 things:
446 1. Counts the number of cases which are equal to every other case in READER,
447 and those cases for which the relationship between it and every other case
448 satifies PRED (normally either > or <). VAR is variable defining a case's value
451 2. Counts the number of true and false cases in reader, and populates
452 CUTPOINT_RDR accordingly. TRUE_INDEX and FALSE_INDEX are the indices
453 which receive these values. POS_COND is the condition defining true
456 3. CC is filled with the cumulative weight of all cases of READER.
458 static struct casereader *
459 process_group (const struct variable *var, struct casereader *reader,
460 bool (*pred) (double, double),
461 const struct dictionary *dict,
463 struct casereader **cutpoint_rdr,
464 bool (*pos_cond) (double, double),
468 const struct variable *w = dict_get_weight (dict);
470 struct casereader *r1 =
471 casereader_create_distinct (sort_execute_1var (reader, var), var, w);
473 const int weight_idx = w ? var_get_case_index (w) :
474 caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
478 struct casereader *rclone = casereader_clone (r1);
479 struct casewriter *wtr;
480 struct caseproto *proto = caseproto_create ();
482 proto = caseproto_add_width (proto, 0);
483 proto = caseproto_add_width (proto, 0);
484 proto = caseproto_add_width (proto, 0);
486 wtr = autopaging_writer_create (proto);
490 for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
492 struct ccase *new_case = case_create (proto);
494 struct casereader *r2 = casereader_clone (rclone);
496 const double weight1 = case_data_idx (c1, weight_idx)->f;
497 const double d1 = case_data (c1, var)->f;
501 *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
503 true_index, false_index);
507 for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
509 const double d2 = case_data (c2, var)->f;
510 const double weight2 = case_data_idx (c2, weight_idx)->f;
517 else if ( pred (d2, d1))
523 case_data_rw_idx (new_case, VALUE)->f = d1;
524 case_data_rw_idx (new_case, N_EQ)->f = n_eq;
525 case_data_rw_idx (new_case, N_PRED)->f = n_pred;
527 casewriter_write (wtr, new_case);
529 casereader_destroy (r2);
532 casereader_destroy (r1);
533 casereader_destroy (rclone);
535 return casewriter_make_reader (wtr);
538 /* Some more indeces into case data */
539 #define N_POS_EQ 1 /* number of positive cases with values equal to n */
540 #define N_POS_GT 2 /* number of postive cases with values greater than n */
541 #define N_NEG_EQ 3 /* number of negative cases with values equal to n */
542 #define N_NEG_LT 4 /* number of negative cases with values less than n */
545 gt (double d1, double d2)
552 ge (double d1, double d2)
558 lt (double d1, double d2)
565 Return a casereader with width 3,
566 populated with cases based upon READER.
567 The cases will have the values:
568 (N, number of cases equal to N, number of cases greater than N)
569 As a side effect, update RS->n1 with the number of positive cases.
571 static struct casereader *
572 process_positive_group (const struct variable *var, struct casereader *reader,
573 const struct dictionary *dict,
574 struct roc_state *rs)
576 return process_group (var, reader, gt, dict, &rs->n1,
583 Return a casereader with width 3,
584 populated with cases based upon READER.
585 The cases will have the values:
586 (N, number of cases equal to N, number of cases less than N)
587 As a side effect, update RS->n2 with the number of negative cases.
589 static struct casereader *
590 process_negative_group (const struct variable *var, struct casereader *reader,
591 const struct dictionary *dict,
592 struct roc_state *rs)
594 return process_group (var, reader, lt, dict, &rs->n2,
604 append_cutpoint (struct casewriter *writer, double cutpoint)
606 struct ccase *cc = case_create (casewriter_get_proto (writer));
608 case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
609 case_data_rw_idx (cc, ROC_TP)->f = 0;
610 case_data_rw_idx (cc, ROC_FN)->f = 0;
611 case_data_rw_idx (cc, ROC_TN)->f = 0;
612 case_data_rw_idx (cc, ROC_FP)->f = 0;
614 casewriter_write (writer, cc);
619 Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will
620 be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
621 reader will be populated with its final number of cases.
622 However on exit from this function, only ROC_CUTPOINT entries will be set to their final
623 value. The other entries will be initialised to zero.
626 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
629 struct casereader *r = casereader_clone (input);
631 struct caseproto *proto = caseproto_create ();
633 struct subcase ordering;
634 subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
636 proto = caseproto_add_width (proto, 0); /* cutpoint */
637 proto = caseproto_add_width (proto, 0); /* ROC_TP */
638 proto = caseproto_add_width (proto, 0); /* ROC_FN */
639 proto = caseproto_add_width (proto, 0); /* ROC_TN */
640 proto = caseproto_add_width (proto, 0); /* ROC_FP */
642 for (i = 0 ; i < roc->n_vars; ++i)
644 rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
645 rs[i].prev_result = SYSMIS;
646 rs[i].max = -DBL_MAX;
650 for (; (c = casereader_read (r)) != NULL; case_unref (c))
652 for (i = 0 ; i < roc->n_vars; ++i)
654 const union value *v = case_data (c, roc->vars[i]);
655 const double result = v->f;
657 if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
660 minimize (&rs[i].min, result);
661 maximize (&rs[i].max, result);
663 if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
665 const double mean = (result + rs[i].prev_result ) / 2.0;
666 append_cutpoint (rs[i].cutpoint_wtr, mean);
669 rs[i].prev_result = result;
672 casereader_destroy (r);
675 /* Append the min and max cutpoints */
676 for (i = 0 ; i < roc->n_vars; ++i)
678 append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
679 append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
681 rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
686 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
690 struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
692 struct casereader *negatives = NULL;
693 struct casereader *positives = NULL;
695 struct caseproto *n_proto = caseproto_create ();
697 struct subcase up_ordering;
698 struct subcase down_ordering;
700 struct casewriter *neg_wtr = NULL;
702 struct casereader *input = casereader_create_filter_missing (reader,
703 roc->vars, roc->n_vars,
708 input = casereader_create_filter_missing (input,
714 neg_wtr = autopaging_writer_create (casereader_get_proto (input));
716 prepare_cutpoints (roc, rs, input);
719 /* Separate the positive actual state cases from the negative ones */
721 casereader_create_filter_func (input,
727 n_proto = caseproto_create ();
729 n_proto = caseproto_add_width (n_proto, 0);
730 n_proto = caseproto_add_width (n_proto, 0);
731 n_proto = caseproto_add_width (n_proto, 0);
732 n_proto = caseproto_add_width (n_proto, 0);
733 n_proto = caseproto_add_width (n_proto, 0);
735 subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
736 subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
738 for (i = 0 ; i < roc->n_vars; ++i)
740 struct casewriter *w = NULL;
741 struct casereader *r = NULL;
746 struct casereader *n_neg ;
747 const struct variable *var = roc->vars[i];
749 struct casereader *neg ;
750 struct casereader *pos = casereader_clone (positives);
753 struct casereader *n_pos =
754 process_positive_group (var, pos, dict, &rs[i]);
756 if ( negatives == NULL)
758 negatives = casewriter_make_reader (neg_wtr);
761 neg = casereader_clone (negatives);
763 n_neg = process_negative_group (var, neg, dict, &rs[i]);
766 /* Merge the n_pos and n_neg casereaders */
767 w = sort_create_writer (&up_ordering, n_proto);
768 for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
770 struct ccase *pos_case = case_create (n_proto);
772 const double jpos = case_data_idx (cpos, VALUE)->f;
774 while ((cneg = casereader_read (n_neg)))
776 struct ccase *nc = case_create (n_proto);
778 const double jneg = case_data_idx (cneg, VALUE)->f;
780 case_data_rw_idx (nc, VALUE)->f = jneg;
781 case_data_rw_idx (nc, N_POS_EQ)->f = 0;
783 case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
785 *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
786 *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
788 casewriter_write (w, nc);
795 case_data_rw_idx (pos_case, VALUE)->f = jpos;
796 *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
797 *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
798 case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
799 case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
801 casewriter_write (w, pos_case);
804 /* These aren't used anymore */
808 r = casewriter_make_reader (w);
810 /* Propagate the N_POS_GT values from the positive cases
811 to the negative ones */
813 double prev_pos_gt = rs[i].n1;
814 w = sort_create_writer (&down_ordering, n_proto);
816 for ( ; (c = casereader_read (r) ); case_unref (c))
818 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
819 struct ccase *nc = case_clone (c);
821 if ( n_pos_gt == SYSMIS)
823 n_pos_gt = prev_pos_gt;
824 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
827 casewriter_write (w, nc);
828 prev_pos_gt = n_pos_gt;
831 r = casewriter_make_reader (w);
834 /* Propagate the N_NEG_LT values from the negative cases
835 to the positive ones */
837 double prev_neg_lt = rs[i].n2;
838 w = sort_create_writer (&up_ordering, n_proto);
840 for ( ; (c = casereader_read (r) ); case_unref (c))
842 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
843 struct ccase *nc = case_clone (c);
845 if ( n_neg_lt == SYSMIS)
847 n_neg_lt = prev_neg_lt;
848 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
851 casewriter_write (w, nc);
852 prev_neg_lt = n_neg_lt;
855 r = casewriter_make_reader (w);
859 struct ccase *prev_case = NULL;
860 for ( ; (c = casereader_read (r) ); case_unref (c))
862 const struct ccase *next_case = casereader_peek (r, 0);
864 const double j = case_data_idx (c, VALUE)->f;
865 double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
866 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
867 double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
868 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
870 if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
872 if ( 0 == case_data_idx (c, N_POS_EQ)->f)
874 n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
875 n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
878 if ( 0 == case_data_idx (c, N_NEG_EQ)->f)
880 n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
881 n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
885 if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
887 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
890 n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
892 n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
896 case_unref (prev_case);
897 prev_case = case_clone (c);
900 rs[i].auc /= rs[i].n1 * rs[i].n2;
902 rs[i].auc = 1 - rs[i].auc;
904 if ( roc->bi_neg_exp )
906 rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
907 rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
911 rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
912 rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
917 casereader_destroy (positives);
918 casereader_destroy (negatives);
920 output_roc (rs, roc);
926 show_auc (struct roc_state *rs, const struct cmd_roc *roc)
929 const int n_fields = roc->print_se ? 5 : 1;
930 const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
931 const int n_rows = 2 + roc->n_vars;
932 struct tab_table *tbl = tab_create (n_cols, n_rows);
934 if ( roc->n_vars > 1)
935 tab_title (tbl, _("Area Under the Curve"));
937 tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
939 tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
942 tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
944 tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
955 tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
956 tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
958 tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
959 tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
961 tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
962 TAT_TITLE | TAB_CENTER,
963 _("Asymp. %g%% Confidence Interval"), roc->ci);
964 tab_vline (tbl, 0, n_cols - 1, 0, 0);
965 tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
968 if ( roc->n_vars > 1)
969 tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
971 if ( roc->n_vars > 1)
972 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
975 for ( i = 0 ; i < roc->n_vars ; ++i )
977 tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
979 tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
984 const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
985 (12 * rs[i].n1 * rs[i].n2));
989 se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
990 (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
992 se /= rs[i].n1 * rs[i].n2;
996 tab_double (tbl, n_cols - 4, 2 + i, 0,
1000 ci = 1 - roc->ci / 100.0;
1001 yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1003 tab_double (tbl, n_cols - 2, 2 + i, 0,
1007 tab_double (tbl, n_cols - 1, 2 + i, 0,
1011 tab_double (tbl, n_cols - 3, 2 + i, 0,
1012 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1022 show_summary (const struct cmd_roc *roc)
1024 const int n_cols = 3;
1025 const int n_rows = 4;
1026 struct tab_table *tbl = tab_create (n_cols, n_rows);
1028 tab_title (tbl, _("Case Summary"));
1030 tab_headers (tbl, 1, 0, 2, 0);
1039 tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1040 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1043 tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1044 tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1047 tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1048 tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1049 tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1051 tab_joint_text (tbl, 1, 0, 2, 0,
1052 TAT_TITLE | TAB_CENTER,
1053 _("Valid N (listwise)"));
1056 tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1057 tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1060 tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1061 tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1063 tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1064 tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1071 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1075 const int n_cols = roc->n_vars > 1 ? 4 : 3;
1077 struct tab_table *tbl ;
1079 for (i = 0; i < roc->n_vars; ++i)
1080 n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1082 tbl = tab_create (n_cols, n_rows);
1084 if ( roc->n_vars > 1)
1085 tab_title (tbl, _("Coordinates of the Curve"));
1087 tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1090 tab_headers (tbl, 1, 0, 1, 0);
1092 tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1094 if ( roc->n_vars > 1)
1095 tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1097 tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1098 tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1099 tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1108 if ( roc->n_vars > 1)
1109 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1111 for (i = 0; i < roc->n_vars; ++i)
1114 struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1116 if ( roc->n_vars > 1)
1117 tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1120 tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1123 for (; (cc = casereader_read (r)) != NULL;
1124 case_unref (cc), x++)
1126 const double se = case_data_idx (cc, ROC_TP)->f /
1128 case_data_idx (cc, ROC_TP)->f
1130 case_data_idx (cc, ROC_FN)->f
1133 const double sp = case_data_idx (cc, ROC_TN)->f /
1135 case_data_idx (cc, ROC_TN)->f
1137 case_data_idx (cc, ROC_FP)->f
1140 tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f,
1141 var_get_print_format (roc->vars[i]));
1143 tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1144 tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1147 casereader_destroy (r);
1155 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1161 struct roc_chart *rc;
1164 rc = roc_chart_create (roc->reference);
1165 for (i = 0; i < roc->n_vars; i++)
1166 roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1167 rs[i].cutpoint_rdr);
1168 roc_chart_submit (rc);
1173 if ( roc->print_coords )
1174 show_coords (rs, roc);