From 9a6ded3d54197e14f7f1a9d49a674df7591a3d71 Mon Sep 17 00:00:00 2001 From: John Darrington Date: Thu, 11 Jun 2009 14:22:22 +0800 Subject: [PATCH] Added code to generate the ROC cutpoint tables. --- src/language/stats/roc.c | 294 ++++++++++++++++++++++++++++++++++++--- tests/command/roc.sh | 3 +- 2 files changed, 280 insertions(+), 17 deletions(-) diff --git a/src/language/stats/roc.c b/src/language/stats/roc.c index b3841634..21b70318 100644 --- a/src/language/stats/roc.c +++ b/src/language/stats/roc.c @@ -28,6 +28,8 @@ #include #include #include +#include + #include @@ -328,26 +330,89 @@ struct roc_state double q1hat; double q2hat; + + struct casewriter *cutpoint_wtr; + struct casereader *cutpoint_rdr; + double prev_result; + double min; + double max; }; -static void output_roc (struct roc_state *rs, const struct cmd_roc *roc); +#define CUTPOINT 0 +#define TP 1 +#define FN 2 +#define TN 3 +#define FP 4 + + +static struct casereader * +accumulate_counts (struct casereader *cutpoint_rdr, + double result, double weight, + bool (*pos_cond) (double, double), + int true_index, int false_index) +{ + const struct caseproto *proto = casereader_get_proto (cutpoint_rdr); + struct casewriter *w = + autopaging_writer_create (proto); + struct casereader *r = casereader_clone (cutpoint_rdr); + struct ccase *cpc; + double prev_cp = SYSMIS; + + + for ( ; (cpc = casereader_read (r) ); case_unref (cpc)) + { + struct ccase *new_case; + const double cp = case_data_idx (cpc, CUTPOINT)->f; + + /* We don't want duplicates here */ + if ( cp == prev_cp ) + continue; + + new_case = case_clone (cpc); + + if ( pos_cond (result, cp)) + { + case_data_rw_idx (new_case, true_index)->f += weight; + } + else + { + case_data_rw_idx (new_case, false_index)->f += weight; + } + + prev_cp = cp; + + casewriter_write (w, new_case); + } + casereader_destroy (r); + + return casewriter_make_reader (w); +} + + + +static void output_roc (struct roc_state *rs, const struct cmd_roc *roc); static struct casereader * process_group (const struct variable *var, struct casereader *reader, bool (*pred) (double, double), const struct dictionary *dict, - double *cc) + double *cc, + struct casereader **cutpoint_rdr, + bool (*pos_cond) (double, double), + int true_index, + int false_index + ) { const struct variable *w = dict_get_weight (dict); - const int weight_idx = w ? var_get_case_index (w) : - caseproto_get_n_widths (casereader_get_proto (reader)) - 1; - struct casereader *r1 = casereader_create_distinct (sort_execute_1var (reader, var), var, w); + const int weight_idx = w ? var_get_case_index (w) : + caseproto_get_n_widths (casereader_get_proto (r1)) - 1; + struct ccase *c1; struct casereader *rclone = casereader_clone (r1); @@ -361,7 +426,7 @@ process_group (const struct variable *var, struct casereader *reader, wtr = autopaging_writer_create (proto); *cc = 0; - + for ( ; (c1 = casereader_read (r1) ); case_unref (c1)) { struct ccase *c2; @@ -372,6 +437,9 @@ process_group (const struct variable *var, struct casereader *reader, double n_eq = 0.0; double n_pred = 0.0; + *cutpoint_rdr = accumulate_counts (*cutpoint_rdr, d1, weight1, + pos_cond, + true_index, false_index); struct ccase *new_case = case_create (proto); @@ -414,12 +482,60 @@ gt (double d1, double d2) return d1 > d2; } + +static bool +ge (double d1, double d2) +{ + return d1 > d2; +} + static bool lt (double d1, double d2) { return d1 < d2; } +static struct casereader * +process_positive_group (const struct variable *var, struct casereader *reader, + const struct dictionary *dict, + struct roc_state *rs) +{ + return process_group (var, reader, gt, dict, &rs->n1, + &rs->cutpoint_rdr, + ge, + TP, FN); +} + + +static struct casereader * +process_negative_group (const struct variable *var, struct casereader *reader, + const struct dictionary *dict, + struct roc_state *rs) +{ + return process_group (var, reader, lt, dict, &rs->n2, + &rs->cutpoint_rdr, + lt, + TN, FP); +} + + + + +static void +append_cutpoint (struct casewriter *writer, double cutpoint) +{ + struct ccase *cc = case_create (casewriter_get_proto (writer)); + + case_data_rw_idx (cc, CUTPOINT)->f = cutpoint; + case_data_rw_idx (cc, TP)->f = 0; + case_data_rw_idx (cc, FN)->f = 0; + case_data_rw_idx (cc, TN)->f = 0; + case_data_rw_idx (cc, FP)->f = 0; + + + casewriter_write (writer, cc); +} + static void do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict) @@ -428,13 +544,71 @@ do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict) struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs); - const struct caseproto *proto = casereader_get_proto (input); - - struct casewriter *neg_wtr = autopaging_writer_create (proto); + struct casewriter *neg_wtr = autopaging_writer_create (casereader_get_proto (input)); struct casereader *negatives = NULL; + struct casereader *positives = NULL; + + + /* Prepare the cutpoints */ + { + struct casereader *r = casereader_clone (input); + struct ccase *c; + struct caseproto *proto = caseproto_create (); - struct casereader *positives = + struct subcase ordering; + struct variable *iv = var_create_internal (CUTPOINT); + subcase_init_var (&ordering, iv, SC_ASCEND); + + + proto = caseproto_add_width (proto, 0); /* cutpoint */ + proto = caseproto_add_width (proto, 0); /* TP */ + proto = caseproto_add_width (proto, 0); /* FN */ + proto = caseproto_add_width (proto, 0); /* TN */ + proto = caseproto_add_width (proto, 0); /* FP */ + + + for (i = 0 ; i < roc->n_vars; ++i) + { + rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto); + rs[i].prev_result = SYSMIS; + rs[i].max = -DBL_MAX; + rs[i].min = DBL_MAX; + } + + for (; (c = casereader_read (r)) != NULL; case_unref (c)) + { + const double weight = dict_get_case_weight (dict, c, NULL); + for (i = 0 ; i < roc->n_vars; ++i) + { + const double result = case_data (c, roc->vars[i])->f; + + minimize (&rs[i].min, result); + maximize (&rs[i].max, result); + + if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result ) + { + const double mean = (result + rs[i].prev_result ) / 2.0; + append_cutpoint (rs[i].cutpoint_wtr, mean); + } + + rs[i].prev_result = result; + } + } + casereader_destroy (r); + + + /* Append the min and max cutpoints */ + for (i = 0 ; i < roc->n_vars; ++i) + { + append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1); + append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1); + + rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr); + } + } + + positives = casereader_create_filter_func (input, match_positives, NULL, @@ -451,16 +625,18 @@ do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict) struct casereader *neg ; struct casereader *pos = casereader_clone (positives); - struct casereader *n_pos = process_group (var, pos, gt, dict, &rs[i].n1); + struct casereader *n_pos = + process_positive_group (var, pos, dict, &rs[i]); if ( negatives == NULL) { negatives = casewriter_make_reader (neg_wtr); } - + neg = casereader_clone (negatives); - n_neg = process_group (var, neg, lt, dict, &rs[i].n2); + n_neg = process_negative_group (var, neg, dict, &rs[i]); + /* Simple join on VALUE */ for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos)) @@ -583,7 +759,6 @@ show_auc (struct roc_state *rs, const struct cmd_roc *roc) if ( roc->print_se ) { - double se ; const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) / (12 * rs[i].n1 * rs[i].n2)); @@ -675,6 +850,93 @@ show_summary (const struct cmd_roc *roc) } +static void +show_coords (struct roc_state *rs, const struct cmd_roc *roc) +{ + int x = 1; + int i; + const int n_cols = roc->n_vars > 1 ? 4 : 3; + int n_rows = 1; + struct tab_table *tbl ; + + for (i = 0; i < roc->n_vars; ++i) + n_rows += casereader_count_cases (rs[i].cutpoint_rdr); + + tbl = tab_create (n_cols, n_rows, 0); + + if ( roc->n_vars > 1) + tab_title (tbl, _("Coordinates of the Curve")); + else + tab_title (tbl, _("Coordinates of the Curve (%s)"), var_to_string (roc->vars[0])); + + + tab_headers (tbl, 1, 0, 1, 0); + + tab_dim (tbl, tab_natural_dimensions, NULL); + + tab_hline (tbl, TAL_2, 0, n_cols - 1, 1); + + if ( roc->n_vars > 1) + tab_text (tbl, 0, 0, TAT_TITLE, _("Test variable")); + + tab_text (tbl, n_cols - 3, 0, TAT_TITLE, _("Positive if greater than or equal to")); + tab_text (tbl, n_cols - 2, 0, TAT_TITLE, _("Sensitivity")); + tab_text (tbl, n_cols - 1, 0, TAT_TITLE, _("1 - Specificity")); + + tab_box (tbl, + TAL_2, TAL_2, + -1, TAL_1, + 0, 0, + n_cols - 1, + n_rows - 1); + + if ( roc->n_vars > 1) + tab_vline (tbl, TAL_2, 1, 0, n_rows - 1); + + for (i = 0; i < roc->n_vars; ++i) + { + struct ccase *cc; + struct casereader *r = casereader_clone (rs[i].cutpoint_rdr); + + if ( roc->n_vars > 1) + tab_text (tbl, 0, x, TAT_TITLE, var_to_string (roc->vars[i])); + + if ( i > 0) + tab_hline (tbl, TAL_1, 0, n_cols - 1, x); + + + for (; (cc = casereader_read (r)) != NULL; + case_unref (cc), x++) + { + const double se = case_data_idx (cc, TP)->f / + ( + case_data_idx (cc, TP)->f + + + case_data_idx (cc, FN)->f + ); + + const double sp = case_data_idx (cc, TN)->f / + ( + case_data_idx (cc, TN)->f + + + case_data_idx (cc, FP)->f + ); + + tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, CUTPOINT)->f, + var_get_print_format (roc->vars[i])); + + tab_double (tbl, n_cols - 2, x, 0, se, NULL); + tab_double (tbl, n_cols - 1, x, 0, 1 - sp, NULL); + } + + casereader_destroy (r); + } + + tab_submit (tbl); +} + + + static void output_roc (struct roc_state *rs, const struct cmd_roc *roc) { @@ -688,8 +950,8 @@ output_roc (struct roc_state *rs, const struct cmd_roc *roc) show_auc (rs, roc); -#if 0 + if ( roc->print_coords ) show_coords (rs, roc); -#endif } + diff --git a/tests/command/roc.sh b/tests/command/roc.sh index df3312be..263d399e 100755 --- a/tests/command/roc.sh +++ b/tests/command/roc.sh @@ -95,6 +95,7 @@ activity="run program" $SUPERVISOR $PSPP --testing-mode $TESTFILE if [ $? -ne 0 ] ; then no_result ; fi + activity="compare results" perl -pi -e 's/^\s*$//g' $TEMPDIR/pspp.list diff -b $TEMPDIR/pspp.list - << EOF @@ -159,7 +160,7 @@ See pspp-1.png for a chart. # y# .000| 1.000| 1.000# # # 1.500| .960| .900# # # 2.500| .680| .340# -# # 3.000| .600| .300# +# # 3.000| .600| .340# # # 3.500| .600| .300# # # 4.500| .200| .020# # # 6.000| .000| .000# -- 2.30.2