From 3d8d78ad9ca206b6489cc3944c985c8ba89e4b1e Mon Sep 17 00:00:00 2001 From: John Darrington Date: Sun, 19 Jul 2009 19:27:51 +0200 Subject: [PATCH] Add some comments and macros to make the code more readable --- src/language/stats/roc.c | 122 +++++++++++++++++++++++++++------------ 1 file changed, 86 insertions(+), 36 deletions(-) diff --git a/src/language/stats/roc.c b/src/language/stats/roc.c index 866eba83..024c9f85 100644 --- a/src/language/stats/roc.c +++ b/src/language/stats/roc.c @@ -324,6 +324,13 @@ dump_casereader (struct casereader *reader) } #endif + +/* + Return true iff the state variable indicates that C has positive actual state. + + As a side effect, this function also accumulates the roc->{pos,neg} and + roc->{pos,neg}_weighted counts. + */ static bool match_positives (const struct ccase *c, void *aux) { @@ -331,9 +338,9 @@ match_positives (const struct ccase *c, void *aux) const struct variable *wv = dict_get_weight (roc->dict); const double weight = wv ? case_data (c, wv)->f : 1.0; - bool positive = ( 0 == value_compare_3way (case_data (c, roc->state_var), - &roc->state_value, - var_get_width (roc->state_var))); + const bool positive = + ( 0 == value_compare_3way (case_data (c, roc->state_var), &roc->state_value, + var_get_width (roc->state_var))); if ( positive ) { @@ -354,6 +361,8 @@ match_positives (const struct ccase *c, void *aux) #define N_EQ 1 #define N_PRED 2 +/* Some intermediate state for calculating the cutpoints and the + standard error values */ struct roc_state { double auc; @@ -371,8 +380,6 @@ struct roc_state double max; }; - - #define CUTPOINT 0 #define TP 1 #define FN 2 @@ -380,6 +387,15 @@ struct roc_state #define FP 4 +/* + Return a new casereader based upon CUTPOINT_RDR. + The number of "positive" cases are placed into + the position TRUE_INDEX, and the number of "negative" cases + into FALSE_INDEX. + POS_COND and RESULT determine the semantics of what is + "positive". + WEIGHT is the value of a single count. + */ static struct casereader * accumulate_counts (struct casereader *cutpoint_rdr, double result, double weight, @@ -407,13 +423,9 @@ accumulate_counts (struct casereader *cutpoint_rdr, new_case = case_clone (cpc); if ( pos_cond (result, cp)) - { - case_data_rw_idx (new_case, true_index)->f += weight; - } + case_data_rw_idx (new_case, true_index)->f += weight; else - { - case_data_rw_idx (new_case, false_index)->f += weight; - } + case_data_rw_idx (new_case, false_index)->f += weight; prev_cp = cp; @@ -509,6 +521,12 @@ process_group (const struct variable *var, struct casereader *reader, return casewriter_make_reader (wtr); } +/* Some more indeces into case data */ +#define N_POS_EQ 1 /* number of positive cases with values equal to n */ +#define N_POS_GT 2 /* number of postive cases with values greater than n */ +#define N_NEG_EQ 3 /* number of negative cases with values equal to n */ +#define N_NEG_LT 4 /* number of negative cases with values less than n */ + static bool gt (double d1, double d2) { @@ -528,6 +546,14 @@ lt (double d1, double d2) return d1 < d2; } + +/* + Return a casereader with width 3, + populated with cases based upon READER. + The cases will have the values: + (N, number of cases equal to N, number of cases greater than N) + As a side effect, update RS->n1 with the number of positive cases. +*/ static struct casereader * process_positive_group (const struct variable *var, struct casereader *reader, const struct dictionary *dict, @@ -539,7 +565,13 @@ process_positive_group (const struct variable *var, struct casereader *reader, TP, FN); } - +/* + Return a casereader with width 3, + populated with cases based upon READER. + The cases will have the values: + (N, number of cases equal to N, number of cases less than N) + As a side effect, update RS->n2 with the number of negative cases. +*/ static struct casereader * process_negative_group (const struct variable *var, struct casereader *reader, const struct dictionary *dict, @@ -565,12 +597,17 @@ append_cutpoint (struct casewriter *writer, double cutpoint) case_data_rw_idx (cc, TN)->f = 0; case_data_rw_idx (cc, FP)->f = 0; - casewriter_write (writer, cc); } -/* Prepare the cutpoints */ +/* + Create and initialise the rs[x].cutpoint_rdr casereaders. That is, the readers will + be created with width 5, ready to take the values (cutpoint, TP, FN, TN, FP), and the + reader will be populated with its final number of cases. + However on exit from this function, only CUTPOINT entries will be set to their final + value. The other entries will be initialised to zero. +*/ static void prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input) { @@ -664,6 +701,8 @@ do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict) prepare_cutpoints (roc, rs, input); + + /* Separate the positive actual state cases from the negative ones */ positives = casereader_create_filter_func (input, match_positives, @@ -696,6 +735,7 @@ do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict) struct casereader *neg ; struct casereader *pos = casereader_clone (positives); + struct casereader *n_pos = process_positive_group (var, pos, dict, &rs[i]); @@ -708,6 +748,8 @@ do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict) n_neg = process_negative_group (var, neg, dict, &rs[i]); + + /* Merge the n_pos and n_neg casereaders */ w = sort_create_writer (&up_ordering, n_proto); for ( ; (cpos = casereader_read (n_pos) ); case_unref (cpos)) { @@ -722,12 +764,12 @@ do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict) const double jneg = case_data_idx (cneg, VALUE)->f; case_data_rw_idx (nc, VALUE)->f = jneg; - case_data_rw_idx (nc, N_EQ)->f = 0; + case_data_rw_idx (nc, N_POS_EQ)->f = 0; - case_data_rw_idx (nc, N_PRED)->f = SYSMIS; + case_data_rw_idx (nc, N_POS_GT)->f = SYSMIS; - *case_data_rw_idx (nc, 3) = *case_data_idx (cneg, N_EQ); - *case_data_rw_idx (nc, 4) = *case_data_idx (cneg, N_PRED); + *case_data_rw_idx (nc, N_NEG_EQ) = *case_data_idx (cneg, N_EQ); + *case_data_rw_idx (nc, N_NEG_LT) = *case_data_idx (cneg, N_PRED); casewriter_write (w, nc); @@ -737,29 +779,35 @@ do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict) } case_data_rw_idx (pos_case, VALUE)->f = jpos; - *case_data_rw_idx (pos_case, N_EQ) = *case_data_idx (cpos, N_EQ); - *case_data_rw_idx (pos_case, N_PRED) = *case_data_idx (cpos, N_PRED); - case_data_rw_idx (pos_case, 3)->f = 0; - case_data_rw_idx (pos_case, 4)->f = SYSMIS; + *case_data_rw_idx (pos_case, N_POS_EQ) = *case_data_idx (cpos, N_EQ); + *case_data_rw_idx (pos_case, N_POS_GT) = *case_data_idx (cpos, N_PRED); + case_data_rw_idx (pos_case, N_NEG_EQ)->f = 0; + case_data_rw_idx (pos_case, N_NEG_LT)->f = SYSMIS; casewriter_write (w, pos_case); } +/* These aren't used anymore */ +#undef N_EQ +#undef N_PRED + r = casewriter_make_reader (w); + /* Propagate the N_POS_GT values from the positive cases + to the negative ones */ { double prev_pos_gt = rs[i].n1; w = sort_create_writer (&down_ordering, n_proto); for ( ; (c = casereader_read (r) ); case_unref (c)) { - double n_pos_gt = case_data_idx (c, N_PRED)->f; + double n_pos_gt = case_data_idx (c, N_POS_GT)->f; struct ccase *nc = case_clone (c); if ( n_pos_gt == SYSMIS) { n_pos_gt = prev_pos_gt; - case_data_rw_idx (nc, N_PRED)->f = n_pos_gt; + case_data_rw_idx (nc, N_POS_GT)->f = n_pos_gt; } casewriter_write (w, nc); @@ -769,19 +817,21 @@ do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict) r = casewriter_make_reader (w); } + /* Propagate the N_NEG_LT values from the negative cases + to the positive ones */ { double prev_neg_lt = rs[i].n2; w = sort_create_writer (&up_ordering, n_proto); for ( ; (c = casereader_read (r) ); case_unref (c)) { - double n_neg_lt = case_data_idx (c, 4)->f; + double n_neg_lt = case_data_idx (c, N_NEG_LT)->f; struct ccase *nc = case_clone (c); if ( n_neg_lt == SYSMIS) { n_neg_lt = prev_neg_lt; - case_data_rw_idx (nc, 4)->f = n_neg_lt; + case_data_rw_idx (nc, N_NEG_LT)->f = n_neg_lt; } casewriter_write (w, nc); @@ -798,23 +848,23 @@ do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict) const struct ccase *next_case = casereader_peek (r, 0); const double j = case_data_idx (c, VALUE)->f; - double n_pos_eq = case_data_idx (c, N_EQ)->f; - double n_pos_gt = case_data_idx (c, N_PRED)->f; - double n_neg_eq = case_data_idx (c, 3)->f; - double n_neg_lt = case_data_idx (c, 4)->f; + double n_pos_eq = case_data_idx (c, N_POS_EQ)->f; + double n_pos_gt = case_data_idx (c, N_POS_GT)->f; + double n_neg_eq = case_data_idx (c, N_NEG_EQ)->f; + double n_neg_lt = case_data_idx (c, N_NEG_LT)->f; if ( prev_case && j == case_data_idx (prev_case, VALUE)->f) { - if ( 0 == case_data_idx (c, N_EQ)->f) + if ( 0 == case_data_idx (c, N_POS_EQ)->f) { - n_pos_eq = case_data_idx (prev_case, N_EQ)->f; - n_pos_gt = case_data_idx (prev_case, N_PRED)->f; + n_pos_eq = case_data_idx (prev_case, N_POS_EQ)->f; + n_pos_gt = case_data_idx (prev_case, N_POS_GT)->f; } - if ( 0 == case_data_idx (c, 3)->f) + if ( 0 == case_data_idx (c, N_NEG_EQ)->f) { - n_neg_eq = case_data_idx (prev_case, 3)->f; - n_neg_lt = case_data_idx (prev_case, 4)->f; + n_neg_eq = case_data_idx (prev_case, N_NEG_EQ)->f; + n_neg_lt = case_data_idx (prev_case, N_NEG_LT)->f; } } -- 2.30.2