X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?a=blobdiff_plain;f=src%2Flanguage%2Fstats%2Froc.c;h=d1461a599f0724471a37aaf41a1120c336670f52;hb=6ccced652fb274a08361fb844fb1c435a4c654d1;hp=d266d7cf7684d311db10a36ba788852a19470f4a;hpb=b883f96966eaf08620aae7269690875d3db11054;p=pspp-builds.git diff --git a/src/language/stats/roc.c b/src/language/stats/roc.c index d266d7cf..d1461a59 100644 --- a/src/language/stats/roc.c +++ b/src/language/stats/roc.c @@ -16,28 +16,26 @@ #include -#include -#include -#include -#include -#include +#include #include #include #include #include #include -#include +#include #include - - +#include +#include +#include +#include #include +#include +#include +#include +#include #include -#include - -#include -#include #include "gettext.h" #define _(msgid) gettext (msgid) @@ -49,7 +47,7 @@ struct cmd_roc const struct variable **vars; const struct dictionary *dict; - const struct variable *state_var ; + const struct variable *state_var; union value state_value; /* Plot the roc curve */ @@ -95,29 +93,32 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) roc.pos = roc.pos_weighted = 0; roc.neg = roc.neg_weighted = 0; roc.dict = dataset_dict (ds); + roc.state_var = NULL; + lex_match (lexer, '/'); if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars, PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC)) - goto error;; + goto error; if ( ! lex_force_match (lexer, T_BY)) { - goto error;; + goto error; } roc.state_var = parse_variable (lexer, dict); if ( !lex_force_match (lexer, '(')) { - goto error;; + goto error; } + value_init (&roc.state_value, var_get_width (roc.state_var)); parse_value (lexer, &roc.state_value, var_get_width (roc.state_var)); if ( !lex_force_match (lexer, ')')) { - goto error;; + goto error; } @@ -140,7 +141,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } } } @@ -164,7 +165,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } } else if (lex_match_id (lexer, "PRINT")) @@ -183,7 +184,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } } } @@ -206,7 +207,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } lex_force_match (lexer, ')'); } @@ -224,7 +225,7 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } lex_force_match (lexer, ')'); } @@ -250,14 +251,14 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) else { lex_error (lexer, NULL); - goto error;; + goto error; } lex_force_match (lexer, ')'); } else { lex_error (lexer, NULL); - goto error;; + goto error; } } } @@ -269,11 +270,15 @@ cmd_roc (struct lexer *lexer, struct dataset *ds) } if ( ! run_roc (ds, &roc)) - goto error;; + goto error; + value_destroy (&roc.state_value, var_get_width (roc.state_var)); + free (roc.vars); return CMD_SUCCESS; error: + if ( roc.state_var) + value_destroy (&roc.state_value, var_get_width (roc.state_var)); free (roc.vars); return CMD_FAILURE; } @@ -382,13 +387,6 @@ struct roc_state double max; }; -#define CUTPOINT 0 -#define TP 1 -#define FN 2 -#define TN 3 -#define FP 4 - - /* Return a new casereader based upon CUTPOINT_RDR. The number of "positive" cases are placed into @@ -414,7 +412,7 @@ accumulate_counts (struct casereader *cutpoint_rdr, for ( ; (cpc = casereader_read (r) ); case_unref (cpc)) { struct ccase *new_case; - const double cp = case_data_idx (cpc, CUTPOINT)->f; + const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f; assert (cp != SYSMIS); @@ -578,7 +576,7 @@ process_positive_group (const struct variable *var, struct casereader *reader, return process_group (var, reader, gt, dict, &rs->n1, &rs->cutpoint_rdr, ge, - TP, FN); + ROC_TP, ROC_FN); } /* @@ -596,7 +594,7 @@ process_negative_group (const struct variable *var, struct casereader *reader, return process_group (var, reader, lt, dict, &rs->n2, &rs->cutpoint_rdr, lt, - TN, FP); + ROC_TN, ROC_FP); } @@ -607,11 +605,11 @@ append_cutpoint (struct casewriter *writer, double cutpoint) { struct ccase *cc = case_create (casewriter_get_proto (writer)); - case_data_rw_idx (cc, CUTPOINT)->f = cutpoint; - case_data_rw_idx (cc, TP)->f = 0; - case_data_rw_idx (cc, FN)->f = 0; - case_data_rw_idx (cc, TN)->f = 0; - case_data_rw_idx (cc, FP)->f = 0; + case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint; + case_data_rw_idx (cc, ROC_TP)->f = 0; + case_data_rw_idx (cc, ROC_FN)->f = 0; + case_data_rw_idx (cc, ROC_TN)->f = 0; + case_data_rw_idx (cc, ROC_FP)->f = 0; casewriter_write (writer, cc); } @@ -619,9 +617,9 @@ append_cutpoint (struct casewriter *writer, double cutpoint) /* Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will - be created with width 5, ready to take the values (cutpoint, TP, FN, TN, FP), and the + be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the reader will be populated with its final number of cases. - However on exit from this function, only CUTPOINT entries will be set to their final + However on exit from this function, only ROC_CUTPOINT entries will be set to their final value. The other entries will be initialised to zero. */ static void @@ -633,13 +631,13 @@ prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader struct caseproto *proto = caseproto_create (); struct subcase ordering; - subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND); + subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND); proto = caseproto_add_width (proto, 0); /* cutpoint */ - proto = caseproto_add_width (proto, 0); /* TP */ - proto = caseproto_add_width (proto, 0); /* FN */ - proto = caseproto_add_width (proto, 0); /* TN */ - proto = caseproto_add_width (proto, 0); /* FP */ + proto = caseproto_add_width (proto, 0); /* ROC_TP */ + proto = caseproto_add_width (proto, 0); /* ROC_FN */ + proto = caseproto_add_width (proto, 0); /* ROC_TN */ + proto = caseproto_add_width (proto, 0); /* ROC_FP */ for (i = 0 ; i < roc->n_vars; ++i) { @@ -931,7 +929,7 @@ show_auc (struct roc_state *rs, const struct cmd_roc *roc) const int n_fields = roc->print_se ? 5 : 1; const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields; const int n_rows = 2 + roc->n_vars; - struct tab_table *tbl = tab_create (n_cols, n_rows, 0); + struct tab_table *tbl = tab_create (n_cols, n_rows); if ( roc->n_vars > 1) tab_title (tbl, _("Area Under the Curve")); @@ -940,7 +938,6 @@ show_auc (struct roc_state *rs, const struct cmd_roc *roc) tab_headers (tbl, n_cols - n_fields, 0, 1, 0); - tab_dim (tbl, tab_natural_dimensions, NULL); tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area")); @@ -961,9 +958,9 @@ show_auc (struct roc_state *rs, const struct cmd_roc *roc) tab_text (tbl, n_cols - 2, 1, TAT_TITLE, _("Lower Bound")); tab_text (tbl, n_cols - 1, 1, TAT_TITLE, _("Upper Bound")); - tab_joint_text (tbl, n_cols - 2, 0, 4, 0, - TAT_TITLE | TAB_CENTER | TAT_PRINTF, - _("Asymp. %g%% Confidence Interval"), roc->ci); + tab_joint_text_format (tbl, n_cols - 2, 0, 4, 0, + TAT_TITLE | TAB_CENTER, + _("Asymp. %g%% Confidence Interval"), roc->ci); tab_vline (tbl, 0, n_cols - 1, 0, 0); tab_hline (tbl, TAL_1, n_cols - 2, n_cols - 1, 1); } @@ -1026,14 +1023,12 @@ show_summary (const struct cmd_roc *roc) { const int n_cols = 3; const int n_rows = 4; - struct tab_table *tbl = tab_create (n_cols, n_rows, 0); + struct tab_table *tbl = tab_create (n_cols, n_rows); tab_title (tbl, _("Case Summary")); tab_headers (tbl, 1, 0, 2, 0); - tab_dim (tbl, tab_natural_dimensions, NULL); - tab_box (tbl, TAL_2, TAL_2, -1, -1, @@ -1084,7 +1079,7 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) for (i = 0; i < roc->n_vars; ++i) n_rows += casereader_count_cases (rs[i].cutpoint_rdr); - tbl = tab_create (n_cols, n_rows, 0); + tbl = tab_create (n_cols, n_rows); if ( roc->n_vars > 1) tab_title (tbl, _("Coordinates of the Curve")); @@ -1094,8 +1089,6 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) tab_headers (tbl, 1, 0, 1, 0); - tab_dim (tbl, tab_natural_dimensions, NULL); - tab_hline (tbl, TAL_2, 0, n_cols - 1, 1); if ( roc->n_vars > 1) @@ -1130,21 +1123,21 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) for (; (cc = casereader_read (r)) != NULL; case_unref (cc), x++) { - const double se = case_data_idx (cc, TP)->f / + const double se = case_data_idx (cc, ROC_TP)->f / ( - case_data_idx (cc, TP)->f + case_data_idx (cc, ROC_TP)->f + - case_data_idx (cc, FN)->f + case_data_idx (cc, ROC_FN)->f ); - const double sp = case_data_idx (cc, TN)->f / + const double sp = case_data_idx (cc, ROC_TN)->f / ( - case_data_idx (cc, TN)->f + case_data_idx (cc, ROC_TN)->f + - case_data_idx (cc, FP)->f + case_data_idx (cc, ROC_FP)->f ); - tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, CUTPOINT)->f, + tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f, var_get_print_format (roc->vars[i])); tab_double (tbl, n_cols - 2, x, 0, se, NULL); @@ -1158,68 +1151,25 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) } -static void -draw_roc (struct roc_state *rs, const struct cmd_roc *roc) -{ - int i; - - struct chart *roc_chart = chart_create (); - - chart_write_title (roc_chart, _("ROC Curve")); - chart_write_xlabel (roc_chart, _("1 - Specificity")); - chart_write_ylabel (roc_chart, _("Sensitivity")); - - chart_write_xscale (roc_chart, 0, 1, 5); - chart_write_yscale (roc_chart, 0, 1, 5); - - if ( roc->reference ) - { - chart_line (roc_chart, 1.0, 0, - 0.0, 1.0, - CHART_DIM_X); - } - - for (i = 0; i < roc->n_vars; ++i) - { - struct ccase *cc; - struct casereader *r = casereader_clone (rs[i].cutpoint_rdr); - - chart_vector_start (roc_chart, var_get_name (roc->vars[i])); - for (; (cc = casereader_read (r)) != NULL; - case_unref (cc)) - { - double se = case_data_idx (cc, TP)->f; - double sp = case_data_idx (cc, TN)->f; - - se /= case_data_idx (cc, FN)->f + - case_data_idx (cc, TP)->f ; - - sp /= case_data_idx (cc, TN)->f + - case_data_idx (cc, FP)->f ; - - chart_vector (roc_chart, 1 - sp, se); - } - chart_vector_end (roc_chart); - casereader_destroy (r); - } - - chart_write_legend (roc_chart); - - chart_submit (roc_chart); -} - - static void output_roc (struct roc_state *rs, const struct cmd_roc *roc) { show_summary (roc); if ( roc->curve ) - draw_roc (rs, roc); + { + struct roc_chart *rc; + size_t i; + + rc = roc_chart_create (roc->reference); + for (i = 0; i < roc->n_vars; i++) + roc_chart_add_var (rc, var_get_name (roc->vars[i]), + rs[i].cutpoint_rdr); + roc_chart_submit (rc); + } show_auc (rs, roc); - if ( roc->print_coords ) show_coords (rs, roc); }