#include <config.h>
-#include "roc.h"
#include <data/procedure.h>
#include <language/lexer/variable-parser.h>
#include <language/lexer/value-parser.h>
+#include <language/command.h>
#include <language/lexer/lexer.h>
#include <data/casegrouper.h>
{
size_t n_vars;
const struct variable **vars;
+ const struct dictionary *dict;
- struct variable *state_var ;
+ const struct variable *state_var ;
union value state_value;
/* Plot the roc curve */
bool invert ; /* True iff a smaller test result variable indicates
a positive result */
+
+ double pos;
+ double neg;
+ double pos_weighted;
+ double neg_weighted;
};
static int run_roc (struct dataset *ds, struct cmd_roc *roc);
roc.ci = 95;
roc.bi_neg_exp = false;
roc.invert = false;
+ roc.pos = roc.pos_weighted = 0;
+ roc.neg = roc.neg_weighted = 0;
+ roc.dict = dataset_dict (ds);
if (!parse_variables_const (lexer, dict, &roc.vars, &roc.n_vars,
PV_APPEND | PV_NO_DUPLICATE | PV_NUMERIC))
- return 2;
+ goto error;;
if ( ! lex_force_match (lexer, T_BY))
{
- return 2;
+ goto error;;
}
roc.state_var = parse_variable (lexer, dict);
if ( !lex_force_match (lexer, '('))
{
- return 2;
+ goto error;;
}
parse_value (lexer, &roc.state_value, var_get_width (roc.state_var));
if ( !lex_force_match (lexer, ')'))
{
- return 2;
+ goto error;;
}
else
{
lex_error (lexer, NULL);
- return 2;
+ goto error;;
}
}
}
else
{
lex_error (lexer, NULL);
- return 2;
+ goto error;;
}
}
else if (lex_match_id (lexer, "PRINT"))
else
{
lex_error (lexer, NULL);
- return 2;
+ goto error;;
}
}
}
else
{
lex_error (lexer, NULL);
- return 2;
+ goto error;;
}
lex_force_match (lexer, ')');
}
else
{
lex_error (lexer, NULL);
- return 2;
+ goto error;;
}
lex_force_match (lexer, ')');
}
else
{
lex_error (lexer, NULL);
- return 2;
+ goto error;;
}
lex_force_match (lexer, ')');
}
else
{
lex_error (lexer, NULL);
- return 2;
+ goto error;;
}
}
}
}
}
- run_roc (ds, &roc);
+ if ( ! run_roc (ds, &roc))
+ goto error;;
+
+ return CMD_SUCCESS;
- return 1;
+ error:
+ free (roc.vars);
+ return CMD_FAILURE;
}
match_positives (const struct ccase *c, void *aux)
{
struct cmd_roc *roc = aux;
+ const struct variable *wv = dict_get_weight (roc->dict);
+ const double weight = wv ? case_data (c, wv)->f : 1.0;
- return 0 == value_compare_3way (case_data (c, roc->state_var),
+ bool positive = ( 0 == value_compare_3way (case_data (c, roc->state_var),
&roc->state_value,
- var_get_width (roc->state_var));
+ var_get_width (roc->state_var)));
+
+ if ( positive )
+ {
+ roc->pos++;
+ roc->pos_weighted += weight;
+ }
+ else
+ {
+ roc->neg++;
+ roc->neg_weighted += weight;
+ }
+
+ return positive;
}
struct casereader **cutpoint_rdr,
bool (*pos_cond) (double, double),
int true_index,
- int false_index
- )
+ int false_index)
{
const struct variable *w = dict_get_weight (dict);
+
struct casereader *r1 =
casereader_create_distinct (sort_execute_1var (reader, var), var, w);
}
+/* Prepare the cutpoints */
static void
-do_roc (struct cmd_roc *roc, struct casereader *input, struct dictionary *dict)
+prepare_cutpoints (struct cmd_roc *roc, struct roc_state *rs, struct casereader *input)
{
int i;
+ struct casereader *r = casereader_clone (input);
+ struct ccase *c;
+ struct caseproto *proto = caseproto_create ();
- struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
+ struct subcase ordering;
+ subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND);
- struct casewriter *neg_wtr = autopaging_writer_create (casereader_get_proto (input));
+ proto = caseproto_add_width (proto, 0); /* cutpoint */
+ proto = caseproto_add_width (proto, 0); /* TP */
+ proto = caseproto_add_width (proto, 0); /* FN */
+ proto = caseproto_add_width (proto, 0); /* TN */
+ proto = caseproto_add_width (proto, 0); /* FP */
- struct casereader *negatives = NULL;
- struct casereader *positives = NULL;
+ for (i = 0 ; i < roc->n_vars; ++i)
+ {
+ rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
+ rs[i].prev_result = SYSMIS;
+ rs[i].max = -DBL_MAX;
+ rs[i].min = DBL_MAX;
+ }
- struct caseproto *n_proto = caseproto_create ();
+ for (; (c = casereader_read (r)) != NULL; case_unref (c))
+ {
+ for (i = 0 ; i < roc->n_vars; ++i)
+ {
+ const union value *v = case_data (c, roc->vars[i]);
+ const double result = v->f;
- struct subcase up_ordering;
- struct subcase down_ordering;
+ if ( mv_is_value_missing (var_get_missing_values (roc->vars[i]), v, roc->exclude))
+ continue;
- /* Prepare the cutpoints */
- {
- struct casereader *r = casereader_clone (input);
- struct ccase *c;
- struct caseproto *proto = caseproto_create ();
+ minimize (&rs[i].min, result);
+ maximize (&rs[i].max, result);
- struct subcase ordering;
- subcase_init (&ordering, CUTPOINT, 0, SC_ASCEND);
+ if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
+ {
+ const double mean = (result + rs[i].prev_result ) / 2.0;
+ append_cutpoint (rs[i].cutpoint_wtr, mean);
+ }
- proto = caseproto_add_width (proto, 0); /* cutpoint */
- proto = caseproto_add_width (proto, 0); /* TP */
- proto = caseproto_add_width (proto, 0); /* FN */
- proto = caseproto_add_width (proto, 0); /* TN */
- proto = caseproto_add_width (proto, 0); /* FP */
+ rs[i].prev_result = result;
+ }
+ }
+ casereader_destroy (r);
- for (i = 0 ; i < roc->n_vars; ++i)
- {
- rs[i].cutpoint_wtr = sort_create_writer (&ordering, proto);
- rs[i].prev_result = SYSMIS;
- rs[i].max = -DBL_MAX;
- rs[i].min = DBL_MAX;
- }
+ /* Append the min and max cutpoints */
+ for (i = 0 ; i < roc->n_vars; ++i)
+ {
+ append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
+ append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
- for (; (c = casereader_read (r)) != NULL; case_unref (c))
- {
- for (i = 0 ; i < roc->n_vars; ++i)
- {
- const double result = case_data (c, roc->vars[i])->f;
+ rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
+ }
+}
- minimize (&rs[i].min, result);
- maximize (&rs[i].max, result);
+static void
+do_roc (struct cmd_roc *roc, struct casereader *reader, struct dictionary *dict)
+{
+ int i;
- if ( rs[i].prev_result != SYSMIS && rs[i].prev_result != result )
- {
- const double mean = (result + rs[i].prev_result ) / 2.0;
- append_cutpoint (rs[i].cutpoint_wtr, mean);
- }
+ struct roc_state *rs = xcalloc (roc->n_vars, sizeof *rs);
- rs[i].prev_result = result;
- }
- }
- casereader_destroy (r);
+ struct casereader *negatives = NULL;
+ struct casereader *positives = NULL;
+ struct caseproto *n_proto = caseproto_create ();
- /* Append the min and max cutpoints */
- for (i = 0 ; i < roc->n_vars; ++i)
- {
- append_cutpoint (rs[i].cutpoint_wtr, rs[i].min - 1);
- append_cutpoint (rs[i].cutpoint_wtr, rs[i].max + 1);
+ struct subcase up_ordering;
+ struct subcase down_ordering;
- rs[i].cutpoint_rdr = casewriter_make_reader (rs[i].cutpoint_wtr);
- }
- }
+ struct casewriter *neg_wtr = NULL;
+
+ struct casereader *input = casereader_create_filter_missing (reader,
+ roc->vars, roc->n_vars,
+ roc->exclude,
+ NULL,
+ NULL);
+
+ input = casereader_create_filter_missing (input,
+ &roc->state_var, 1,
+ roc->exclude,
+ NULL,
+ NULL);
+
+ neg_wtr = autopaging_writer_create (casereader_get_proto (input));
+
+ prepare_cutpoints (roc, rs, input);
positives =
casereader_create_filter_func (input,
tab_text (tbl, 0, 3, TAB_LEFT, _("Negative"));
-#if 0
tab_double (tbl, 1, 2, 0, roc->pos, &F_8_0);
tab_double (tbl, 1, 3, 0, roc->neg, &F_8_0);
tab_double (tbl, 2, 2, 0, roc->pos_weighted, 0);
tab_double (tbl, 2, 3, 0, roc->neg_weighted, 0);
-#endif
tab_submit (tbl);
}