1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include "language/stats/roc.h"
21 #include <gsl/gsl_cdf.h>
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/charts/roc-chart.h"
37 #include "output/pivot-table.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) msgid
46 const struct variable **vars;
47 const struct dictionary *dict;
49 const struct variable *state_var;
50 union value state_value;
51 size_t state_var_width;
53 /* Plot the roc curve */
55 /* Plot the reference line */
62 bool bi_neg_exp; /* True iff the bi-negative exponential critieria
64 enum mv_class exclude;
66 bool invert ; /* True iff a smaller test result variable indicates
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
81 const struct dictionary *dict = dataset_dict (ds);
86 roc.print_coords = false;
89 roc.reference = false;
91 roc.bi_neg_exp = false;
93 roc.pos = roc.pos_weighted = 0;
94 roc.neg = roc.neg_weighted = 0;
95 roc.dict = dataset_dict (ds);
97 roc.state_var_width = -1;
99 lex_match (lexer, T_SLASH);
100 if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
101 PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
104 if (! lex_force_match (lexer, T_BY))
109 roc.state_var = parse_variable (lexer, dict);
115 if (!lex_force_match (lexer, T_LPAREN))
120 roc.state_var_width = var_get_width (roc.state_var);
121 value_init (&roc.state_value, roc.state_var_width);
122 parse_value (lexer, &roc.state_value, roc.state_var);
125 if (!lex_force_match (lexer, T_RPAREN))
130 while (lex_token (lexer) != T_ENDCMD)
132 lex_match (lexer, T_SLASH);
133 if (lex_match_id (lexer, "MISSING"))
135 lex_match (lexer, T_EQUALS);
136 while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
138 if (lex_match_id (lexer, "INCLUDE"))
140 roc.exclude = MV_SYSTEM;
142 else if (lex_match_id (lexer, "EXCLUDE"))
144 roc.exclude = MV_ANY;
148 lex_error (lexer, NULL);
153 else if (lex_match_id (lexer, "PLOT"))
155 lex_match (lexer, T_EQUALS);
156 if (lex_match_id (lexer, "CURVE"))
159 if (lex_match (lexer, T_LPAREN))
161 roc.reference = true;
162 if (! lex_force_match_id (lexer, "REFERENCE"))
164 if (! lex_force_match (lexer, T_RPAREN))
168 else if (lex_match_id (lexer, "NONE"))
174 lex_error (lexer, NULL);
178 else if (lex_match_id (lexer, "PRINT"))
180 lex_match (lexer, T_EQUALS);
181 while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
183 if (lex_match_id (lexer, "SE"))
187 else if (lex_match_id (lexer, "COORDINATES"))
189 roc.print_coords = true;
193 lex_error (lexer, NULL);
198 else if (lex_match_id (lexer, "CRITERIA"))
200 lex_match (lexer, T_EQUALS);
201 while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
203 if (lex_match_id (lexer, "CUTOFF"))
205 if (! lex_force_match (lexer, T_LPAREN))
207 if (lex_match_id (lexer, "INCLUDE"))
209 roc.exclude = MV_SYSTEM;
211 else if (lex_match_id (lexer, "EXCLUDE"))
213 roc.exclude = MV_USER | MV_SYSTEM;
217 lex_error (lexer, NULL);
220 if (! lex_force_match (lexer, T_RPAREN))
223 else if (lex_match_id (lexer, "TESTPOS"))
225 if (! lex_force_match (lexer, T_LPAREN))
227 if (lex_match_id (lexer, "LARGE"))
231 else if (lex_match_id (lexer, "SMALL"))
237 lex_error (lexer, NULL);
240 if (! lex_force_match (lexer, T_RPAREN))
243 else if (lex_match_id (lexer, "CI"))
245 if (!lex_force_match (lexer, T_LPAREN))
247 if (! lex_force_num (lexer))
249 roc.ci = lex_number (lexer);
251 if (!lex_force_match (lexer, T_RPAREN))
254 else if (lex_match_id (lexer, "DISTRIBUTION"))
256 if (!lex_force_match (lexer, T_LPAREN))
258 if (lex_match_id (lexer, "FREE"))
260 roc.bi_neg_exp = false;
262 else if (lex_match_id (lexer, "NEGEXPO"))
264 roc.bi_neg_exp = true;
268 lex_error (lexer, NULL);
271 if (!lex_force_match (lexer, T_RPAREN))
276 lex_error (lexer, NULL);
283 lex_error (lexer, NULL);
288 if (! run_roc (ds, &roc))
292 value_destroy (&roc.state_value, roc.state_var_width);
298 value_destroy (&roc.state_value, roc.state_var_width);
307 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
311 run_roc (struct dataset *ds, struct cmd_roc *roc)
313 struct dictionary *dict = dataset_dict (ds);
315 struct casereader *group;
317 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
318 while (casegrouper_get_next_group (grouper, &group))
320 do_roc (roc, group, dataset_dict (ds));
322 ok = casegrouper_destroy (grouper);
323 ok = proc_commit (ds) && ok;
330 dump_casereader (struct casereader *reader)
333 struct casereader *r = casereader_clone (reader);
335 for (; (c = casereader_read (r)); case_unref (c))
338 for (i = 0 ; i < case_get_value_cnt (c); ++i)
340 printf ("%g ", case_data_idx (c, i)->f);
345 casereader_destroy (r);
351 Return true iff the state variable indicates that C has positive actual state.
353 As a side effect, this function also accumulates the roc->{pos,neg} and
354 roc->{pos,neg}_weighted counts.
357 match_positives (const struct ccase *c, void *aux)
359 struct cmd_roc *roc = aux;
360 const struct variable *wv = dict_get_weight (roc->dict);
361 const double weight = wv ? case_data (c, wv)->f : 1.0;
363 const bool positive =
364 (0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
365 var_get_width (roc->state_var)));
370 roc->pos_weighted += weight;
375 roc->neg_weighted += weight;
386 /* Some intermediate state for calculating the cutpoints and the
387 standard error values */
390 double auc; /* Area under the curve */
392 double n1; /* total weight of positives */
393 double n2; /* total weight of negatives */
395 /* intermediates for standard error */
399 /* intermediates for cutpoints */
400 struct casewriter *cutpoint_wtr;
401 struct casereader *cutpoint_rdr;
408 Return a new casereader based upon CUTPOINT_RDR.
409 The number of "positive" cases are placed into
410 the position TRUE_INDEX, and the number of "negative" cases
412 POS_COND and RESULT determine the semantics of what is
414 WEIGHT is the value of a single count.
416 static struct casereader *
417 accumulate_counts (struct casereader *input,
418 double result, double weight,
419 bool (*pos_cond) (double, double),
420 int true_index, int false_index)
422 const struct caseproto *proto = casereader_get_proto (input);
423 struct casewriter *w =
424 autopaging_writer_create (proto);
426 double prev_cp = SYSMIS;
428 for (; (cpc = casereader_read (input)); case_unref (cpc))
430 struct ccase *new_case;
431 const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f;
433 assert (cp != SYSMIS);
435 /* We don't want duplicates here */
439 new_case = case_clone (cpc);
441 if (pos_cond (result, cp))
442 case_data_rw_idx (new_case, true_index)->f += weight;
444 case_data_rw_idx (new_case, false_index)->f += weight;
448 casewriter_write (w, new_case);
450 casereader_destroy (input);
452 return casewriter_make_reader (w);
457 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
460 This function does 3 things:
462 1. Counts the number of cases which are equal to every other case in READER,
463 and those cases for which the relationship between it and every other case
464 satifies PRED (normally either > or <). VAR is variable defining a case's value
467 2. Counts the number of true and false cases in reader, and populates
468 CUTPOINT_RDR accordingly. TRUE_INDEX and FALSE_INDEX are the indices
469 which receive these values. POS_COND is the condition defining true
472 3. CC is filled with the cumulative weight of all cases of READER.
474 static struct casereader *
475 process_group (const struct variable *var, struct casereader *reader,
476 bool (*pred) (double, double),
477 const struct dictionary *dict,
479 struct casereader **cutpoint_rdr,
480 bool (*pos_cond) (double, double),
484 const struct variable *w = dict_get_weight (dict);
486 struct casereader *r1 =
487 casereader_create_distinct (sort_execute_1var (reader, var), var, w);
489 const int weight_idx = w ? var_get_case_index (w) :
490 caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
494 struct casereader *rclone = casereader_clone (r1);
495 struct casewriter *wtr;
496 struct caseproto *proto = caseproto_create ();
498 proto = caseproto_add_width (proto, 0);
499 proto = caseproto_add_width (proto, 0);
500 proto = caseproto_add_width (proto, 0);
502 wtr = autopaging_writer_create (proto);
506 for (; (c1 = casereader_read (r1)); case_unref (c1))
508 struct ccase *new_case = case_create (proto);
510 struct casereader *r2 = casereader_clone (rclone);
512 const double weight1 = case_data_idx (c1, weight_idx)->f;
513 const double d1 = case_data (c1, var)->f;
517 *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
519 true_index, false_index);
523 for (; (c2 = casereader_read (r2)); case_unref (c2))
525 const double d2 = case_data (c2, var)->f;
526 const double weight2 = case_data_idx (c2, weight_idx)->f;
533 else if (pred (d2, d1))
539 case_data_rw_idx (new_case, VALUE)->f = d1;
540 case_data_rw_idx (new_case, N_EQ)->f = n_eq;
541 case_data_rw_idx (new_case, N_PRED)->f = n_pred;
543 casewriter_write (wtr, new_case);
545 casereader_destroy (r2);
549 casereader_destroy (r1);
550 casereader_destroy (rclone);
552 caseproto_unref (proto);
554 return casewriter_make_reader (wtr);
557 /* Some more indeces into case data */
558 #define N_POS_EQ 1 /* number of positive cases with values equal to n */
559 #define N_POS_GT 2 /* number of positive cases with values greater than n */
560 #define N_NEG_EQ 3 /* number of negative cases with values equal to n */
561 #define N_NEG_LT 4 /* number of negative cases with values less than n */
564 gt (double d1, double d2)
571 ge (double d1, double d2)
577 lt (double d1, double d2)
584 Return a casereader with width 3,
585 populated with cases based upon READER.
586 The cases will have the values:
587 (N, number of cases equal to N, number of cases greater than N)
588 As a side effect, update RS->n1 with the number of positive cases.
590 static struct casereader *
591 process_positive_group (const struct variable *var, struct casereader *reader,
592 const struct dictionary *dict,
593 struct roc_state *rs)
595 return process_group (var, reader, gt, dict, &rs->n1,
602 Return a casereader with width 3,
603 populated with cases based upon READER.
604 The cases will have the values:
605 (N, number of cases equal to N, number of cases less than N)
606 As a side effect, update RS->n2 with the number of negative cases.
608 static struct casereader *
609 process_negative_group (const struct variable *var, struct casereader *reader,
610 const struct dictionary *dict,
611 struct roc_state *rs)
613 return process_group (var, reader, lt, dict, &rs->n2,
623 append_cutpoint (struct casewriter *writer, double cutpoint)
625 struct ccase *cc = case_create (casewriter_get_proto (writer));
627 case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint;
628 case_data_rw_idx (cc, ROC_TP)->f = 0;
629 case_data_rw_idx (cc, ROC_FN)->f = 0;
630 case_data_rw_idx (cc, ROC_TN)->f = 0;
631 case_data_rw_idx (cc, ROC_FP)->f = 0;
633 casewriter_write (writer, cc);
638 Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will
639 be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
640 reader will be populated with its final number of cases.
641 However on exit from this function, only ROC_CUTPOINT entries will be set to their final
642 value. The other entries will be initialised to zero.
645 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
648 struct casereader *r = casereader_clone (input);
652 struct caseproto *proto = caseproto_create ();
653 struct subcase ordering;
654 subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
656 proto = caseproto_add_width (proto, 0); /* cutpoint */
657 proto = caseproto_add_width (proto, 0); /* ROC_TP */
658 proto = caseproto_add_width (proto, 0); /* ROC_FN */
659 proto = caseproto_add_width (proto, 0); /* ROC_TN */
660 proto = caseproto_add_width (proto, 0); /* ROC_FP */
662 for (i = 0 ; i < roc->n_vars; ++i)
664 rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
665 rs[i].prev_result = SYSMIS;
666 rs[i].max = -DBL_MAX;
670 caseproto_unref (proto);
671 subcase_destroy (&ordering);
674 for (; (c = casereader_read (r)) != NULL; case_unref (c))
676 for (i = 0 ; i < roc->n_vars; ++i)
678 const union value *v = case_data (c, roc->vars[i]);
679 const double result = v->f;
681 if (mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
684 minimize (&rs[i].min, result);
685 maximize (&rs[i].max, result);
687 if (rs[i].prev_result != SYSMIS && rs[i].prev_result != result)
689 const double mean = (result + rs[i].prev_result) / 2.0;
690 append_cutpoint (rs[i].cutpoint_wtr, mean);
693 rs[i].prev_result = result;
696 casereader_destroy (r);
699 /* Append the min and max cutpoints */
700 for (i = 0 ; i < roc->n_vars; ++i)
702 append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
703 append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
705 rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
710 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
714 struct roc_state *rs = XCALLOC (roc->n_vars, struct roc_state);
716 struct casereader *negatives = NULL;
717 struct casereader *positives = NULL;
719 struct caseproto *n_proto = NULL;
721 struct subcase up_ordering;
722 struct subcase down_ordering;
724 struct casewriter *neg_wtr = NULL;
726 struct casereader *input = casereader_create_filter_missing (reader,
727 roc->vars, roc->n_vars,
732 input = casereader_create_filter_missing (input,
738 neg_wtr = autopaging_writer_create (casereader_get_proto (input));
740 prepare_cutpoints (roc, rs, input);
743 /* Separate the positive actual state cases from the negative ones */
745 casereader_create_filter_func (input,
751 n_proto = caseproto_create ();
753 n_proto = caseproto_add_width (n_proto, 0);
754 n_proto = caseproto_add_width (n_proto, 0);
755 n_proto = caseproto_add_width (n_proto, 0);
756 n_proto = caseproto_add_width (n_proto, 0);
757 n_proto = caseproto_add_width (n_proto, 0);
759 subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
760 subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
762 for (i = 0 ; i < roc->n_vars; ++i)
764 struct casewriter *w = NULL;
765 struct casereader *r = NULL;
770 struct casereader *n_neg_reader ;
771 const struct variable *var = roc->vars[i];
773 struct casereader *neg ;
774 struct casereader *pos = casereader_clone (positives);
776 struct casereader *n_pos_reader =
777 process_positive_group (var, pos, dict, &rs[i]);
779 if (negatives == NULL)
781 negatives = casewriter_make_reader (neg_wtr);
784 neg = casereader_clone (negatives);
786 n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
788 /* Merge the n_pos and n_neg casereaders */
789 w = sort_create_writer (&up_ordering, n_proto);
790 for (; (cpos = casereader_read (n_pos_reader)); case_unref (cpos))
792 struct ccase *pos_case = case_create (n_proto);
794 const double jpos = case_data_idx (cpos, VALUE)->f;
796 while ((cneg = casereader_read (n_neg_reader)))
798 struct ccase *nc = case_create (n_proto);
800 const double jneg = case_data_idx (cneg, VALUE)->f;
802 case_data_rw_idx (nc, VALUE)->f = jneg;
803 case_data_rw_idx (nc, N_POS_EQ)->f = 0;
805 case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS;
807 *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
808 *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
810 casewriter_write (w, nc);
817 case_data_rw_idx (pos_case, VALUE)->f = jpos;
818 *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
819 *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
820 case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0;
821 case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS;
823 casewriter_write (w, pos_case);
826 casereader_destroy (n_pos_reader);
827 casereader_destroy (n_neg_reader);
829 /* These aren't used anymore */
833 r = casewriter_make_reader (w);
835 /* Propagate the N_POS_GT values from the positive cases
836 to the negative ones */
838 double prev_pos_gt = rs[i].n1;
839 w = sort_create_writer (&down_ordering, n_proto);
841 for (; (c = casereader_read (r)); case_unref (c))
843 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
844 struct ccase *nc = case_clone (c);
846 if (n_pos_gt == SYSMIS)
848 n_pos_gt = prev_pos_gt;
849 case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt;
852 casewriter_write (w, nc);
853 prev_pos_gt = n_pos_gt;
856 casereader_destroy (r);
857 r = casewriter_make_reader (w);
860 /* Propagate the N_NEG_LT values from the negative cases
861 to the positive ones */
863 double prev_neg_lt = rs[i].n2;
864 w = sort_create_writer (&up_ordering, n_proto);
866 for (; (c = casereader_read (r)); case_unref (c))
868 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
869 struct ccase *nc = case_clone (c);
871 if (n_neg_lt == SYSMIS)
873 n_neg_lt = prev_neg_lt;
874 case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt;
877 casewriter_write (w, nc);
878 prev_neg_lt = n_neg_lt;
881 casereader_destroy (r);
882 r = casewriter_make_reader (w);
886 struct ccase *prev_case = NULL;
887 for (; (c = casereader_read (r)); case_unref (c))
889 struct ccase *next_case = casereader_peek (r, 0);
891 const double j = case_data_idx (c, VALUE)->f;
892 double n_pos_eq = case_data_idx (c, N_POS_EQ)->f;
893 double n_pos_gt = case_data_idx (c, N_POS_GT)->f;
894 double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f;
895 double n_neg_lt = case_data_idx (c, N_NEG_LT)->f;
897 if (prev_case && j == case_data_idx (prev_case, VALUE)->f)
899 if (0 == case_data_idx (c, N_POS_EQ)->f)
901 n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f;
902 n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f;
905 if (0 == case_data_idx (c, N_NEG_EQ)->f)
907 n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f;
908 n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f;
912 if (NULL == next_case || j != case_data_idx (next_case, VALUE)->f)
914 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
917 n_neg_eq * (pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
919 n_pos_eq * (pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
923 case_unref (next_case);
924 case_unref (prev_case);
925 prev_case = case_clone (c);
927 casereader_destroy (r);
928 case_unref (prev_case);
930 rs[i].auc /= rs[i].n1 * rs[i].n2;
932 rs[i].auc = 1 - rs[i].auc;
936 rs[i].q1hat = rs[i].auc / (2 - rs[i].auc);
937 rs[i].q2hat = 2 * pow2 (rs[i].auc) / (1 + rs[i].auc);
941 rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
942 rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
947 casereader_destroy (positives);
948 casereader_destroy (negatives);
950 caseproto_unref (n_proto);
951 subcase_destroy (&up_ordering);
952 subcase_destroy (&down_ordering);
954 output_roc (rs, roc);
956 for (i = 0 ; i < roc->n_vars; ++i)
957 casereader_destroy (rs[i].cutpoint_rdr);
963 show_auc (struct roc_state *rs, const struct cmd_roc *roc)
965 struct pivot_table *table = pivot_table_create (N_("Area Under the Curve"));
967 struct pivot_dimension *statistics = pivot_dimension_create (
968 table, PIVOT_AXIS_COLUMN, N_("Statistics"),
969 N_("Area"), PIVOT_RC_OTHER);
972 pivot_category_create_leaves (
974 N_("Std. Error"), PIVOT_RC_OTHER,
975 N_("Asymptotic Sig."), PIVOT_RC_SIGNIFICANCE);
976 struct pivot_category *interval = pivot_category_create_group__ (
978 pivot_value_new_text_format (N_("Asymp. %g%% Confidence Interval"),
980 pivot_category_create_leaves (interval,
981 N_("Lower Bound"), PIVOT_RC_OTHER,
982 N_("Upper Bound"), PIVOT_RC_OTHER);
985 struct pivot_dimension *variables = pivot_dimension_create (
986 table, PIVOT_AXIS_ROW, N_("Variable under test"));
987 variables->root->show_label = true;
989 for (size_t i = 0 ; i < roc->n_vars ; ++i)
991 int var_idx = pivot_category_create_leaf (
992 variables->root, pivot_value_new_variable (roc->vars[i]));
994 pivot_table_put2 (table, 0, var_idx, pivot_value_new_number (rs[i].auc));
998 double se = (rs[i].auc * (1 - rs[i].auc)
999 + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc))
1000 + (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc)));
1001 se /= rs[i].n1 * rs[i].n2;
1004 double ci = 1 - roc->ci / 100.0;
1005 double yy = gsl_cdf_gaussian_Qinv (ci, se);
1007 double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1008 (12 * rs[i].n1 * rs[i].n2));
1009 double sig = 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5)
1011 double entries[] = { se, sig, rs[i].auc - yy, rs[i].auc + yy };
1012 for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1013 pivot_table_put2 (table, i + 1, var_idx,
1014 pivot_value_new_number (entries[i]));
1018 pivot_table_submit (table);
1023 show_summary (const struct cmd_roc *roc)
1025 struct pivot_table *table = pivot_table_create (N_("Case Summary"));
1027 struct pivot_dimension *statistics = pivot_dimension_create (
1028 table, PIVOT_AXIS_COLUMN, N_("Valid N (listwise)"),
1029 N_("Unweighted"), PIVOT_RC_INTEGER,
1030 N_("Weighted"), PIVOT_RC_OTHER);
1031 statistics->root->show_label = true;
1033 struct pivot_dimension *cases = pivot_dimension_create__ (
1034 table, PIVOT_AXIS_ROW, pivot_value_new_variable (roc->state_var));
1035 cases->root->show_label = true;
1036 pivot_category_create_leaves (cases->root, N_("Positive"), N_("Negative"));
1047 { 1, 0, roc->pos_weighted },
1048 { 1, 1, roc->neg_weighted },
1050 for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1052 const struct entry *e = &entries[i];
1053 pivot_table_put2 (table, e->stat_idx, e->case_idx,
1054 pivot_value_new_number (e->x));
1056 pivot_table_submit (table);
1060 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1062 struct pivot_table *table = pivot_table_create (
1063 N_("Coordinates of the Curve"));
1065 pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
1066 N_("Positive if greater than or equal to"),
1067 N_("Sensitivity"), N_("1 - Specificity"));
1069 struct pivot_dimension *coordinates = pivot_dimension_create (
1070 table, PIVOT_AXIS_ROW, N_("Coordinates"));
1071 coordinates->hide_all_labels = true;
1073 struct pivot_dimension *variables = pivot_dimension_create (
1074 table, PIVOT_AXIS_ROW, N_("Test variable"));
1075 variables->root->show_label = true;
1079 for (size_t i = 0; i < roc->n_vars; ++i)
1081 struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1083 int var_idx = pivot_category_create_leaf (
1084 variables->root, pivot_value_new_variable (roc->vars[i]));
1088 for (; (cc = casereader_read (r)) != NULL; case_unref (cc))
1090 const double se = case_data_idx (cc, ROC_TP)->f /
1091 (case_data_idx (cc, ROC_TP)->f + case_data_idx (cc, ROC_FN)->f);
1093 const double sp = case_data_idx (cc, ROC_TN)->f /
1094 (case_data_idx (cc, ROC_TN)->f + case_data_idx (cc, ROC_FP)->f);
1096 if (coord_idx >= n_coords)
1098 assert (coord_idx == n_coords);
1099 pivot_category_create_leaf (
1100 coordinates->root, pivot_value_new_integer (++n_coords));
1104 table, 0, coord_idx, var_idx,
1105 pivot_value_new_var_value (roc->vars[i],
1106 case_data_idx (cc, ROC_CUTPOINT)));
1108 pivot_table_put3 (table, 1, coord_idx, var_idx,
1109 pivot_value_new_number (se));
1110 pivot_table_put3 (table, 2, coord_idx, var_idx,
1111 pivot_value_new_number (1 - sp));
1115 casereader_destroy (r);
1118 pivot_table_submit (table);
1123 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1129 struct roc_chart *rc;
1132 rc = roc_chart_create (roc->reference);
1133 for (i = 0; i < roc->n_vars; i++)
1134 roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1135 rs[i].cutpoint_rdr);
1136 roc_chart_submit (rc);
1141 if (roc->print_coords)
1142 show_coords (rs, roc);