1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2009 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <data/procedure.h>
20 #include <language/lexer/variable-parser.h>
21 #include <language/lexer/value-parser.h>
22 #include <language/command.h>
23 #include <language/lexer/lexer.h>
25 #include <data/casegrouper.h>
26 #include <data/casereader.h>
27 #include <data/casewriter.h>
28 #include <data/dictionary.h>
29 #include <data/format.h>
30 #include <math/sort.h>
31 #include <data/subcase.h>
34 #include <libpspp/misc.h>
36 #include <gsl/gsl_cdf.h>
37 #include <output/table.h>
39 #include <output/charts/plot-chart.h>
40 #include <output/charts/cartesian.h>
43 #define _(msgid) gettext (msgid)
44 #define N_(msgid) msgid
49 const struct variable **vars;
50 const struct dictionary *dict;
52 const struct variable *state_var ;
53 union value state_value;
55 /* Plot the roc curve */
57 /* Plot the reference line */
64 bool bi_neg_exp; /* True iff the bi-negative exponential critieria
66 enum mv_class exclude;
68 bool invert ; /* True iff a smaller test result variable indicates
77 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
80 cmd_roc (struct lexer *lexer, struct dataset *ds)
83 const struct dictionary *dict = dataset_dict (ds);
88 roc.print_coords = false;
91 roc.reference = false;
93 roc.bi_neg_exp = false;
95 roc.pos = roc.pos_weighted = 0;
96 roc.neg = roc.neg_weighted = 0;
97 roc.dict = dataset_dict (ds);
99 if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
100 PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
103 if ( ! lex_force_match (lexer, T_BY))
108 roc.state_var = parse_variable (lexer, dict);
110 if ( !lex_force_match (lexer, '('))
115 value_init (&roc.state_value, var_get_width (roc.state_var));
116 parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
119 if ( !lex_force_match (lexer, ')'))
125 while (lex_token (lexer) != '.')
127 lex_match (lexer, '/');
128 if (lex_match_id (lexer, "MISSING"))
130 lex_match (lexer, '=');
131 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
133 if (lex_match_id (lexer, "INCLUDE"))
135 roc.exclude = MV_SYSTEM;
137 else if (lex_match_id (lexer, "EXCLUDE"))
139 roc.exclude = MV_ANY;
143 lex_error (lexer, NULL);
148 else if (lex_match_id (lexer, "PLOT"))
150 lex_match (lexer, '=');
151 if (lex_match_id (lexer, "CURVE"))
154 if (lex_match (lexer, '('))
156 roc.reference = true;
157 lex_force_match_id (lexer, "REFERENCE");
158 lex_force_match (lexer, ')');
161 else if (lex_match_id (lexer, "NONE"))
167 lex_error (lexer, NULL);
171 else if (lex_match_id (lexer, "PRINT"))
173 lex_match (lexer, '=');
174 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
176 if (lex_match_id (lexer, "SE"))
180 else if (lex_match_id (lexer, "COORDINATES"))
182 roc.print_coords = true;
186 lex_error (lexer, NULL);
191 else if (lex_match_id (lexer, "CRITERIA"))
193 lex_match (lexer, '=');
194 while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
196 if (lex_match_id (lexer, "CUTOFF"))
198 lex_force_match (lexer, '(');
199 if (lex_match_id (lexer, "INCLUDE"))
201 roc.exclude = MV_SYSTEM;
203 else if (lex_match_id (lexer, "EXCLUDE"))
205 roc.exclude = MV_USER | MV_SYSTEM;
209 lex_error (lexer, NULL);
212 lex_force_match (lexer, ')');
214 else if (lex_match_id (lexer, "TESTPOS"))
216 lex_force_match (lexer, '(');
217 if (lex_match_id (lexer, "LARGE"))
221 else if (lex_match_id (lexer, "SMALL"))
227 lex_error (lexer, NULL);
230 lex_force_match (lexer, ')');
232 else if (lex_match_id (lexer, "CI"))
234 lex_force_match (lexer, '(');
235 lex_force_num (lexer);
236 roc.ci = lex_number (lexer);
238 lex_force_match (lexer, ')');
240 else if (lex_match_id (lexer, "DISTRIBUTION"))
242 lex_force_match (lexer, '(');
243 if (lex_match_id (lexer, "FREE"))
245 roc.bi_neg_exp = false;
247 else if (lex_match_id (lexer, "NEGEXPO"))
249 roc.bi_neg_exp = true;
253 lex_error (lexer, NULL);
256 lex_force_match (lexer, ')');
260 lex_error (lexer, NULL);
267 lex_error (lexer, NULL);
272 if ( ! run_roc (ds, &roc))
275 value_destroy (&roc.state_value, var_get_width (roc.state_var));
280 value_destroy (&roc.state_value, var_get_width (roc.state_var));
289 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
293 run_roc (struct dataset *ds, struct cmd_roc *roc)
295 struct dictionary *dict = dataset_dict (ds);
297 struct casereader *group;
299 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
300 while (casegrouper_get_next_group (grouper, &group))
302 do_roc (roc, group, dataset_dict (ds));
304 ok = casegrouper_destroy (grouper);
305 ok = proc_commit (ds) && ok;
312 dump_casereader (struct casereader *reader)
315 struct casereader *r = casereader_clone (reader);
317 for ( ; (c = casereader_read (r) ); case_unref (c))
320 for (i = 0 ; i < case_get_value_cnt (c); ++i)
322 printf ("%g ", case_data_idx (c, i)->f);
327 casereader_destroy (r);
333 Return true iff the state variable indicates that C has positive actual state.
335 As a side effect, this function also accumulates the roc->{pos,neg} and
336 roc->{pos,neg}_weighted counts.
339 match_positives (const struct ccase *c, void *aux)
341 struct cmd_roc *roc = aux;
342 const struct variable *wv = dict_get_weight (roc->dict);
343 const double weight = wv ? case_data (c, wv)->f : 1.0;
345 const bool positive =
346 ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
347 var_get_width (roc->state_var)));
352 roc->pos_weighted += weight;
357 roc->neg_weighted += weight;
368 /* Some intermediate state for calculating the cutpoints and the
369 standard error values */
372 double auc; /* Area under the curve */
374 double n1; /* total weight of positives */
375 double n2; /* total weight of negatives */
377 /* intermediates for standard error */
381 /* intermediates for cutpoints */
382 struct casewriter *cutpoint_wtr;
383 struct casereader *cutpoint_rdr;
397 Return a new casereader based upon CUTPOINT_RDR.
398 The number of "positive" cases are placed into
399 the position TRUE_INDEX, and the number of "negative" cases
401 POS_COND and RESULT determine the semantics of what is
403 WEIGHT is the value of a single count.
405 static struct casereader *
406 accumulate_counts (struct casereader *cutpoint_rdr,
407 double result, double weight,
408 bool (*pos_cond) (double, double),
409 int true_index, int false_index)
411 const struct caseproto *proto = casereader_get_proto (cutpoint_rdr);
412 struct casewriter *w =
413 autopaging_writer_create (proto);
414 struct casereader *r = casereader_clone (cutpoint_rdr);
416 double prev_cp = SYSMIS;
418 for ( ; (cpc = casereader_read (r) ); case_unref (cpc))
420 struct ccase *new_case;
421 const double cp = case_data_idx (cpc, CUTPOINT)->f;
423 assert (cp != SYSMIS);
425 /* We don't want duplicates here */
429 new_case = case_clone (cpc);
431 if ( pos_cond (result, cp))
432 case_data_rw_idx (new_case, true_index)->f += weight;
434 case_data_rw_idx (new_case, false_index)->f += weight;
438 casewriter_write (w, new_case);
440 casereader_destroy (r);
442 return casewriter_make_reader (w);
447 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
450 This function does 3 things:
452 1. Counts the number of cases which are equal to every other case in READER,
453 and those cases for which the relationship between it and every other case
454 satifies PRED (normally either > or <). VAR is variable defining a case's value
457 2. Counts the number of true and false cases in reader, and populates
458 CUTPOINT_RDR accordingly. TRUE_INDEX and FALSE_INDEX are the indices
459 which receive these values. POS_COND is the condition defining true
462 3. CC is filled with the cumulative weight of all cases of READER.
464 static struct casereader *
465 process_group (const struct variable *var, struct casereader *reader,
466 bool (*pred) (double, double),
467 const struct dictionary *dict,
469 struct casereader **cutpoint_rdr,
470 bool (*pos_cond) (double, double),
474 const struct variable *w = dict_get_weight (dict);
476 struct casereader *r1 =
477 casereader_create_distinct (sort_execute_1var (reader, var), var, w);
479 const int weight_idx = w ? var_get_case_index (w) :
480 caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
484 struct casereader *rclone = casereader_clone (r1);
485 struct casewriter *wtr;
486 struct caseproto *proto = caseproto_create ();
488 proto = caseproto_add_width (proto, 0);
489 proto = caseproto_add_width (proto, 0);
490 proto = caseproto_add_width (proto, 0);
492 wtr = autopaging_writer_create (proto);
496 for ( ; (c1 = casereader_read (r1) ); case_unref (c1))
498 struct ccase *new_case = case_create (proto);
500 struct casereader *r2 = casereader_clone (rclone);
502 const double weight1 = case_data_idx (c1, weight_idx)->f;
503 const double d1 = case_data (c1, var)->f;
507 *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
509 true_index, false_index);
513 for ( ; (c2 = casereader_read (r2) ); case_unref (c2))
515 const double d2 = case_data (c2, var)->f;
516 const double weight2 = case_data_idx (c2, weight_idx)->f;
523 else if ( pred (d2, d1))
529 case_data_rw_idx (new_case, VALUE)->f = d1;
530 case_data_rw_idx (new_case, N_EQ)->f = n_eq;
531 case_data_rw_idx (new_case, N_PRED)->f = n_pred;
533 casewriter_write (wtr, new_case);
535 casereader_destroy (r2);
538 casereader_destroy (r1);
539 casereader_destroy (rclone);
541 return casewriter_make_reader (wtr);
544 /* Some more indeces into case data */
545 #define N_POS_EQ 1 /* number of positive cases with values equal to n */
546 #define N_POS_GT 2 /* number of postive cases with values greater than n */
547 #define N_NEG_EQ 3 /* number of negative cases with values equal to n */
548 #define N_NEG_LT 4 /* number of negative cases with values less than n */
551 gt (double d1, double d2)
558 ge (double d1, double d2)
564 lt (double d1, double d2)
571 Return a casereader with width 3,
572 populated with cases based upon READER.
573 The cases will have the values:
574 (N, number of cases equal to N, number of cases greater than N)
575 As a side effect, update RS->n1 with the number of positive cases.
577 static struct casereader *
578 process_positive_group (const struct variable *var, struct casereader *reader,
579 const struct dictionary *dict,
580 struct roc_state *rs)
582 return process_group (var, reader, gt, dict, &rs->n1,
589 Return a casereader with width 3,
590 populated with cases based upon READER.
591 The cases will have the values:
592 (N, number of cases equal to N, number of cases less than N)
593 As a side effect, update RS->n2 with the number of negative cases.
595 static struct casereader *
596 process_negative_group (const struct variable *var, struct casereader *reader,
597 const struct dictionary *dict,
598 struct roc_state *rs)
600 return process_group (var, reader, lt, dict, &rs->n2,
610 append_cutpoint (struct casewriter *writer, double cutpoint)
612 struct ccase *cc = case_create (casewriter_get_proto (writer));
614 case_data_rw_idx (cc, CUTPOINT)->f = cutpoint;
615 case_data_rw_idx (cc, TP)->f = 0;
616 case_data_rw_idx (cc, FN)->f = 0;
617 case_data_rw_idx (cc, TN)->f = 0;
618 case_data_rw_idx (cc, FP)->f = 0;
620 casewriter_write (writer, cc);
625 Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will
626 be created with width 5, ready to take the values (cutpoint, TP, FN, TN, FP), and the
627 reader will be populated with its final number of cases.
628 However on exit from this function, only CUTPOINT entries will be set to their final
629 value. The other entries will be initialised to zero.
632 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
635 struct casereader *r = casereader_clone (input);
637 struct caseproto *proto = caseproto_create ();
639 struct subcase ordering;
640 subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND);
642 proto = caseproto_add_width (proto, 0); /* cutpoint */
643 proto = caseproto_add_width (proto, 0); /* TP */
644 proto = caseproto_add_width (proto, 0); /* FN */
645 proto = caseproto_add_width (proto, 0); /* TN */
646 proto = caseproto_add_width (proto, 0); /* FP */
648 for (i = 0 ; i < roc->n_vars; ++i)
650 rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
651 rs[i].prev_result = SYSMIS;
652 rs[i].max = -DBL_MAX;
656 for (; (c = casereader_read (r)) != NULL; case_unref (c))
658 for (i = 0 ; i < roc->n_vars; ++i)
660 const union value *v = case_data (c, roc->vars[i]);
661 const double result = v->f;
663 if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
666 minimize (&rs[i].min, result);
667 maximize (&rs[i].max, result);
669 if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
671 const double mean = (result + rs[i].prev_result ) / 2.0;
672 append_cutpoint (rs[i].cutpoint_wtr, mean);
675 rs[i].prev_result = result;
678 casereader_destroy (r);
681 /* Append the min and max cutpoints */
682 for (i = 0 ; i < roc->n_vars; ++i)
684 append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
685 append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
687 rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
692 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
696 struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
698 struct casereader *negatives = NULL;
699 struct casereader *positives = NULL;
701 struct caseproto *n_proto = caseproto_create ();
703 struct subcase up_ordering;
704 struct subcase down_ordering;
706 struct casewriter *neg_wtr = NULL;
708 struct casereader *input = casereader_create_filter_missing (reader,
709 roc->vars, roc->n_vars,
714 input = casereader_create_filter_missing (input,
720 neg_wtr = autopaging_writer_create (casereader_get_proto (input));
722 prepare_cutpoints (roc, rs, input);
725 /* Separate the positive actual state cases from the negative ones */
727 casereader_create_filter_func (input,
733 n_proto = caseproto_create ();
735 n_proto = caseproto_add_width (n_proto, 0);
736 n_proto = caseproto_add_width (n_proto, 0);
737 n_proto = caseproto_add_width (n_proto, 0);
738 n_proto = caseproto_add_width (n_proto, 0);
739 n_proto = caseproto_add_width (n_proto, 0);
741 subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
742 subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
744 for (i = 0 ; i < roc->n_vars; ++i)
746 struct casewriter *w = NULL;
747 struct casereader *r = NULL;
752 struct casereader *n_neg ;
753 const struct variable *var = roc->vars[i];
755 struct casereader *neg ;
756 struct casereader *pos = casereader_clone (positives);
759 struct casereader *n_pos =
760 process_positive_group (var, pos, dict, &rs[i]);
762 if ( negatives == NULL)
764 negatives = casewriter_make_reader (neg_wtr);
767 neg = casereader_clone (negatives);
769 n_neg = process_negative_group (var, neg, dict, &rs[i]);
772 /* Merge the n_pos and n_neg casereaders */
773 w = sort_create_writer (&up_ordering, n_proto);
774 for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos))
776 struct ccase *pos_case = case_create (n_proto);
778 const double jpos = case_data_idx (cpos, VALUE)->f;
780 while ((cneg = casereader_read (n_neg)))
782 struct ccase *nc = case_create (n_proto);
784 const double jneg = case_data_idx (cneg, VALUE)->f;
786 case_data_rw_idx (nc, VALUE)->f = jneg;
787 case_data_rw_idx (nc, N_POS_EQ)->f = 0;
789 case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
791 *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
792 *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
794 casewriter_write (w, nc);
801 case_data_rw_idx (pos_case, VALUE)->f = jpos;
802 *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
803 *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
804 case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
805 case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
807 casewriter_write (w, pos_case);
810 /* These aren't used anymore */
814 r = casewriter_make_reader (w);
816 /* Propagate the N_POS_GT values from the positive cases
817 to the negative ones */
819 double prev_pos_gt = rs[i].n1;
820 w = sort_create_writer (&down_ordering, n_proto);
822 for ( ; (c = casereader_read (r) ); case_unref (c))
824 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
825 struct ccase *nc = case_clone (c);
827 if ( n_pos_gt == SYSMIS)
829 n_pos_gt = prev_pos_gt;
830 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
833 casewriter_write (w, nc);
834 prev_pos_gt = n_pos_gt;
837 r = casewriter_make_reader (w);
840 /* Propagate the N_NEG_LT values from the negative cases
841 to the positive ones */
843 double prev_neg_lt = rs[i].n2;
844 w = sort_create_writer (&up_ordering, n_proto);
846 for ( ; (c = casereader_read (r) ); case_unref (c))
848 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
849 struct ccase *nc = case_clone (c);
851 if ( n_neg_lt == SYSMIS)
853 n_neg_lt = prev_neg_lt;
854 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
857 casewriter_write (w, nc);
858 prev_neg_lt = n_neg_lt;
861 r = casewriter_make_reader (w);
865 struct ccase *prev_case = NULL;
866 for ( ; (c = casereader_read (r) ); case_unref (c))
868 const struct ccase *next_case = casereader_peek (r, 0);
870 const double j = case_data_idx (c, VALUE)->f;
871 double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
872 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
873 double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
874 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
876 if ( prev_case && j == case_data_idx (prev_case, VALUE)->f)
878 if ( 0 == case_data_idx (c, N_POS_EQ)->f)
880 n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
881 n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
884 if ( 0 == case_data_idx (c, N_NEG_EQ)->f)
886 n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
887 n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
891 if ( NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
893 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
896 n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
898 n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
902 case_unref (prev_case);
903 prev_case = case_clone (c);
906 rs[i].auc /= rs[i].n1 * rs[i].n2;
908 rs[i].auc = 1 - rs[i].auc;
910 if ( roc->bi_neg_exp )
912 rs[i].q1hat = rs[i].auc / ( 2 - rs[i].auc);
913 rs[i].q2hat = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc);
917 rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
918 rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
923 casereader_destroy (positives);
924 casereader_destroy (negatives);
926 output_roc (rs, roc);
932 show_auc (struct roc_state *rs, const struct cmd_roc *roc)
935 const int n_fields = roc->print_se ? 5 : 1;
936 const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields;
937 const int n_rows = 2 + roc->n_vars;
938 struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
940 if ( roc->n_vars > 1)
941 tab_title (tbl, _("Area Under the Curve"));
943 tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0]));
945 tab_headers (tbl, n_cols - n_fields, 0, 1, 0);
947 tab_dim (tbl, tab_natural_dimensions, NULL);
949 tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area"));
951 tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
962 tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error"));
963 tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig."));
965 tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound"));
966 tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound"));
968 tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0,
969 TAT_TITLE | TAB_CENTER,
970 _("Asymp. %g%% Confidence Interval"), roc->ci);
971 tab_vline (tbl, 0, n_cols - 1, 0, 0);
972 tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1);
975 if ( roc->n_vars > 1)
976 tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test"));
978 if ( roc->n_vars > 1)
979 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
982 for ( i = 0 ; i < roc->n_vars ; ++i )
984 tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i]));
986 tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL);
991 const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
992 (12 * rs[i].n1 * rs[i].n2));
996 se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc)) +
997 (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc));
999 se /= rs[i].n1 * rs[i].n2;
1003 tab_double (tbl, n_cols - 4, 2 + i, 0,
1007 ci = 1 - roc->ci / 100.0;
1008 yy = gsl_cdf_gaussian_Qinv (ci, se) ;
1010 tab_double (tbl, n_cols - 2, 2 + i, 0,
1014 tab_double (tbl, n_cols - 1, 2 + i, 0,
1018 tab_double (tbl, n_cols - 3, 2 + i, 0,
1019 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)),
1029 show_summary (const struct cmd_roc *roc)
1031 const int n_cols = 3;
1032 const int n_rows = 4;
1033 struct tab_table *tbl = tab_create (n_cols, n_rows, 0);
1035 tab_title (tbl, _("Case Summary"));
1037 tab_headers (tbl, 1, 0, 2, 0);
1039 tab_dim (tbl, tab_natural_dimensions, NULL);
1048 tab_hline (tbl, TAL_2, 0, n_cols - 1, 2);
1049 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1052 tab_hline (tbl, TAL_2, 1, n_cols - 1, 1);
1053 tab_vline (tbl, TAL_1, 2, 1, n_rows - 1);
1056 tab_text (tbl, 0, 1, TAT_TITLE | TAB_LEFT, var_to_string (roc->state_var));
1057 tab_text (tbl, 1, 1, TAT_TITLE, _("Unweighted"));
1058 tab_text (tbl, 2, 1, TAT_TITLE, _("Weighted"));
1060 tab_joint_text (tbl, 1, 0, 2, 0,
1061 TAT_TITLE | TAB_CENTER,
1062 _("Valid N (listwise)"));
1065 tab_text (tbl, 0, 2, TAB_LEFT, _("Positive"));
1066 tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
1069 tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
1070 tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
1072 tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
1073 tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
1080 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1084 const int n_cols = roc->n_vars > 1 ? 4 : 3;
1086 struct tab_table *tbl ;
1088 for (i = 0; i < roc->n_vars; ++i)
1089 n_rows += casereader_count_cases (rs[i].cutpoint_rdr);
1091 tbl = tab_create (n_cols, n_rows, 0);
1093 if ( roc->n_vars > 1)
1094 tab_title (tbl, _("Coordinates of the Curve"));
1096 tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0]));
1099 tab_headers (tbl, 1, 0, 1, 0);
1101 tab_dim (tbl, tab_natural_dimensions, NULL);
1103 tab_hline (tbl, TAL_2, 0, n_cols - 1, 1);
1105 if ( roc->n_vars > 1)
1106 tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable"));
1108 tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to"));
1109 tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity"));
1110 tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity"));
1119 if ( roc->n_vars > 1)
1120 tab_vline (tbl, TAL_2, 1, 0, n_rows - 1);
1122 for (i = 0; i < roc->n_vars; ++i)
1125 struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1127 if ( roc->n_vars > 1)
1128 tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i]));
1131 tab_hline (tbl, TAL_1, 0, n_cols - 1, x);
1134 for (; (cc = casereader_read (r)) != NULL;
1135 case_unref (cc), x++)
1137 const double se = case_data_idx (cc, TP)->f /
1139 case_data_idx (cc, TP)->f
1141 case_data_idx (cc, FN)->f
1144 const double sp = case_data_idx (cc, TN)->f /
1146 case_data_idx (cc, TN)->f
1148 case_data_idx (cc, FP)->f
1151 tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, CUTPOINT)->f,
1152 var_get_print_format (roc->vars[i]));
1154 tab_double (tbl, n_cols - 2, x, 0, se, NULL);
1155 tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL);
1158 casereader_destroy (r);
1166 draw_roc (struct roc_state *rs, const struct cmd_roc *roc)
1170 struct chart *roc_chart = chart_create ();
1172 chart_write_title (roc_chart, _("ROC Curve"));
1173 chart_write_xlabel (roc_chart, _("1 - Specificity"));
1174 chart_write_ylabel (roc_chart, _("Sensitivity"));
1176 chart_write_xscale (roc_chart, 0, 1, 5);
1177 chart_write_yscale (roc_chart, 0, 1, 5);
1179 if ( roc->reference )
1181 chart_line (roc_chart, 1.0, 0,
1186 for (i = 0; i < roc->n_vars; ++i)
1189 struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1191 chart_vector_start (roc_chart, var_get_name (roc->vars[i]));
1192 for (; (cc = casereader_read (r)) != NULL;
1195 double se = case_data_idx (cc, TP)->f;
1196 double sp = case_data_idx (cc, TN)->f;
1198 se /= case_data_idx (cc, FN)->f +
1199 case_data_idx (cc, TP)->f ;
1201 sp /= case_data_idx (cc, TN)->f +
1202 case_data_idx (cc, FP)->f ;
1204 chart_vector (roc_chart, 1 - sp, se);
1206 chart_vector_end (roc_chart);
1207 casereader_destroy (r);
1210 chart_write_legend (roc_chart);
1212 chart_submit (roc_chart);
1217 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1227 if ( roc->print_coords )
1228 show_coords (rs, roc);