From 3d161b40e72009aeca4bfc1cb8be82d05d7e6e2a Mon Sep 17 00:00:00 2001 From: Jason Stover Date: Wed, 16 Jul 2008 18:22:33 -0400 Subject: [PATCH] Fixed updating of covariance matrix when both variables are categorical. Reported by John Darrington --- src/data/category.c | 17 +++++++++++++ src/data/category.h | 5 ++++ src/math/covariance-matrix.c | 46 ++++++++++++++++++++++++++---------- 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/src/data/category.c b/src/data/category.c index 33078ce2..1620bc7f 100644 --- a/src/data/category.c +++ b/src/data/category.c @@ -148,6 +148,23 @@ cat_value_update (const struct variable *v, const union value *val) } } } +/* + Return the count for the sth category. + */ +size_t +cat_get_category_count (const size_t s, const struct variable *v) +{ + struct cat_vals *tmp; + size_t n_categories; + + tmp = var_get_obs_vals (v); + n_categories = cat_get_n_categories (v); + if (s < n_categories) + { + return tmp->value_counts[s]; + } + return CAT_VALUE_NOT_FOUND; +} const union value * cat_subscript_to_value (const size_t s, const struct variable *v) diff --git a/src/data/category.h b/src/data/category.h index 6ef40857..db4bb339 100644 --- a/src/data/category.h +++ b/src/data/category.h @@ -50,6 +50,11 @@ const union value *cat_subscript_to_value (const size_t, void cat_value_update (const struct variable *, const union value *); +/* + Return the count for the sth category. +*/ +size_t +cat_get_category_count (const size_t, const struct variable *); /* Return the number of categories of a categorical variable. diff --git a/src/math/covariance-matrix.c b/src/math/covariance-matrix.c index 85998dfe..69aaf533 100644 --- a/src/math/covariance-matrix.c +++ b/src/math/covariance-matrix.c @@ -69,13 +69,35 @@ covariance_update_categorical_numeric (struct design_matrix *cov, double mean, gsl_matrix_set (cov->m, row, col, (val2->f - mean) * x * weight); } } +static void +column_iterate (struct design_matrix *cov, const struct variable *v, double weight, + double ssize, double x, const union value *val1, size_t row) +{ + size_t col; + size_t i; + double y; + union value *tmp_val; + col = design_matrix_var_to_column (cov, v); + for (i = 0; i < cat_get_n_categories (v) - 1; i++) + { + col += i; + y = -1.0 * cat_get_category_count (i, v) / ssize; + tmp_val = cat_subscript_to_value (i, v); + if (compare_values (tmp_val, val1, var_get_width (v))) + { + y += -1.0; + } + gsl_matrix_set (cov->m, row, col, x * y * weight); + gsl_matrix_set (cov->m, col, row, x * y * weight); + } +} /* - Call this function in the first data pass. The central moments are + Call this function in the second data pass. The central moments are MEAN1 and MEAN2. Any categorical variables should already have their values summarized in in its OBS_VALS element. */ -void covariance_pass_one (struct design_matrix *cov, double mean1, double mean2, +void covariance_pass_two (struct design_matrix *cov, double mean1, double mean2, double weight, double ssize, const struct variable *v1, const struct variable *v2, const union value *val1, const union value *val2) { @@ -83,7 +105,7 @@ void covariance_pass_one (struct design_matrix *cov, double mean1, double mean2, size_t col; size_t i; double x; - double y; + union value *tmp_val; if (var_is_alpha (v1)) { @@ -94,18 +116,18 @@ void covariance_pass_one (struct design_matrix *cov, double mean1, double mean2, } else { - row = design_matrix_var_to_column (cov, v1); - col = design_matrix_var_to_column (cov, v2); - for (i = 0; i < cat_get_n_categories (v2); i++) + row = design_matrix_var_to_column (cov, v1); + for (i = 0; i < cat_get_n_categories (v1) - 1; i++) { - col += i; - y = -1.0 * cat_get_n_categories (v2) / ssize; - if (i == cat_value_find (v2, val2)) + row += i; + x = -1.0 * cat_get_category_count (i, v1) / ssize; + tmp_val = cat_subscript_to_value (i, v1); + if (compare_values (tmp_val, val1, var_get_width (v1))) { - y += 1.0; + x += 1.0; } - gsl_matrix_set (cov->m, row, col, x * y * weight); - gsl_matrix_set (cov->m, col, row, x * y * weight); + column_iterate (cov, v1, weight, ssize, x, val1, row); + column_iterate (cov, v2, weight, ssize, x, val2, row); } } } -- 2.30.2