1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include "language/commands/roc.h"
21 #include <gsl/gsl_cdf.h>
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/charts/roc-chart.h"
37 #include "output/pivot-table.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) msgid
46 const struct variable **vars;
47 const struct dictionary *dict;
49 const struct variable *state_var;
50 union value state_value;
51 size_t state_var_width;
53 /* Plot the roc curve */
55 /* Plot the reference line */
62 bool bi_neg_exp; /* True iff the bi-negative exponential critieria
64 enum mv_class exclude;
66 bool invert; /* True iff a smaller test result variable indicates
75 static int run_roc (struct dataset *, struct cmd_roc *);
76 static void do_roc (struct cmd_roc *, struct casereader *, struct dictionary *);
80 cmd_roc (struct lexer *lexer, struct dataset *ds)
82 const struct dictionary *dict = dataset_dict (ds);
84 struct cmd_roc roc = {
89 .state_var_width = -1,
92 lex_match (lexer, T_SLASH);
93 if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
94 PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
97 if (!lex_force_match (lexer, T_BY))
100 roc.state_var = parse_variable (lexer, dict);
104 if (!lex_force_match (lexer, T_LPAREN))
107 roc.state_var_width = var_get_width (roc.state_var);
108 value_init (&roc.state_value, roc.state_var_width);
109 if (!parse_value (lexer, &roc.state_value, roc.state_var)
110 || !lex_force_match (lexer, T_RPAREN))
113 while (lex_token (lexer) != T_ENDCMD)
115 lex_match (lexer, T_SLASH);
116 if (lex_match_id (lexer, "MISSING"))
118 lex_match (lexer, T_EQUALS);
119 while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
121 if (lex_match_id (lexer, "INCLUDE"))
122 roc.exclude = MV_SYSTEM;
123 else if (lex_match_id (lexer, "EXCLUDE"))
124 roc.exclude = MV_ANY;
127 lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
132 else if (lex_match_id (lexer, "PLOT"))
134 lex_match (lexer, T_EQUALS);
135 if (lex_match_id (lexer, "CURVE"))
138 if (lex_match (lexer, T_LPAREN))
140 roc.reference = true;
141 if (!lex_force_match_id (lexer, "REFERENCE")
142 || !lex_force_match (lexer, T_RPAREN))
146 else if (lex_match_id (lexer, "NONE"))
150 lex_error_expecting (lexer, "CURVE", "NONE");
154 else if (lex_match_id (lexer, "PRINT"))
156 lex_match (lexer, T_EQUALS);
157 while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
159 if (lex_match_id (lexer, "SE"))
161 else if (lex_match_id (lexer, "COORDINATES"))
162 roc.print_coords = true;
165 lex_error_expecting (lexer, "SE", "COORDINATES");
170 else if (lex_match_id (lexer, "CRITERIA"))
172 lex_match (lexer, T_EQUALS);
173 while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
175 if (lex_match_id (lexer, "CUTOFF"))
177 if (!lex_force_match (lexer, T_LPAREN))
179 if (lex_match_id (lexer, "INCLUDE"))
180 roc.exclude = MV_SYSTEM;
181 else if (lex_match_id (lexer, "EXCLUDE"))
182 roc.exclude = MV_USER | MV_SYSTEM;
185 lex_error_expecting (lexer, "INCLUDE", "EXCLUDE");
188 if (!lex_force_match (lexer, T_RPAREN))
191 else if (lex_match_id (lexer, "TESTPOS"))
193 if (!lex_force_match (lexer, T_LPAREN))
195 if (lex_match_id (lexer, "LARGE"))
197 else if (lex_match_id (lexer, "SMALL"))
201 lex_error_expecting (lexer, "LARGE", "SMALL");
204 if (!lex_force_match (lexer, T_RPAREN))
207 else if (lex_match_id (lexer, "CI"))
209 if (!lex_force_match (lexer, T_LPAREN))
211 if (!lex_force_num (lexer))
213 roc.ci = lex_number (lexer);
215 if (!lex_force_match (lexer, T_RPAREN))
218 else if (lex_match_id (lexer, "DISTRIBUTION"))
220 if (!lex_force_match (lexer, T_LPAREN))
222 if (lex_match_id (lexer, "FREE"))
223 roc.bi_neg_exp = false;
224 else if (lex_match_id (lexer, "NEGEXPO"))
225 roc.bi_neg_exp = true;
228 lex_error_expecting (lexer, "FREE", "NEGEXPO");
231 if (!lex_force_match (lexer, T_RPAREN))
236 lex_error_expecting (lexer, "CUTOFF", "TESTPOS", "CI",
244 lex_error_expecting (lexer, "MISSING", "PLOT", "PRINT", "CRITERIA");
249 if (!run_roc (ds, &roc))
253 value_destroy (&roc.state_value, roc.state_var_width);
259 value_destroy (&roc.state_value, roc.state_var_width);
265 run_roc (struct dataset *ds, struct cmd_roc *roc)
267 struct dictionary *dict = dataset_dict (ds);
268 struct casereader *group;
270 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
271 while (casegrouper_get_next_group (grouper, &group))
272 do_roc (roc, group, dataset_dict (ds));
274 bool ok = casegrouper_destroy (grouper);
275 ok = proc_commit (ds) && ok;
281 dump_casereader (struct casereader *reader)
284 struct casereader *r = casereader_clone (reader);
286 for (; (c = casereader_read (r)); case_unref (c))
288 for (size_t i = 0; i < case_get_n_values (c); ++i)
289 printf ("%g ", case_num_idx (c, i));
293 casereader_destroy (r);
299 Return true iff the state variable indicates that C has positive actual state.
301 As a side effect, this function also accumulates the roc->{pos,neg} and
302 roc->{pos,neg}_weighted counts.
305 match_positives (const struct ccase *c, void *aux)
307 struct cmd_roc *roc = aux;
308 const struct variable *wv = dict_get_weight (roc->dict);
309 const double weight = wv ? case_num (c, wv) : 1.0;
311 const bool positive =
312 (0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
313 var_get_width (roc->state_var)));
318 roc->pos_weighted += weight;
323 roc->neg_weighted += weight;
334 /* Some intermediate state for calculating the cutpoints and the
335 standard error values */
338 double auc; /* Area under the curve */
340 double n1; /* total weight of positives */
341 double n2; /* total weight of negatives */
343 /* intermediates for standard error */
347 /* intermediates for cutpoints */
348 struct casewriter *cutpoint_wtr;
349 struct casereader *cutpoint_rdr;
356 Return a new casereader based upon CUTPOINT_RDR.
357 The number of "positive" cases are placed into
358 the position TRUE_INDEX, and the number of "negative" cases
360 POS_COND and RESULT determine the semantics of what is
362 WEIGHT is the value of a single count.
364 static struct casereader *
365 accumulate_counts (struct casereader *input,
366 double result, double weight,
367 bool (*pos_cond) (double, double),
368 int true_index, int false_index)
370 const struct caseproto *proto = casereader_get_proto (input);
371 struct casewriter *w =
372 autopaging_writer_create (proto);
374 double prev_cp = SYSMIS;
376 for (; (cpc = casereader_read (input)); case_unref (cpc))
378 struct ccase *new_case;
379 const double cp = case_num_idx (cpc, ROC_CUTPOINT);
381 assert (cp != SYSMIS);
383 /* We don't want duplicates here */
387 new_case = case_clone (cpc);
389 int index = pos_cond (result, cp) ? true_index : false_index;
390 *case_num_rw_idx (new_case, index) += weight;
394 casewriter_write (w, new_case);
396 casereader_destroy (input);
398 return casewriter_make_reader (w);
403 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
406 This function does 3 things:
408 1. Counts the number of cases which are equal to every other case in READER,
409 and those cases for which the relationship between it and every other case
410 satifies PRED (normally either > or <). VAR is variable defining a case's value
413 2. Counts the number of true and false cases in reader, and populates
414 CUTPOINT_RDR accordingly. TRUE_INDEX and FALSE_INDEX are the indices
415 which receive these values. POS_COND is the condition defining true
418 3. CC is filled with the cumulative weight of all cases of READER.
420 static struct casereader *
421 process_group (const struct variable *var, struct casereader *reader,
422 bool (*pred) (double, double),
423 const struct dictionary *dict,
425 struct casereader **cutpoint_rdr,
426 bool (*pos_cond) (double, double),
430 const struct variable *w = dict_get_weight (dict);
432 struct casereader *r1 =
433 casereader_create_distinct (sort_execute_1var (reader, var), var, w);
435 const int weight_idx = w ? var_get_case_index (w) :
436 caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
440 struct casereader *rclone = casereader_clone (r1);
441 struct casewriter *wtr;
442 struct caseproto *proto = caseproto_create ();
444 proto = caseproto_add_width (proto, 0);
445 proto = caseproto_add_width (proto, 0);
446 proto = caseproto_add_width (proto, 0);
448 wtr = autopaging_writer_create (proto);
452 for (; (c1 = casereader_read (r1)); case_unref (c1))
454 struct ccase *new_case = case_create (proto);
456 struct casereader *r2 = casereader_clone (rclone);
458 const double weight1 = case_num_idx (c1, weight_idx);
459 const double d1 = case_num (c1, var);
463 *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
465 true_index, false_index);
469 for (; (c2 = casereader_read (r2)); case_unref (c2))
471 const double d2 = case_num (c2, var);
472 const double weight2 = case_num_idx (c2, weight_idx);
479 else if (pred (d2, d1))
485 *case_num_rw_idx (new_case, VALUE) = d1;
486 *case_num_rw_idx (new_case, N_EQ) = n_eq;
487 *case_num_rw_idx (new_case, N_PRED) = n_pred;
489 casewriter_write (wtr, new_case);
491 casereader_destroy (r2);
494 casereader_destroy (r1);
495 casereader_destroy (rclone);
497 caseproto_unref (proto);
499 return casewriter_make_reader (wtr);
502 /* Some more indeces into case data */
503 #define N_POS_EQ 1 /* number of positive cases with values equal to n */
504 #define N_POS_GT 2 /* number of positive cases with values greater than n */
505 #define N_NEG_EQ 3 /* number of negative cases with values equal to n */
506 #define N_NEG_LT 4 /* number of negative cases with values less than n */
509 gt (double d1, double d2)
516 ge (double d1, double d2)
522 lt (double d1, double d2)
529 Return a casereader with width 3,
530 populated with cases based upon READER.
531 The cases will have the values:
532 (N, number of cases equal to N, number of cases greater than N)
533 As a side effect, update RS->n1 with the number of positive cases.
535 static struct casereader *
536 process_positive_group (const struct variable *var, struct casereader *reader,
537 const struct dictionary *dict,
538 struct roc_state *rs)
540 return process_group (var, reader, gt, dict, &rs->n1,
547 Return a casereader with width 3,
548 populated with cases based upon READER.
549 The cases will have the values:
550 (N, number of cases equal to N, number of cases less than N)
551 As a side effect, update RS->n2 with the number of negative cases.
553 static struct casereader *
554 process_negative_group (const struct variable *var, struct casereader *reader,
555 const struct dictionary *dict,
556 struct roc_state *rs)
558 return process_group (var, reader, lt, dict, &rs->n2,
568 append_cutpoint (struct casewriter *writer, double cutpoint)
570 struct ccase *cc = case_create (casewriter_get_proto (writer));
572 *case_num_rw_idx (cc, ROC_CUTPOINT) = cutpoint;
573 *case_num_rw_idx (cc, ROC_TP) = 0;
574 *case_num_rw_idx (cc, ROC_FN) = 0;
575 *case_num_rw_idx (cc, ROC_TN) = 0;
576 *case_num_rw_idx (cc, ROC_FP) = 0;
578 casewriter_write (writer, cc);
582 Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the
583 readers will be created with width 5, ready to take the values (cutpoint,
584 ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the reader will be populated with its
585 final number of cases. However on exit from this function, only
586 ROC_CUTPOINT entries will be set to their final value. The other entries
587 will be initialised to zero.
589 static struct roc_state *
590 prepare_cutpoints (struct cmd_roc *roc, struct casereader *input)
592 struct casereader *r = casereader_clone (input);
595 struct subcase ordering;
596 subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
598 struct caseproto *proto = caseproto_create ();
599 proto = caseproto_add_width (proto, 0); /* cutpoint */
600 proto = caseproto_add_width (proto, 0); /* ROC_TP */
601 proto = caseproto_add_width (proto, 0); /* ROC_FN */
602 proto = caseproto_add_width (proto, 0); /* ROC_TN */
603 proto = caseproto_add_width (proto, 0); /* ROC_FP */
605 struct roc_state *rs = xnmalloc (roc->n_vars, sizeof *rs);
606 for (size_t i = 0; i < roc->n_vars; ++i)
607 rs[i] = (struct roc_state) {
608 .cutpoint_wtr = sort_create_writer (&ordering, proto),
609 .prev_result = SYSMIS,
614 caseproto_unref (proto);
615 subcase_uninit (&ordering);
617 for (; (c = casereader_read (r)) != NULL; case_unref (c))
618 for (size_t i = 0; i < roc->n_vars; ++i)
620 const union value *v = case_data (c, roc->vars[i]);
621 const double result = v->f;
623 if (mv_is_value_missing (var_get_missing_values (roc->vars[i]), v)
627 minimize (&rs[i].min, result);
628 maximize (&rs[i].max, result);
630 if (rs[i].prev_result != SYSMIS && rs[i].prev_result != result)
632 const double mean = (result + rs[i].prev_result) / 2.0;
633 append_cutpoint (rs[i].cutpoint_wtr, mean);
636 rs[i].prev_result = result;
638 casereader_destroy (r);
640 /* Append the min and max cutpoints */
641 for (size_t i = 0; i < roc->n_vars; ++i)
643 append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
644 append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
646 rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
653 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
655 struct casereader *input = casereader_create_filter_missing (
656 reader, roc->vars, roc->n_vars, roc->exclude, NULL, NULL);
657 input = casereader_create_filter_missing (
658 input, &roc->state_var, 1, roc->exclude, NULL, NULL);
660 struct roc_state *rs = prepare_cutpoints (roc, input);
662 /* Separate the positive actual state cases from the negative ones */
663 struct casewriter *neg_wtr
664 = autopaging_writer_create (casereader_get_proto (input));
665 struct casereader *positives = casereader_create_filter_func (
666 input, match_positives, NULL, roc, neg_wtr);
668 struct caseproto *n_proto = caseproto_create ();
669 for (size_t i = 0; i < 5; i++)
670 n_proto = caseproto_add_width (n_proto, 0);
672 struct subcase up_ordering;
673 struct subcase down_ordering;
674 subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
675 subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
677 struct casereader *negatives = NULL;
678 for (size_t i = 0; i < roc->n_vars; ++i)
680 const struct variable *var = roc->vars[i];
682 struct casereader *pos = casereader_clone (positives);
684 struct casereader *n_pos_reader =
685 process_positive_group (var, pos, dict, &rs[i]);
688 negatives = casewriter_make_reader (neg_wtr);
690 struct casereader *neg = casereader_clone (negatives);
691 struct casereader *n_neg_reader
692 = process_negative_group (var, neg, dict, &rs[i]);
694 /* Merge the n_pos and n_neg casereaders */
695 struct casewriter *w = sort_create_writer (&up_ordering, n_proto);
697 for (; (cpos = casereader_read (n_pos_reader)); case_unref (cpos))
699 struct ccase *pos_case = case_create (n_proto);
700 const double jpos = case_num_idx (cpos, VALUE);
703 while ((cneg = casereader_read (n_neg_reader)))
705 struct ccase *nc = case_create (n_proto);
707 const double jneg = case_num_idx (cneg, VALUE);
709 *case_num_rw_idx (nc, VALUE) = jneg;
710 *case_num_rw_idx (nc, N_POS_EQ) = 0;
712 *case_num_rw_idx (nc, N_POS_GT) = SYSMIS;
714 *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
715 *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
717 casewriter_write (w, nc);
724 *case_num_rw_idx (pos_case, VALUE) = jpos;
725 *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
726 *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
727 *case_num_rw_idx (pos_case, N_NEG_EQ) = 0;
728 *case_num_rw_idx (pos_case, N_NEG_LT) = SYSMIS;
730 casewriter_write (w, pos_case);
733 casereader_destroy (n_pos_reader);
734 casereader_destroy (n_neg_reader);
736 struct casereader *r = casewriter_make_reader (w);
738 /* Propagate the N_POS_GT values from the positive cases
739 to the negative ones */
740 double prev_pos_gt = rs[i].n1;
741 w = sort_create_writer (&down_ordering, n_proto);
744 for (; (c = casereader_read (r)); case_unref (c))
746 double n_pos_gt = case_num_idx (c, N_POS_GT);
747 struct ccase *nc = case_clone (c);
749 if (n_pos_gt == SYSMIS)
751 n_pos_gt = prev_pos_gt;
752 *case_num_rw_idx (nc, N_POS_GT) = n_pos_gt;
755 casewriter_write (w, nc);
756 prev_pos_gt = n_pos_gt;
758 casereader_destroy (r);
759 r = casewriter_make_reader (w);
761 /* Propagate the N_NEG_LT values from the negative cases
762 to the positive ones */
763 double prev_neg_lt = rs[i].n2;
764 w = sort_create_writer (&up_ordering, n_proto);
766 for (; (c = casereader_read (r)); case_unref (c))
768 double n_neg_lt = case_num_idx (c, N_NEG_LT);
769 struct ccase *nc = case_clone (c);
771 if (n_neg_lt == SYSMIS)
773 n_neg_lt = prev_neg_lt;
774 *case_num_rw_idx (nc, N_NEG_LT) = n_neg_lt;
777 casewriter_write (w, nc);
778 prev_neg_lt = n_neg_lt;
781 casereader_destroy (r);
782 r = casewriter_make_reader (w);
784 struct ccase *prev_case = NULL;
785 for (; (c = casereader_read (r)); case_unref (c))
787 struct ccase *next_case = casereader_peek (r, 0);
789 const double j = case_num_idx (c, VALUE);
790 double n_pos_eq = case_num_idx (c, N_POS_EQ);
791 double n_pos_gt = case_num_idx (c, N_POS_GT);
792 double n_neg_eq = case_num_idx (c, N_NEG_EQ);
793 double n_neg_lt = case_num_idx (c, N_NEG_LT);
795 if (prev_case && j == case_num_idx (prev_case, VALUE))
797 if (0 == case_num_idx (c, N_POS_EQ))
799 n_pos_eq = case_num_idx (prev_case, N_POS_EQ);
800 n_pos_gt = case_num_idx (prev_case, N_POS_GT);
803 if (0 == case_num_idx (c, N_NEG_EQ))
805 n_neg_eq = case_num_idx (prev_case, N_NEG_EQ);
806 n_neg_lt = case_num_idx (prev_case, N_NEG_LT);
810 if (NULL == next_case || j != case_num_idx (next_case, VALUE))
812 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
815 n_neg_eq * (pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
817 n_pos_eq * (pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
821 case_unref (next_case);
822 case_unref (prev_case);
823 prev_case = case_clone (c);
825 casereader_destroy (r);
826 case_unref (prev_case);
828 rs[i].auc /= rs[i].n1 * rs[i].n2;
830 rs[i].auc = 1 - rs[i].auc;
834 rs[i].q1hat = rs[i].auc / (2 - rs[i].auc);
835 rs[i].q2hat = 2 * pow2 (rs[i].auc) / (1 + rs[i].auc);
839 rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
840 rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
844 casereader_destroy (positives);
845 casereader_destroy (negatives);
847 caseproto_unref (n_proto);
848 subcase_uninit (&up_ordering);
849 subcase_uninit (&down_ordering);
851 output_roc (rs, roc);
853 for (size_t i = 0; i < roc->n_vars; ++i)
854 casereader_destroy (rs[i].cutpoint_rdr);
860 show_auc (struct roc_state *rs, const struct cmd_roc *roc)
862 struct pivot_table *table = pivot_table_create (N_("Area Under the Curve"));
864 struct pivot_dimension *statistics = pivot_dimension_create (
865 table, PIVOT_AXIS_COLUMN, N_("Statistics"),
866 N_("Area"), PIVOT_RC_OTHER);
869 pivot_category_create_leaves (
871 N_("Std. Error"), PIVOT_RC_OTHER,
872 N_("Asymptotic Sig."), PIVOT_RC_SIGNIFICANCE);
873 struct pivot_category *interval = pivot_category_create_group__ (
875 pivot_value_new_text_format (N_("Asymp. %g%% Confidence Interval"),
877 pivot_category_create_leaves (interval,
878 N_("Lower Bound"), PIVOT_RC_OTHER,
879 N_("Upper Bound"), PIVOT_RC_OTHER);
882 struct pivot_dimension *variables = pivot_dimension_create (
883 table, PIVOT_AXIS_ROW, N_("Variable under test"));
884 variables->root->show_label = true;
886 for (size_t i = 0; i < roc->n_vars; ++i)
888 int var_idx = pivot_category_create_leaf (
889 variables->root, pivot_value_new_variable (roc->vars[i]));
891 pivot_table_put2 (table, 0, var_idx, pivot_value_new_number (rs[i].auc));
895 double se = (rs[i].auc * (1 - rs[i].auc)
896 + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc))
897 + (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc)));
898 se /= rs[i].n1 * rs[i].n2;
901 double ci = 1 - roc->ci / 100.0;
902 double yy = gsl_cdf_gaussian_Qinv (ci, se);
904 double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
905 (12 * rs[i].n1 * rs[i].n2));
906 double sig = 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5)
908 double entries[] = { se, sig, rs[i].auc - yy, rs[i].auc + yy };
909 for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
910 pivot_table_put2 (table, i + 1, var_idx,
911 pivot_value_new_number (entries[i]));
915 pivot_table_submit (table);
919 show_summary (const struct cmd_roc *roc)
921 struct pivot_table *table = pivot_table_create (N_("Case Summary"));
923 struct pivot_dimension *statistics = pivot_dimension_create (
924 table, PIVOT_AXIS_COLUMN, N_("Valid N (listwise)"),
925 N_("Unweighted"), PIVOT_RC_INTEGER,
926 N_("Weighted"), PIVOT_RC_OTHER);
927 statistics->root->show_label = true;
929 struct pivot_dimension *cases = pivot_dimension_create__ (
930 table, PIVOT_AXIS_ROW, pivot_value_new_variable (roc->state_var));
931 cases->root->show_label = true;
932 pivot_category_create_leaves (cases->root, N_("Positive"), N_("Negative"));
943 { 1, 0, roc->pos_weighted },
944 { 1, 1, roc->neg_weighted },
946 for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
948 const struct entry *e = &entries[i];
949 pivot_table_put2 (table, e->stat_idx, e->case_idx,
950 pivot_value_new_number (e->x));
952 pivot_table_submit (table);
956 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
958 struct pivot_table *table = pivot_table_create (
959 N_("Coordinates of the Curve"));
961 pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
962 N_("Positive if greater than or equal to"),
963 N_("Sensitivity"), N_("1 - Specificity"));
965 struct pivot_dimension *coordinates = pivot_dimension_create (
966 table, PIVOT_AXIS_ROW, N_("Coordinates"));
967 coordinates->hide_all_labels = true;
969 struct pivot_dimension *variables = pivot_dimension_create (
970 table, PIVOT_AXIS_ROW, N_("Test variable"));
971 variables->root->show_label = true;
975 for (size_t i = 0; i < roc->n_vars; ++i)
977 struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
979 int var_idx = pivot_category_create_leaf (
980 variables->root, pivot_value_new_variable (roc->vars[i]));
984 for (; (cc = casereader_read (r)) != NULL; case_unref (cc))
986 const double se = case_num_idx (cc, ROC_TP) /
987 (case_num_idx (cc, ROC_TP) + case_num_idx (cc, ROC_FN));
989 const double sp = case_num_idx (cc, ROC_TN) /
990 (case_num_idx (cc, ROC_TN) + case_num_idx (cc, ROC_FP));
992 if (coord_idx >= n_coords)
994 assert (coord_idx == n_coords);
995 pivot_category_create_leaf (
996 coordinates->root, pivot_value_new_integer (++n_coords));
1000 table, 0, coord_idx, var_idx,
1001 pivot_value_new_var_value (roc->vars[i],
1002 case_data_idx (cc, ROC_CUTPOINT)));
1004 pivot_table_put3 (table, 1, coord_idx, var_idx,
1005 pivot_value_new_number (se));
1006 pivot_table_put3 (table, 2, coord_idx, var_idx,
1007 pivot_value_new_number (1 - sp));
1011 casereader_destroy (r);
1014 pivot_table_submit (table);
1018 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1024 struct roc_chart *rc = roc_chart_create (roc->reference);
1025 for (size_t i = 0; i < roc->n_vars; i++)
1026 roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1027 rs[i].cutpoint_rdr);
1028 roc_chart_submit (rc);
1033 if (roc->print_coords)
1034 show_coords (rs, roc);