X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?a=blobdiff_plain;f=src%2Flanguage%2Fstats%2Froc.c;h=1d21a4f6a1221ed6b83f3c2bbb8d4e25c457c95a;hb=cb72db62c20ecab427229110820c5b053d0663c4;hp=1d61a55c57a3cad25c834c204c3571632d9df778;hpb=c2f0df181038fe9975d642096e65ea48ca491acd;p=pspp diff --git a/src/language/stats/roc.c b/src/language/stats/roc.c index 1d61a55c57..1d21a4f6a1 100644 --- a/src/language/stats/roc.c +++ b/src/language/stats/roc.c @@ -16,6 +16,8 @@ #include +#include + #include #include #include @@ -36,8 +38,8 @@ #include #include -#include -#include +#include +#include #include "gettext.h" #define _(msgid) gettext (msgid) @@ -383,13 +385,6 @@ struct roc_state double max; }; -#define CUTPOINT 0 -#define TP 1 -#define FN 2 -#define TN 3 -#define FP 4 - - /* Return a new casereader based upon CUTPOINT_RDR. The number of "positive" cases are placed into @@ -415,7 +410,7 @@ accumulate_counts (struct casereader *cutpoint_rdr, for ( ; (cpc = casereader_read (r) ); case_unref (cpc)) { struct ccase *new_case; - const double cp = case_data_idx (cpc, CUTPOINT)->f; + const double cp = case_data_idx (cpc, ROC_CUTPOINT)->f; assert (cp != SYSMIS); @@ -579,7 +574,7 @@ process_positive_group (const struct variable *var, struct casereader *reader, return process_group (var, reader, gt, dict, &rs->n1, &rs->cutpoint_rdr, ge, - TP, FN); + ROC_TP, ROC_FN); } /* @@ -597,7 +592,7 @@ process_negative_group (const struct variable *var, struct casereader *reader, return process_group (var, reader, lt, dict, &rs->n2, &rs->cutpoint_rdr, lt, - TN, FP); + ROC_TN, ROC_FP); } @@ -608,11 +603,11 @@ append_cutpoint (struct casewriter *writer, double cutpoint) { struct ccase *cc = case_create (casewriter_get_proto (writer)); - case_data_rw_idx (cc, CUTPOINT)->f = cutpoint; - case_data_rw_idx (cc, TP)->f = 0; - case_data_rw_idx (cc, FN)->f = 0; - case_data_rw_idx (cc, TN)->f = 0; - case_data_rw_idx (cc, FP)->f = 0; + case_data_rw_idx (cc, ROC_CUTPOINT)->f = cutpoint; + case_data_rw_idx (cc, ROC_TP)->f = 0; + case_data_rw_idx (cc, ROC_FN)->f = 0; + case_data_rw_idx (cc, ROC_TN)->f = 0; + case_data_rw_idx (cc, ROC_FP)->f = 0; casewriter_write (writer, cc); } @@ -620,9 +615,9 @@ append_cutpoint (struct casewriter *writer, double cutpoint) /* Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will - be created with width 5, ready to take the values (cutpoint, TP, FN, TN, FP), and the + be created with width 5, ready to take the values (cutpoint, ROC_TP, ROC_FN, ROC_TN, ROC_FP), and the reader will be populated with its final number of cases. - However on exit from this function, only CUTPOINT entries will be set to their final + However on exit from this function, only ROC_CUTPOINT entries will be set to their final value. The other entries will be initialised to zero. */ static void @@ -634,13 +629,13 @@ prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader struct caseproto *proto = caseproto_create (); struct subcase ordering; - subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND); + subcase_init (&ordering, ROC_CUTPOINT, 0, SC_ASCEND); proto = caseproto_add_width (proto, 0); /* cutpoint */ - proto = caseproto_add_width (proto, 0); /* TP */ - proto = caseproto_add_width (proto, 0); /* FN */ - proto = caseproto_add_width (proto, 0); /* TN */ - proto = caseproto_add_width (proto, 0); /* FP */ + proto = caseproto_add_width (proto, 0); /* ROC_TP */ + proto = caseproto_add_width (proto, 0); /* ROC_FN */ + proto = caseproto_add_width (proto, 0); /* ROC_TN */ + proto = caseproto_add_width (proto, 0); /* ROC_FP */ for (i = 0 ; i < roc->n_vars; ++i) { @@ -932,7 +927,7 @@ show_auc (struct roc_state *rs, const struct cmd_roc *roc) const int n_fields = roc->print_se ? 5 : 1; const int n_cols = roc->n_vars > 1 ? n_fields + 1: n_fields; const int n_rows = 2 + roc->n_vars; - struct tab_table *tbl = tab_create (n_cols, n_rows, 0); + struct tab_table *tbl = tab_create (n_cols, n_rows); if ( roc->n_vars > 1) tab_title (tbl, _("Area Under the Curve")); @@ -941,7 +936,7 @@ show_auc (struct roc_state *rs, const struct cmd_roc *roc) tab_headers (tbl, n_cols - n_fields, 0, 1, 0); - tab_dim (tbl, tab_natural_dimensions, NULL); + tab_dim (tbl, tab_natural_dimensions, NULL, NULL); tab_text (tbl, n_cols - n_fields, 1, TAT_TITLE, _("Area")); @@ -1027,13 +1022,13 @@ show_summary (const struct cmd_roc *roc) { const int n_cols = 3; const int n_rows = 4; - struct tab_table *tbl = tab_create (n_cols, n_rows, 0); + struct tab_table *tbl = tab_create (n_cols, n_rows); tab_title (tbl, _("Case Summary")); tab_headers (tbl, 1, 0, 2, 0); - tab_dim (tbl, tab_natural_dimensions, NULL); + tab_dim (tbl, tab_natural_dimensions, NULL, NULL); tab_box (tbl, TAL_2, TAL_2, @@ -1085,7 +1080,7 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) for (i = 0; i < roc->n_vars; ++i) n_rows += casereader_count_cases (rs[i].cutpoint_rdr); - tbl = tab_create (n_cols, n_rows, 0); + tbl = tab_create (n_cols, n_rows); if ( roc->n_vars > 1) tab_title (tbl, _("Coordinates of the Curve")); @@ -1095,7 +1090,7 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) tab_headers (tbl, 1, 0, 1, 0); - tab_dim (tbl, tab_natural_dimensions, NULL); + tab_dim (tbl, tab_natural_dimensions, NULL, NULL); tab_hline (tbl, TAL_2, 0, n_cols - 1, 1); @@ -1131,21 +1126,21 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) for (; (cc = casereader_read (r)) != NULL; case_unref (cc), x++) { - const double se = case_data_idx (cc, TP)->f / + const double se = case_data_idx (cc, ROC_TP)->f / ( - case_data_idx (cc, TP)->f + case_data_idx (cc, ROC_TP)->f + - case_data_idx (cc, FN)->f + case_data_idx (cc, ROC_FN)->f ); - const double sp = case_data_idx (cc, TN)->f / + const double sp = case_data_idx (cc, ROC_TN)->f / ( - case_data_idx (cc, TN)->f + case_data_idx (cc, ROC_TN)->f + - case_data_idx (cc, FP)->f + case_data_idx (cc, ROC_FP)->f ); - tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, CUTPOINT)->f, + tab_double (tbl, n_cols - 3, x, 0, case_data_idx (cc, ROC_CUTPOINT)->f, var_get_print_format (roc->vars[i])); tab_double (tbl, n_cols - 2, x, 0, se, NULL); @@ -1159,68 +1154,25 @@ show_coords (struct roc_state *rs, const struct cmd_roc *roc) } -static void -draw_roc (struct roc_state *rs, const struct cmd_roc *roc) -{ - int i; - - struct chart *roc_chart = chart_create (); - - chart_write_title (roc_chart, _("ROC Curve")); - chart_write_xlabel (roc_chart, _("1 - Specificity")); - chart_write_ylabel (roc_chart, _("Sensitivity")); - - chart_write_xscale (roc_chart, 0, 1, 5); - chart_write_yscale (roc_chart, 0, 1, 5); - - if ( roc->reference ) - { - chart_line (roc_chart, 1.0, 0, - 0.0, 1.0, - CHART_DIM_X); - } - - for (i = 0; i < roc->n_vars; ++i) - { - struct ccase *cc; - struct casereader *r = casereader_clone (rs[i].cutpoint_rdr); - - chart_vector_start (roc_chart, var_get_name (roc->vars[i])); - for (; (cc = casereader_read (r)) != NULL; - case_unref (cc)) - { - double se = case_data_idx (cc, TP)->f; - double sp = case_data_idx (cc, TN)->f; - - se /= case_data_idx (cc, FN)->f + - case_data_idx (cc, TP)->f ; - - sp /= case_data_idx (cc, TN)->f + - case_data_idx (cc, FP)->f ; - - chart_vector (roc_chart, 1 - sp, se); - } - chart_vector_end (roc_chart); - casereader_destroy (r); - } - - chart_write_legend (roc_chart); - - chart_submit (roc_chart); -} - - static void output_roc (struct roc_state *rs, const struct cmd_roc *roc) { show_summary (roc); if ( roc->curve ) - draw_roc (rs, roc); + { + struct roc_chart *rc; + size_t i; + + rc = roc_chart_create (roc->reference); + for (i = 0; i < roc->n_vars; i++) + roc_chart_add_var (rc, var_get_name (roc->vars[i]), + rs[i].cutpoint_rdr); + chart_submit (roc_chart_get_chart (rc)); + } show_auc (rs, roc); - if ( roc->print_coords ) show_coords (rs, roc); }