X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?a=blobdiff_plain;f=src%2Flanguage%2Fstats%2Froc.c;h=1f2691d8919e01eac157688d053f218c426581af;hb=204a1ee35aebcc2cf955017070c1a3638cdaee22;hp=024c9f85297c5bf3fc15fae0f7af95fac121b09e;hpb=3d8d78ad9ca206b6489cc3944c985c8ba89e4b1e;p=pspp diff --git a/src/language/stats/roc.c b/src/language/stats/roc.c index 024c9f8529..1f2691d891 100644 --- a/src/language/stats/roc.c +++ b/src/language/stats/roc.c @@ -16,6 +16,8 @@ #include +#include + #include #include #include @@ -36,8 +38,8 @@ #include #include -#include -#include +#include +#include #include "gettext.h" #define _(msgid) gettext (msgid) @@ -98,26 +100,27 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars, PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC)) - goto error;; + goto error; if ( ! lex_force_match (lexer, T_BY)) { - goto error;; + goto error; } roc.state_var = parse_variable (lexer, dict); if ( !lex_force_match (lexer, '(')) { - goto error;; + goto error; } + value_init (&roc.state_value, var_get_width (roc.state_var)); parse_value (lexer, &roc.state_value, var_get_width (roc.state_var)); if ( !lex_force_match (lexer, ')')) { - goto error;; + goto error; } @@ -140,7 +143,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } } } @@ -164,7 +167,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } } else if (lex_match_id (lexer, "PRINT")) @@ -183,7 +186,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } } } @@ -206,7 +209,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } lex_force_match (lexer, ')'); } @@ -224,7 +227,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } lex_force_match (lexer, ')'); } @@ -250,14 +253,14 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } lex_force_match (lexer, ')'); } else { lex_error (lexer, NULL); - goto error;; + goto error; } } } @@ -269,11 +272,14 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) } if ( ! run_roc (ds, &roc)) - goto error;; + goto error; + value_destroy (&roc.state_value, var_get_width (roc.state_var)); + free (roc.vars); return CMD_SUCCESS; error: + value_destroy (&roc.state_value, var_get_width (roc.state_var)); free (roc.vars); return CMD_FAILURE; } @@ -365,14 +371,16 @@ match_positives (const struct ccase *c, void *aux) standard error values */ struct roc_state { - double auc; + double auc; /* Area under the curve */ - double n1; - double n2; + double n1; /* total weight of positives */ + double n2; /* total weight of negatives */ - double q1hat; + /* intermediates for standard error */ + double q1hat; double q2hat; + /* intermediates for cutpoints */ struct casewriter *cutpoint_wtr; struct casereader *cutpoint_rdr; double prev_result; @@ -380,13 +388,6 @@ struct roc_state double max; }; -#define CUTPOINT 0 -#define TP 1 -#define FN 2 -#define TN 3 -#define FP 4 - - /* Return a new casereader based upon CUTPOINT_RDR. The number of "positive" cases are placed into @@ -412,7 +413,7 @@ accumulate_counts (struct casereader *cutpoint_rdr, for ( ; (cpc = casereader_read (r) ); case_unref (cpc)) { struct ccase *new_case; - const double cp = case_data_idx (cpc, CUTPOINT)->f; + const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f; assert (cp != SYSMIS); @@ -440,7 +441,21 @@ accumulate_counts (struct casereader *cutpoint_rdr, static void output_roc (struct roc_state *rs, const struct cmd_roc *roc); +/* + This function does 3 things: + + 1. Counts the number of cases which are equal to every other case in READER, + and those cases for which the relationship between it and every other case + satifies PRED (normally either > or <). VAR is variable defining a case's value + for this purpose. + 2. Counts the number of true and false cases in reader, and populates + CUTPOINT_RDR accordingly. TRUE_INDEX and FALSE_INDEX are the indices + which receive these values. POS_COND is the condition defining true + and false. + + 3. CC is filled with the cumulative weight of all cases of READER. +*/ static struct casereader * process_group (const struct variable *var, struct casereader *reader, bool (*pred) (double, double), @@ -562,7 +577,7 @@ process_positive_group (const struct variable *var, struct casereader *reader, return process_group (var, reader, gt, dict, &rs->n1, &rs->cutpoint_rdr, ge, - TP, FN); + ROC_TP, ROC_FN); } /* @@ -580,7 +595,7 @@ process_negative_group (const struct variable *var, struct casereader *reader, return process_group (var, reader, lt, dict, &rs->n2, &rs->cutpoint_rdr, lt, - TN, FP); + ROC_TN, ROC_FP); } @@ -591,11 +606,11 @@ append_cutpoint (struct casewriter *writer, double cutpoint) { struct ccase *cc = case_create (casewriter_get_proto (writer)); - case_data_rw_idx (cc, CUTPOINT)->f = cutpoint; - case_data_rw_idx (cc, TP)->f = 0; - case_data_rw_idx (cc, FN)->f = 0; - case_data_rw_idx (cc, TN)->f = 0; - case_data_rw_idx (cc, FP)->f = 0; + case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint; + case_data_rw_idx (cc, ROC_TP)->f = 0; + case_data_rw_idx (cc, ROC_FN)->f = 0; + case_data_rw_idx (cc, ROC_TN)->f = 0; + case_data_rw_idx (cc, ROC_FP)->f = 0; casewriter_write (writer, cc); } @@ -603,9 +618,9 @@ append_cutpoint (struct casewriter *writer, double cutpoint) /* Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will - be created with width 5, ready to take the values (cutpoint, TP, FN, TN, FP), and the + be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the reader will be populated with its final number of cases. - However on exit from this function, only CUTPOINT entries will be set to their final + However on exit from this function, only ROC_CUTPOINT entries will be set to their final value. The other entries will be initialised to zero. */ static void @@ -617,13 +632,13 @@ prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader struct caseproto *proto = caseproto_create (); struct subcase ordering; - subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND); + subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND); proto = caseproto_add_width (proto, 0); /* cutpoint */ - proto = caseproto_add_width (proto, 0); /* TP */ - proto = caseproto_add_width (proto, 0); /* FN */ - proto = caseproto_add_width (proto, 0); /* TN */ - proto = caseproto_add_width (proto, 0); /* FP */ + proto = caseproto_add_width (proto, 0); /* ROC_TP */ + proto = caseproto_add_width (proto, 0); /* ROC_FN */ + proto = caseproto_add_width (proto, 0); /* ROC_TN */ + proto = caseproto_add_width (proto, 0); /* ROC_FP */ for (i = 0 ; i < roc->n_vars; ++i) { @@ -915,7 +930,7 @@ show_auc (struct roc_state *rs, const struct cmd_roc *roc) const int n_fields = roc->print_se ? 5 : 1; const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields; const int n_rows = 2 + roc->n_vars; - struct tab_table *tbl = tab_create (n_cols, n_rows, 0); + struct tab_table *tbl = tab_create (n_cols, n_rows); if ( roc->n_vars > 1) tab_title (tbl, _("Area Under the Curve")); @@ -924,7 +939,7 @@ show_auc (struct roc_state *rs, const struct cmd_roc *roc) tab_headers (tbl, n_cols - n_fields, 0, 1, 0); - tab_dim (tbl, tab_natural_dimensions, NULL); + tab_dim (tbl, tab_natural_dimensions, NULL, NULL); tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area")); @@ -945,9 +960,9 @@ show_auc (struct roc_state *rs, const struct cmd_roc *roc) tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound")); tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound")); - tab_joint_text (tbl, n_cols - 2, 0, 4, 0, - TAT_TITLE | TAB_CENTER | TAT_PRINTF, - _("Asymp. %g%% Confidence Interval"), roc->ci); + tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0, + TAT_TITLE | TAB_CENTER, + _("Asymp. %g%% Confidence Interval"), roc->ci); tab_vline (tbl, 0, n_cols - 1, 0, 0); tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1); } @@ -1010,13 +1025,13 @@ show_summary (const struct cmd_roc *roc) { const int n_cols = 3; const int n_rows = 4; - struct tab_table *tbl = tab_create (n_cols, n_rows, 0); + struct tab_table *tbl = tab_create (n_cols, n_rows); tab_title (tbl, _("Case Summary")); tab_headers (tbl, 1, 0, 2, 0); - tab_dim (tbl, tab_natural_dimensions, NULL); + tab_dim (tbl, tab_natural_dimensions, NULL, NULL); tab_box (tbl, TAL_2, TAL_2, @@ -1068,7 +1083,7 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) for (i = 0; i < roc->n_vars; ++i) n_rows += casereader_count_cases (rs[i].cutpoint_rdr); - tbl = tab_create (n_cols, n_rows, 0); + tbl = tab_create (n_cols, n_rows); if ( roc->n_vars > 1) tab_title (tbl, _("Coordinates of the Curve")); @@ -1078,7 +1093,7 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) tab_headers (tbl, 1, 0, 1, 0); - tab_dim (tbl, tab_natural_dimensions, NULL); + tab_dim (tbl, tab_natural_dimensions, NULL, NULL); tab_hline (tbl, TAL_2, 0, n_cols - 1, 1); @@ -1114,21 +1129,21 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) for (; (cc = casereader_read (r)) != NULL; case_unref (cc), x++) { - const double se = case_data_idx (cc, TP)->f / + const double se = case_data_idx (cc, ROC_TP)->f / ( - case_data_idx (cc, TP)->f + case_data_idx (cc, ROC_TP)->f + - case_data_idx (cc, FN)->f + case_data_idx (cc, ROC_FN)->f ); - const double sp = case_data_idx (cc, TN)->f / + const double sp = case_data_idx (cc, ROC_TN)->f / ( - case_data_idx (cc, TN)->f + case_data_idx (cc, ROC_TN)->f + - case_data_idx (cc, FP)->f + case_data_idx (cc, ROC_FP)->f ); - tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, CUTPOINT)->f, + tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f, var_get_print_format (roc->vars[i])); tab_double (tbl, n_cols - 2, x, 0, se, NULL); @@ -1142,68 +1157,25 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) } -static void -draw_roc (struct roc_state *rs, const struct cmd_roc *roc) -{ - int i; - - struct chart *roc_chart = chart_create (); - - chart_write_title (roc_chart, _("ROC Curve")); - chart_write_xlabel (roc_chart, _("1 - Specificity")); - chart_write_ylabel (roc_chart, _("Sensitivity")); - - chart_write_xscale (roc_chart, 0, 1, 5); - chart_write_yscale (roc_chart, 0, 1, 5); - - if ( roc->reference ) - { - chart_line (roc_chart, 1.0, 0, - 0.0, 1.0, - CHART_DIM_X); - } - - for (i = 0; i < roc->n_vars; ++i) - { - struct ccase *cc; - struct casereader *r = casereader_clone (rs[i].cutpoint_rdr); - - chart_vector_start (roc_chart, var_get_name (roc->vars[i])); - for (; (cc = casereader_read (r)) != NULL; - case_unref (cc)) - { - double se = case_data_idx (cc, TP)->f; - double sp = case_data_idx (cc, TN)->f; - - se /= case_data_idx (cc, FN)->f + - case_data_idx (cc, TP)->f ; - - sp /= case_data_idx (cc, TN)->f + - case_data_idx (cc, FP)->f ; - - chart_vector (roc_chart, 1 - sp, se); - } - chart_vector_end (roc_chart); - casereader_destroy (r); - } - - chart_write_legend (roc_chart); - - chart_submit (roc_chart); -} - - static void output_roc (struct roc_state *rs, const struct cmd_roc *roc) { show_summary (roc); if ( roc->curve ) - draw_roc (rs, roc); + { + struct roc_chart *rc; + size_t i; + + rc = roc_chart_create (roc->reference); + for (i = 0; i < roc->n_vars; i++) + roc_chart_add_var (rc, var_get_name (roc->vars[i]), + rs[i].cutpoint_rdr); + chart_submit (roc_chart_get_chart (rc)); + } show_auc (rs, roc); - if ( roc->print_coords ) show_coords (rs, roc); }