From: John Darrington Date: Wed, 10 Jun 2009 13:14:01 +0000 (+0800) Subject: Added basic calculation and display of area under the curve X-Git-Tag: build37~50^2~23 X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3542b2c285a83fb43a4b7031c477ab938eaf7c1c;p=pspp-builds.git Added basic calculation and display of area under the curve --- diff --git a/src/language/stats/roc.c b/src/language/stats/roc.c index 81f06293..46d2da9e 100644 --- a/src/language/stats/roc.c +++ b/src/language/stats/roc.c @@ -24,6 +24,14 @@ #include #include +#include +#include +#include + +#include + +#include +#include #include "gettext.h" #define _(msgid) gettext (msgid) @@ -52,7 +60,6 @@ struct cmd_roc bool invert ; /* True iff a smaller test result variable indicates a positive result */ - }; static int run_roc (struct dataset *ds, struct cmd_roc *roc); @@ -269,7 +276,6 @@ run_roc (struct dataset *ds, struct cmd_roc *roc) while (casegrouper_get_next_group (grouper, &group)) { do_roc (roc, group, dataset_dict (ds)); - casereader_destroy (group); } ok = casegrouper_destroy (grouper); ok = proc_commit (ds) && ok; @@ -279,7 +285,348 @@ run_roc (struct dataset *ds, struct cmd_roc *roc) static void -do_roc (struct cmd_roc *roc, struct casereader *group, struct dictionary *dict) +dump_casereader (struct casereader *reader) +{ + struct ccase *c; + struct casereader *r = casereader_clone (reader); + + for ( ; (c = casereader_read (r) ); case_unref (c)) + { + int i; + for (i = 0 ; i < case_get_value_cnt (c); ++i) + { + printf ("%g ", case_data_idx (c, i)->f); + } + printf ("\n"); + } + + casereader_destroy (r); +} + +static bool +match_positives (const struct ccase *c, void *aux) { + struct cmd_roc *roc = aux; + + return 0 == value_compare_3way (case_data (c, roc->state_var), + &roc->state_value, + var_get_width (roc->state_var)); } + +#define VALUE 0 +#define N_EQ 1 +#define N_PRED 2 + +struct roc_state +{ + double auc; + + double n1; + double n2; +}; + + +static void output_roc (struct roc_state *rs, const struct cmd_roc *roc); + + + +static struct casereader * +process_group (const struct variable *var, struct casereader *reader, + bool (*pred) (double, double), + const struct dictionary *dict, + double *cc) +{ + const struct variable *w = dict_get_weight (dict); + const int weight_idx = w ? var_get_case_index (w) : + caseproto_get_n_widths (casereader_get_proto (reader)) - 1; + + struct casereader *r1 = + casereader_create_distinct (sort_execute_1var (reader, var), var, w); + + struct ccase *c1; + + struct casereader *rclone = casereader_clone (r1); + struct casewriter *wtr; + struct caseproto *proto = caseproto_create (); + + proto = caseproto_add_width (proto, 0); + proto = caseproto_add_width (proto, 0); + proto = caseproto_add_width (proto, 0); + + wtr = autopaging_writer_create (proto); + + *cc = 0; + + for ( ; (c1 = casereader_read (r1) ); case_unref (c1)) + { + struct ccase *c2; + struct casereader *r2 = casereader_clone (rclone); + + const double weight1 = case_data_idx (c1, weight_idx)->f; + const double d1 = case_data (c1, var)->f; + double n_eq = 0.0; + double n_pred = 0.0; + + + struct ccase *new_case = case_create (proto); + + *cc += weight1; + + for ( ; (c2 = casereader_read (r2) ); case_unref (c2)) + { + const double d2 = case_data (c2, var)->f; + const double weight2 = case_data_idx (c2, weight_idx)->f; + + if ( d1 == d2 ) + { + n_eq += weight2; + continue; + } + else if ( pred (d2, d1)) + { + n_pred += weight2; + } + } + + case_data_rw_idx (new_case, VALUE)->f = d1; + case_data_rw_idx (new_case, N_EQ)->f = n_eq; + case_data_rw_idx (new_case, N_PRED)->f = n_pred; + + casewriter_write (wtr, new_case); + + casereader_destroy (r2); + } + + casereader_destroy (r1); + casereader_destroy (rclone); + + return casewriter_make_reader (wtr); +} + +static bool +gt (double d1, double d2) +{ + return d1 > d2; +} + +static bool +lt (double d1, double d2) +{ + return d1 < d2; +} + + +static void +do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict) +{ + int i; + + struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs); + + const struct caseproto *proto = casereader_get_proto (input); + + struct casewriter *neg_wtr = autopaging_writer_create (proto); + + struct casereader *negatives = NULL; + + struct casereader *positives = + casereader_create_filter_func (input, + match_positives, + NULL, + roc, + neg_wtr); + + + for (i = 0 ; i < roc->n_vars; ++i) + { + double q1hat = 0; + double q2hat = 0; + + struct ccase *cpos; + struct casereader *n_neg ; + const struct variable *var = roc->vars[i]; + + struct casereader *neg ; + struct casereader *pos = casereader_clone (positives); + + struct casereader *n_pos = process_group (var, pos, gt, dict, &rs[i].n1); + + if ( negatives == NULL) + { + negatives = casewriter_make_reader (neg_wtr); + } + + neg = casereader_clone (negatives); + + n_neg = process_group (var, neg, lt, dict, &rs[i].n2); + + /* Simple join on VALUE */ + for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos)) + { + struct ccase *cneg = NULL; + double dneg = -DBL_MAX; + const double dpos = case_data_idx (cpos, VALUE)->f; + while (dneg < dpos) + { + if ( cneg ) + case_unref (cneg); + + cneg = casereader_read (n_neg); + dneg = case_data_idx (cneg, VALUE)->f; + } + + if ( dpos == dneg ) + { + double n_pos_eq = case_data_idx (cpos, N_EQ)->f; + double n_neg_eq = case_data_idx (cneg, N_EQ)->f; + double n_pos_gt = case_data_idx (cpos, N_PRED)->f; + double n_neg_lt = case_data_idx (cneg, N_PRED)->f; + + rs[i].auc += n_pos_gt * n_neg_eq + (n_pos_eq * n_neg_eq) / 2.0; + q1hat += n_neg_eq * ( pow2 (n_pos_gt) + n_pos_gt * n_pos_eq + pow2 (n_pos_eq) / 3.0); + q2hat += n_pos_eq * ( pow2 (n_neg_lt) + n_neg_lt * n_neg_eq + pow2 (n_neg_eq) / 3.0); + } + + if ( cneg ) + case_unref (cneg); + } + + rs[i].auc /= rs[i].n1 * rs[i].n2; + if ( roc->invert ) + rs[i].auc = 1 - rs[i].auc; + + + } + + casereader_destroy (positives); + casereader_destroy (negatives); + + output_roc (rs, roc); + + free (rs); +} + + + + +static void +show_auc (struct roc_state *rs, const struct cmd_roc *roc) +{ + int i; + const int n_fields = roc->print_se ? 5 : 1; + const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields; + const int n_rows = 2 + roc->n_vars; + struct tab_table *tbl = tab_create (n_cols, n_rows, 0); + + if ( roc->n_vars > 1) + tab_title (tbl, _("Area Under the Curve")); + else + tab_title (tbl, _("Area Under the Curve (%s)"), var_to_string (roc->vars[0])); + + tab_headers (tbl, n_cols - n_fields, 0, 1, 0); + + tab_dim (tbl, tab_natural_dimensions, NULL); + + tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area")); + + tab_hline (tbl, TAL_2, 0, n_cols - 1, 2); + + tab_box (tbl, + TAL_2, TAL_2, + -1, TAL_1, + 0, 0, + n_cols - 1, + n_rows - 1); + + if ( roc->print_se ) + { + tab_text (tbl, n_cols - 4, 1, TAT_TITLE, _("Std. Error")); + tab_text (tbl, n_cols - 3, 1, TAT_TITLE, _("Asymptotic Sig.")); + + tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound")); + tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound")); + + tab_joint_text (tbl, n_cols - 2, 0, 4, 0, + TAT_TITLE | TAB_CENTER | TAT_PRINTF, + _("Asymp. %g%% Confidence Interval"), roc->ci); + tab_vline (tbl, 0, n_cols - 1, 0, 0); + tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1); + } + + if ( roc->n_vars > 1) + tab_text (tbl, 0, 1, TAT_TITLE, _("Variable under test")); + + if ( roc->n_vars > 1) + tab_vline (tbl, TAL_2, 1, 0, n_rows - 1); + + + for ( i = 0 ; i < roc->n_vars ; ++i ) + { + tab_text (tbl, 0, 2 + i, TAT_TITLE, var_to_string (roc->vars[i])); + + tab_double (tbl, n_cols - n_fields, 2 + i, 0, rs[i].auc, NULL); + + if ( roc->print_se ) + { + + double se ; + const double sd_0_5 = sqrt ((rs[i].n1 + rs[i].n2 + 1) / + (12 * rs[i].n1 * rs[i].n2)); + double ci ; + double yy ; + + double q1 = rs[i].auc / ( 2 - rs[i].auc); + double q2 = 2 * pow2 (rs[i].auc) / ( 1 + rs[i].auc); + + se = rs[i].auc * (1 - rs[i].auc) + (rs[i].n1 - 1) * (q1 - pow2 (rs[i].auc)) + + (rs[i].n2 - 1) * (q2 - pow2 (rs[i].auc)); + + se /= rs[i].n1 * rs[i].n2; + + se = sqrt (se); + + tab_double (tbl, n_cols - 4, 2 + i, 0, + se, + NULL); + + ci = 1 - roc->ci / 100.0; + yy = gsl_cdf_gaussian_Qinv (ci, se) ; + + tab_double (tbl, n_cols - 2, 2 + i, 0, + rs[i].auc - yy, + NULL); + + tab_double (tbl, n_cols - 1, 2 + i, 0, + rs[i].auc + yy, + NULL); + + tab_double (tbl, n_cols - 3, 2 + i, 0, + 2.0 * gsl_cdf_ugaussian_Q (fabs ((rs[i].auc - 0.5 ) / sd_0_5)), + NULL); + } + } + + tab_submit (tbl); +} + + + + +static void +output_roc (struct roc_state *rs, const struct cmd_roc *roc) +{ +#if 0 + show_summary (roc); + + if ( roc->curve ) + draw_roc (rs, roc); +#endif + + show_auc (rs, roc); + +#if 0 + if ( roc->print_coords ) + show_coords (rs, roc); +#endif +}