1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include "language/stats/roc.h"
21 #include <gsl/gsl_cdf.h>
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/casewriter.h"
26 #include "data/dataset.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/subcase.h"
30 #include "language/command.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/value-parser.h"
33 #include "language/lexer/variable-parser.h"
34 #include "libpspp/misc.h"
35 #include "math/sort.h"
36 #include "output/charts/roc-chart.h"
37 #include "output/pivot-table.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) msgid
46 const struct variable **vars;
47 const struct dictionary *dict;
49 const struct variable *state_var;
50 union value state_value;
51 size_t state_var_width;
53 /* Plot the roc curve */
55 /* Plot the reference line */
62 bool bi_neg_exp; /* True iff the bi-negative exponential critieria
64 enum mv_class exclude;
66 bool invert ; /* True iff a smaller test result variable indicates
75 static int run_roc (struct dataset *ds, struct cmd_roc *roc);
78 cmd_roc (struct lexer *lexer, struct dataset *ds)
81 const struct dictionary *dict = dataset_dict (ds);
86 roc.print_coords = false;
89 roc.reference = false;
91 roc.bi_neg_exp = false;
93 roc.pos = roc.pos_weighted = 0;
94 roc.neg = roc.neg_weighted = 0;
95 roc.dict = dataset_dict (ds);
97 roc.state_var_width = -1;
99 lex_match (lexer, T_SLASH);
100 if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
101 PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
104 if (! lex_force_match (lexer, T_BY))
109 roc.state_var = parse_variable (lexer, dict);
115 if (!lex_force_match (lexer, T_LPAREN))
120 roc.state_var_width = var_get_width (roc.state_var);
121 value_init (&roc.state_value, roc.state_var_width);
122 parse_value (lexer, &roc.state_value, roc.state_var);
125 if (!lex_force_match (lexer, T_RPAREN))
130 while (lex_token (lexer) != T_ENDCMD)
132 lex_match (lexer, T_SLASH);
133 if (lex_match_id (lexer, "MISSING"))
135 lex_match (lexer, T_EQUALS);
136 while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
138 if (lex_match_id (lexer, "INCLUDE"))
140 roc.exclude = MV_SYSTEM;
142 else if (lex_match_id (lexer, "EXCLUDE"))
144 roc.exclude = MV_ANY;
148 lex_error (lexer, NULL);
153 else if (lex_match_id (lexer, "PLOT"))
155 lex_match (lexer, T_EQUALS);
156 if (lex_match_id (lexer, "CURVE"))
159 if (lex_match (lexer, T_LPAREN))
161 roc.reference = true;
162 if (! lex_force_match_id (lexer, "REFERENCE"))
164 if (! lex_force_match (lexer, T_RPAREN))
168 else if (lex_match_id (lexer, "NONE"))
174 lex_error (lexer, NULL);
178 else if (lex_match_id (lexer, "PRINT"))
180 lex_match (lexer, T_EQUALS);
181 while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
183 if (lex_match_id (lexer, "SE"))
187 else if (lex_match_id (lexer, "COORDINATES"))
189 roc.print_coords = true;
193 lex_error (lexer, NULL);
198 else if (lex_match_id (lexer, "CRITERIA"))
200 lex_match (lexer, T_EQUALS);
201 while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
203 if (lex_match_id (lexer, "CUTOFF"))
205 if (! lex_force_match (lexer, T_LPAREN))
207 if (lex_match_id (lexer, "INCLUDE"))
209 roc.exclude = MV_SYSTEM;
211 else if (lex_match_id (lexer, "EXCLUDE"))
213 roc.exclude = MV_USER | MV_SYSTEM;
217 lex_error (lexer, NULL);
220 if (! lex_force_match (lexer, T_RPAREN))
223 else if (lex_match_id (lexer, "TESTPOS"))
225 if (! lex_force_match (lexer, T_LPAREN))
227 if (lex_match_id (lexer, "LARGE"))
231 else if (lex_match_id (lexer, "SMALL"))
237 lex_error (lexer, NULL);
240 if (! lex_force_match (lexer, T_RPAREN))
243 else if (lex_match_id (lexer, "CI"))
245 if (!lex_force_match (lexer, T_LPAREN))
247 if (! lex_force_num (lexer))
249 roc.ci = lex_number (lexer);
251 if (!lex_force_match (lexer, T_RPAREN))
254 else if (lex_match_id (lexer, "DISTRIBUTION"))
256 if (!lex_force_match (lexer, T_LPAREN))
258 if (lex_match_id (lexer, "FREE"))
260 roc.bi_neg_exp = false;
262 else if (lex_match_id (lexer, "NEGEXPO"))
264 roc.bi_neg_exp = true;
268 lex_error (lexer, NULL);
271 if (!lex_force_match (lexer, T_RPAREN))
276 lex_error (lexer, NULL);
283 lex_error (lexer, NULL);
288 if (! run_roc (ds, &roc))
292 value_destroy (&roc.state_value, roc.state_var_width);
298 value_destroy (&roc.state_value, roc.state_var_width);
307 do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict);
311 run_roc (struct dataset *ds, struct cmd_roc *roc)
313 struct dictionary *dict = dataset_dict (ds);
315 struct casereader *group;
317 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
318 while (casegrouper_get_next_group (grouper, &group))
320 do_roc (roc, group, dataset_dict (ds));
322 ok = casegrouper_destroy (grouper);
323 ok = proc_commit (ds) && ok;
330 dump_casereader (struct casereader *reader)
333 struct casereader *r = casereader_clone (reader);
335 for (; (c = casereader_read (r)); case_unref (c))
338 for (i = 0 ; i < case_get_value_cnt (c); ++i)
339 printf ("%g ", case_num_idx (c, i));
343 casereader_destroy (r);
349 Return true iff the state variable indicates that C has positive actual state.
351 As a side effect, this function also accumulates the roc->{pos,neg} and
352 roc->{pos,neg}_weighted counts.
355 match_positives (const struct ccase *c, void *aux)
357 struct cmd_roc *roc = aux;
358 const struct variable *wv = dict_get_weight (roc->dict);
359 const double weight = wv ? case_num (c, wv) : 1.0;
361 const bool positive =
362 (0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value,
363 var_get_width (roc->state_var)));
368 roc->pos_weighted += weight;
373 roc->neg_weighted += weight;
384 /* Some intermediate state for calculating the cutpoints and the
385 standard error values */
388 double auc; /* Area under the curve */
390 double n1; /* total weight of positives */
391 double n2; /* total weight of negatives */
393 /* intermediates for standard error */
397 /* intermediates for cutpoints */
398 struct casewriter *cutpoint_wtr;
399 struct casereader *cutpoint_rdr;
406 Return a new casereader based upon CUTPOINT_RDR.
407 The number of "positive" cases are placed into
408 the position TRUE_INDEX, and the number of "negative" cases
410 POS_COND and RESULT determine the semantics of what is
412 WEIGHT is the value of a single count.
414 static struct casereader *
415 accumulate_counts (struct casereader *input,
416 double result, double weight,
417 bool (*pos_cond) (double, double),
418 int true_index, int false_index)
420 const struct caseproto *proto = casereader_get_proto (input);
421 struct casewriter *w =
422 autopaging_writer_create (proto);
424 double prev_cp = SYSMIS;
426 for (; (cpc = casereader_read (input)); case_unref (cpc))
428 struct ccase *new_case;
429 const double cp = case_num_idx (cpc, ROC_CUTPOINT);
431 assert (cp != SYSMIS);
433 /* We don't want duplicates here */
437 new_case = case_clone (cpc);
439 int index = pos_cond (result, cp) ? true_index : false_index;
440 *case_num_rw_idx (new_case, index) += weight;
444 casewriter_write (w, new_case);
446 casereader_destroy (input);
448 return casewriter_make_reader (w);
453 static void output_roc (struct roc_state *rs, const struct cmd_roc *roc);
456 This function does 3 things:
458 1. Counts the number of cases which are equal to every other case in READER,
459 and those cases for which the relationship between it and every other case
460 satifies PRED (normally either > or <). VAR is variable defining a case's value
463 2. Counts the number of true and false cases in reader, and populates
464 CUTPOINT_RDR accordingly. TRUE_INDEX and FALSE_INDEX are the indices
465 which receive these values. POS_COND is the condition defining true
468 3. CC is filled with the cumulative weight of all cases of READER.
470 static struct casereader *
471 process_group (const struct variable *var, struct casereader *reader,
472 bool (*pred) (double, double),
473 const struct dictionary *dict,
475 struct casereader **cutpoint_rdr,
476 bool (*pos_cond) (double, double),
480 const struct variable *w = dict_get_weight (dict);
482 struct casereader *r1 =
483 casereader_create_distinct (sort_execute_1var (reader, var), var, w);
485 const int weight_idx = w ? var_get_case_index (w) :
486 caseproto_get_n_widths (casereader_get_proto (r1)) - 1;
490 struct casereader *rclone = casereader_clone (r1);
491 struct casewriter *wtr;
492 struct caseproto *proto = caseproto_create ();
494 proto = caseproto_add_width (proto, 0);
495 proto = caseproto_add_width (proto, 0);
496 proto = caseproto_add_width (proto, 0);
498 wtr = autopaging_writer_create (proto);
502 for (; (c1 = casereader_read (r1)); case_unref (c1))
504 struct ccase *new_case = case_create (proto);
506 struct casereader *r2 = casereader_clone (rclone);
508 const double weight1 = case_num_idx (c1, weight_idx);
509 const double d1 = case_num (c1, var);
513 *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1,
515 true_index, false_index);
519 for (; (c2 = casereader_read (r2)); case_unref (c2))
521 const double d2 = case_num (c2, var);
522 const double weight2 = case_num_idx (c2, weight_idx);
529 else if (pred (d2, d1))
535 *case_num_rw_idx (new_case, VALUE) = d1;
536 *case_num_rw_idx (new_case, N_EQ) = n_eq;
537 *case_num_rw_idx (new_case, N_PRED) = n_pred;
539 casewriter_write (wtr, new_case);
541 casereader_destroy (r2);
545 casereader_destroy (r1);
546 casereader_destroy (rclone);
548 caseproto_unref (proto);
550 return casewriter_make_reader (wtr);
553 /* Some more indeces into case data */
554 #define N_POS_EQ 1 /* number of positive cases with values equal to n */
555 #define N_POS_GT 2 /* number of positive cases with values greater than n */
556 #define N_NEG_EQ 3 /* number of negative cases with values equal to n */
557 #define N_NEG_LT 4 /* number of negative cases with values less than n */
560 gt (double d1, double d2)
567 ge (double d1, double d2)
573 lt (double d1, double d2)
580 Return a casereader with width 3,
581 populated with cases based upon READER.
582 The cases will have the values:
583 (N, number of cases equal to N, number of cases greater than N)
584 As a side effect, update RS->n1 with the number of positive cases.
586 static struct casereader *
587 process_positive_group (const struct variable *var, struct casereader *reader,
588 const struct dictionary *dict,
589 struct roc_state *rs)
591 return process_group (var, reader, gt, dict, &rs->n1,
598 Return a casereader with width 3,
599 populated with cases based upon READER.
600 The cases will have the values:
601 (N, number of cases equal to N, number of cases less than N)
602 As a side effect, update RS->n2 with the number of negative cases.
604 static struct casereader *
605 process_negative_group (const struct variable *var, struct casereader *reader,
606 const struct dictionary *dict,
607 struct roc_state *rs)
609 return process_group (var, reader, lt, dict, &rs->n2,
619 append_cutpoint (struct casewriter *writer, double cutpoint)
621 struct ccase *cc = case_create (casewriter_get_proto (writer));
623 *case_num_rw_idx (cc, ROC_CUTPOINT) = cutpoint;
624 *case_num_rw_idx (cc, ROC_TP) = 0;
625 *case_num_rw_idx (cc, ROC_FN) = 0;
626 *case_num_rw_idx (cc, ROC_TN) = 0;
627 *case_num_rw_idx (cc, ROC_FP) = 0;
629 casewriter_write (writer, cc);
634 Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will
635 be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the
636 reader will be populated with its final number of cases.
637 However on exit from this function, only ROC_CUTPOINT entries will be set to their final
638 value. The other entries will be initialised to zero.
641 prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
644 struct casereader *r = casereader_clone (input);
648 struct caseproto *proto = caseproto_create ();
649 struct subcase ordering;
650 subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND);
652 proto = caseproto_add_width (proto, 0); /* cutpoint */
653 proto = caseproto_add_width (proto, 0); /* ROC_TP */
654 proto = caseproto_add_width (proto, 0); /* ROC_FN */
655 proto = caseproto_add_width (proto, 0); /* ROC_TN */
656 proto = caseproto_add_width (proto, 0); /* ROC_FP */
658 for (i = 0 ; i < roc->n_vars; ++i)
660 rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
661 rs[i].prev_result = SYSMIS;
662 rs[i].max = -DBL_MAX;
666 caseproto_unref (proto);
667 subcase_destroy (&ordering);
670 for (; (c = casereader_read (r)) != NULL; case_unref (c))
672 for (i = 0 ; i < roc->n_vars; ++i)
674 const union value *v = case_data (c, roc->vars[i]);
675 const double result = v->f;
677 if (mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
680 minimize (&rs[i].min, result);
681 maximize (&rs[i].max, result);
683 if (rs[i].prev_result != SYSMIS && rs[i].prev_result != result)
685 const double mean = (result + rs[i].prev_result) / 2.0;
686 append_cutpoint (rs[i].cutpoint_wtr, mean);
689 rs[i].prev_result = result;
692 casereader_destroy (r);
695 /* Append the min and max cutpoints */
696 for (i = 0 ; i < roc->n_vars; ++i)
698 append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
699 append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
701 rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
706 do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
710 struct roc_state *rs = XCALLOC (roc->n_vars, struct roc_state);
712 struct casereader *negatives = NULL;
713 struct casereader *positives = NULL;
715 struct caseproto *n_proto = NULL;
717 struct subcase up_ordering;
718 struct subcase down_ordering;
720 struct casewriter *neg_wtr = NULL;
722 struct casereader *input = casereader_create_filter_missing (reader,
723 roc->vars, roc->n_vars,
728 input = casereader_create_filter_missing (input,
734 neg_wtr = autopaging_writer_create (casereader_get_proto (input));
736 prepare_cutpoints (roc, rs, input);
739 /* Separate the positive actual state cases from the negative ones */
741 casereader_create_filter_func (input,
747 n_proto = caseproto_create ();
749 n_proto = caseproto_add_width (n_proto, 0);
750 n_proto = caseproto_add_width (n_proto, 0);
751 n_proto = caseproto_add_width (n_proto, 0);
752 n_proto = caseproto_add_width (n_proto, 0);
753 n_proto = caseproto_add_width (n_proto, 0);
755 subcase_init (&up_ordering, VALUE, 0, SC_ASCEND);
756 subcase_init (&down_ordering, VALUE, 0, SC_DESCEND);
758 for (i = 0 ; i < roc->n_vars; ++i)
760 struct casewriter *w = NULL;
761 struct casereader *r = NULL;
766 struct casereader *n_neg_reader ;
767 const struct variable *var = roc->vars[i];
769 struct casereader *neg ;
770 struct casereader *pos = casereader_clone (positives);
772 struct casereader *n_pos_reader =
773 process_positive_group (var, pos, dict, &rs[i]);
775 if (negatives == NULL)
777 negatives = casewriter_make_reader (neg_wtr);
780 neg = casereader_clone (negatives);
782 n_neg_reader = process_negative_group (var, neg, dict, &rs[i]);
784 /* Merge the n_pos and n_neg casereaders */
785 w = sort_create_writer (&up_ordering, n_proto);
786 for (; (cpos = casereader_read (n_pos_reader)); case_unref (cpos))
788 struct ccase *pos_case = case_create (n_proto);
790 const double jpos = case_num_idx (cpos, VALUE);
792 while ((cneg = casereader_read (n_neg_reader)))
794 struct ccase *nc = case_create (n_proto);
796 const double jneg = case_num_idx (cneg, VALUE);
798 *case_num_rw_idx (nc, VALUE) = jneg;
799 *case_num_rw_idx (nc, N_POS_EQ) = 0;
801 *case_num_rw_idx (nc, N_POS_GT) = SYSMIS;
803 *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ);
804 *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED);
806 casewriter_write (w, nc);
813 *case_num_rw_idx (pos_case, VALUE) = jpos;
814 *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ);
815 *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED);
816 *case_num_rw_idx (pos_case, N_NEG_EQ) = 0;
817 *case_num_rw_idx (pos_case, N_NEG_LT) = SYSMIS;
819 casewriter_write (w, pos_case);
822 casereader_destroy (n_pos_reader);
823 casereader_destroy (n_neg_reader);
825 /* These aren't used anymore */
829 r = casewriter_make_reader (w);
831 /* Propagate the N_POS_GT values from the positive cases
832 to the negative ones */
834 double prev_pos_gt = rs[i].n1;
835 w = sort_create_writer (&down_ordering, n_proto);
837 for (; (c = casereader_read (r)); case_unref (c))
839 double n_pos_gt = case_num_idx (c, N_POS_GT);
840 struct ccase *nc = case_clone (c);
842 if (n_pos_gt == SYSMIS)
844 n_pos_gt = prev_pos_gt;
845 *case_num_rw_idx (nc, N_POS_GT) = n_pos_gt;
848 casewriter_write (w, nc);
849 prev_pos_gt = n_pos_gt;
852 casereader_destroy (r);
853 r = casewriter_make_reader (w);
856 /* Propagate the N_NEG_LT values from the negative cases
857 to the positive ones */
859 double prev_neg_lt = rs[i].n2;
860 w = sort_create_writer (&up_ordering, n_proto);
862 for (; (c = casereader_read (r)); case_unref (c))
864 double n_neg_lt = case_num_idx (c, N_NEG_LT);
865 struct ccase *nc = case_clone (c);
867 if (n_neg_lt == SYSMIS)
869 n_neg_lt = prev_neg_lt;
870 *case_num_rw_idx (nc, N_NEG_LT) = n_neg_lt;
873 casewriter_write (w, nc);
874 prev_neg_lt = n_neg_lt;
877 casereader_destroy (r);
878 r = casewriter_make_reader (w);
882 struct ccase *prev_case = NULL;
883 for (; (c = casereader_read (r)); case_unref (c))
885 struct ccase *next_case = casereader_peek (r, 0);
887 const double j = case_num_idx (c, VALUE);
888 double n_pos_eq = case_num_idx (c, N_POS_EQ);
889 double n_pos_gt = case_num_idx (c, N_POS_GT);
890 double n_neg_eq = case_num_idx (c, N_NEG_EQ);
891 double n_neg_lt = case_num_idx (c, N_NEG_LT);
893 if (prev_case && j == case_num_idx (prev_case, VALUE))
895 if (0 == case_num_idx (c, N_POS_EQ))
897 n_pos_eq = case_num_idx (prev_case, N_POS_EQ);
898 n_pos_gt = case_num_idx (prev_case, N_POS_GT);
901 if (0 == case_num_idx (c, N_NEG_EQ))
903 n_neg_eq = case_num_idx (prev_case, N_NEG_EQ);
904 n_neg_lt = case_num_idx (prev_case, N_NEG_LT);
908 if (NULL == next_case || j != case_num_idx (next_case, VALUE))
910 rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0;
913 n_neg_eq * (pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0);
915 n_pos_eq * (pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0);
919 case_unref (next_case);
920 case_unref (prev_case);
921 prev_case = case_clone (c);
923 casereader_destroy (r);
924 case_unref (prev_case);
926 rs[i].auc /= rs[i].n1 * rs[i].n2;
928 rs[i].auc = 1 - rs[i].auc;
932 rs[i].q1hat = rs[i].auc / (2 - rs[i].auc);
933 rs[i].q2hat = 2 * pow2 (rs[i].auc) / (1 + rs[i].auc);
937 rs[i].q1hat /= rs[i].n2 * pow2 (rs[i].n1);
938 rs[i].q2hat /= rs[i].n1 * pow2 (rs[i].n2);
943 casereader_destroy (positives);
944 casereader_destroy (negatives);
946 caseproto_unref (n_proto);
947 subcase_destroy (&up_ordering);
948 subcase_destroy (&down_ordering);
950 output_roc (rs, roc);
952 for (i = 0 ; i < roc->n_vars; ++i)
953 casereader_destroy (rs[i].cutpoint_rdr);
959 show_auc (struct roc_state *rs, const struct cmd_roc *roc)
961 struct pivot_table *table = pivot_table_create (N_("Area Under the Curve"));
963 struct pivot_dimension *statistics = pivot_dimension_create (
964 table, PIVOT_AXIS_COLUMN, N_("Statistics"),
965 N_("Area"), PIVOT_RC_OTHER);
968 pivot_category_create_leaves (
970 N_("Std. Error"), PIVOT_RC_OTHER,
971 N_("Asymptotic Sig."), PIVOT_RC_SIGNIFICANCE);
972 struct pivot_category *interval = pivot_category_create_group__ (
974 pivot_value_new_text_format (N_("Asymp. %g%% Confidence Interval"),
976 pivot_category_create_leaves (interval,
977 N_("Lower Bound"), PIVOT_RC_OTHER,
978 N_("Upper Bound"), PIVOT_RC_OTHER);
981 struct pivot_dimension *variables = pivot_dimension_create (
982 table, PIVOT_AXIS_ROW, N_("Variable under test"));
983 variables->root->show_label = true;
985 for (size_t i = 0 ; i < roc->n_vars ; ++i)
987 int var_idx = pivot_category_create_leaf (
988 variables->root, pivot_value_new_variable (roc->vars[i]));
990 pivot_table_put2 (table, 0, var_idx, pivot_value_new_number (rs[i].auc));
994 double se = (rs[i].auc * (1 - rs[i].auc)
995 + (rs[i].n1 - 1) * (rs[i].q1hat - pow2 (rs[i].auc))
996 + (rs[i].n2 - 1) * (rs[i].q2hat - pow2 (rs[i].auc)));
997 se /= rs[i].n1 * rs[i].n2;
1000 double ci = 1 - roc->ci / 100.0;
1001 double yy = gsl_cdf_gaussian_Qinv (ci, se);
1003 double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) /
1004 (12 * rs[i].n1 * rs[i].n2));
1005 double sig = 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5)
1007 double entries[] = { se, sig, rs[i].auc - yy, rs[i].auc + yy };
1008 for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1009 pivot_table_put2 (table, i + 1, var_idx,
1010 pivot_value_new_number (entries[i]));
1014 pivot_table_submit (table);
1019 show_summary (const struct cmd_roc *roc)
1021 struct pivot_table *table = pivot_table_create (N_("Case Summary"));
1023 struct pivot_dimension *statistics = pivot_dimension_create (
1024 table, PIVOT_AXIS_COLUMN, N_("Valid N (listwise)"),
1025 N_("Unweighted"), PIVOT_RC_INTEGER,
1026 N_("Weighted"), PIVOT_RC_OTHER);
1027 statistics->root->show_label = true;
1029 struct pivot_dimension *cases = pivot_dimension_create__ (
1030 table, PIVOT_AXIS_ROW, pivot_value_new_variable (roc->state_var));
1031 cases->root->show_label = true;
1032 pivot_category_create_leaves (cases->root, N_("Positive"), N_("Negative"));
1043 { 1, 0, roc->pos_weighted },
1044 { 1, 1, roc->neg_weighted },
1046 for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
1048 const struct entry *e = &entries[i];
1049 pivot_table_put2 (table, e->stat_idx, e->case_idx,
1050 pivot_value_new_number (e->x));
1052 pivot_table_submit (table);
1056 show_coords (struct roc_state *rs, const struct cmd_roc *roc)
1058 struct pivot_table *table = pivot_table_create (
1059 N_("Coordinates of the Curve"));
1061 pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
1062 N_("Positive if greater than or equal to"),
1063 N_("Sensitivity"), N_("1 - Specificity"));
1065 struct pivot_dimension *coordinates = pivot_dimension_create (
1066 table, PIVOT_AXIS_ROW, N_("Coordinates"));
1067 coordinates->hide_all_labels = true;
1069 struct pivot_dimension *variables = pivot_dimension_create (
1070 table, PIVOT_AXIS_ROW, N_("Test variable"));
1071 variables->root->show_label = true;
1075 for (size_t i = 0; i < roc->n_vars; ++i)
1077 struct casereader *r = casereader_clone (rs[i].cutpoint_rdr);
1079 int var_idx = pivot_category_create_leaf (
1080 variables->root, pivot_value_new_variable (roc->vars[i]));
1084 for (; (cc = casereader_read (r)) != NULL; case_unref (cc))
1086 const double se = case_num_idx (cc, ROC_TP) /
1087 (case_num_idx (cc, ROC_TP) + case_num_idx (cc, ROC_FN));
1089 const double sp = case_num_idx (cc, ROC_TN) /
1090 (case_num_idx (cc, ROC_TN) + case_num_idx (cc, ROC_FP));
1092 if (coord_idx >= n_coords)
1094 assert (coord_idx == n_coords);
1095 pivot_category_create_leaf (
1096 coordinates->root, pivot_value_new_integer (++n_coords));
1100 table, 0, coord_idx, var_idx,
1101 pivot_value_new_var_value (roc->vars[i],
1102 case_data_idx (cc, ROC_CUTPOINT)));
1104 pivot_table_put3 (table, 1, coord_idx, var_idx,
1105 pivot_value_new_number (se));
1106 pivot_table_put3 (table, 2, coord_idx, var_idx,
1107 pivot_value_new_number (1 - sp));
1111 casereader_destroy (r);
1114 pivot_table_submit (table);
1119 output_roc (struct roc_state *rs, const struct cmd_roc *roc)
1125 struct roc_chart *rc;
1128 rc = roc_chart_create (roc->reference);
1129 for (i = 0; i < roc->n_vars; i++)
1130 roc_chart_add_var (rc, var_get_name (roc->vars[i]),
1131 rs[i].cutpoint_rdr);
1132 roc_chart_submit (rc);
1137 if (roc->print_coords)
1138 show_coords (rs, roc);