X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?a=blobdiff_plain;f=src%2Flanguage%2Fstats%2Fquick-cluster.c;h=0cd08d02a9268224436bac93339ffc876c9669a7;hb=604b7adbf9c26922f7a20887b2baf16a3f0acef6;hp=46a6ca4adb1dfb7f02cae55071a540921f80a86e;hpb=d74ca27a133a6facb0cb8b2ea6a59e9db33744a2;p=pspp diff --git a/src/language/stats/quick-cluster.c b/src/language/stats/quick-cluster.c index 46a6ca4adb..0cd08d02a9 100644 --- a/src/language/stats/quick-cluster.c +++ b/src/language/stats/quick-cluster.c @@ -40,7 +40,7 @@ #include "libpspp/str.h" #include "math/random.h" #include "output/pivot-table.h" -#include "output/text-item.h" +#include "output/output-item.h" #include "gettext.h" #define _(msgid) gettext (msgid) @@ -196,7 +196,7 @@ diff_matrix (const gsl_matrix *m1, const gsl_matrix *m2) double diff = 0; for (j = 0; j < m1->size2; ++j) { - diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j) ); + diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j)); } if (diff > max_diff) max_diff = diff; @@ -247,7 +247,7 @@ dist_from_case (const struct Kmeans *kmeans, const struct ccase *c, for (j = 0; j < qc->n_vars; j++) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val) & qc->exclude) NOT_REACHED (); dist += pow2 (gsl_matrix_get (kmeans->centers, which, j) - val->f); @@ -302,7 +302,7 @@ kmeans_initial_centers (struct Kmeans *kmeans, for (j = 0; j < qc->n_vars; ++j) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val) & qc->exclude) { missing = true; break; @@ -385,7 +385,7 @@ kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c, for (j = 0; j < qc->n_vars; j++) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val) & qc->exclude) continue; dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f); @@ -462,7 +462,7 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, for (j = 0; j < qc->n_vars; j++) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val) & qc->exclude) missing = true; } @@ -482,16 +482,16 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, } long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group); - *n += qc->wv ? case_data (c, qc->wv)->f : 1.0; + *n += qc->wv ? case_num (c, qc->wv) : 1.0; kmeans->n++; for (j = 0; j < qc->n_vars; ++j) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val) & qc->exclude) continue; double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j); - *x += val->f * (qc->wv ? case_data (c, qc->wv)->f : 1.0); + *x += val->f * (qc->wv ? case_num (c, qc->wv) : 1.0); } } @@ -529,7 +529,7 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, for (j = 0; j < qc->n_vars; ++j) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val) & qc->exclude) continue; double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j); @@ -537,7 +537,7 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, } long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group); - *n += qc->wv ? case_data (c, qc->wv)->f : 1.0; + *n += qc->wv ? case_num (c, qc->wv) : 1.0; kmeans->n++; } casereader_destroy (cs); @@ -610,7 +610,7 @@ quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc /* A transformation function which juxtaposes the dataset with the (pre-prepared) dataset containing membership and/or distance values. */ -static int +static enum trns_result save_trans_func (void *aux, struct ccase **c, casenumber x UNUSED) { const struct save_trans_data *std = aux; @@ -621,10 +621,10 @@ save_trans_func (void *aux, struct ccase **c, casenumber x UNUSED) *c = case_unshare (*c); if (std->CASE_IDX_MEMBERSHIP >= 0) - case_data_rw (*c, std->membership)->f = case_data_idx (ca, std->CASE_IDX_MEMBERSHIP)->f; + *case_num_rw (*c, std->membership) = case_num_idx (ca, std->CASE_IDX_MEMBERSHIP); if (std->CASE_IDX_DISTANCE >= 0) - case_data_rw (*c, std->distance)->f = case_data_idx (ca, std->CASE_IDX_DISTANCE)->f; + *case_num_rw (*c, std->distance) = case_num_idx (ca, std->CASE_IDX_DISTANCE); case_unref (ca); @@ -651,8 +651,8 @@ quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, struct qc *qc) { - struct pivot_table *table; - struct pivot_dimension *cases; + struct pivot_table *table = NULL; + struct pivot_dimension *cases = NULL; if (qc->print_cluster_membership) { table = pivot_table_create (N_("Cluster Membership")); @@ -708,10 +708,10 @@ quick_cluster_show_membership (struct Kmeans *kmeans, /* Calculate the membership and distance values. */ struct ccase *outc = case_create (proto); if (qc->save_values & SAVE_MEMBERSHIP) - case_data_rw_idx (outc, qc->save_trans_data->CASE_IDX_MEMBERSHIP)->f = cluster + 1; + *case_num_rw_idx (outc, qc->save_trans_data->CASE_IDX_MEMBERSHIP) = cluster + 1; if (qc->save_values & SAVE_DISTANCE) - case_data_rw_idx (outc, qc->save_trans_data->CASE_IDX_DISTANCE)->f + *case_num_rw_idx (outc, qc->save_trans_data->CASE_IDX_DISTANCE) = sqrt (dist_from_case (kmeans, c, qc, clust)); casewriter_write (qc->save_trans_data->writer, outc); @@ -783,22 +783,12 @@ quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *read quick_cluster_show_membership (kmeans, reader, qc); } -int -cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) +/* Parse the QUICK CLUSTER command and populate QC accordingly. + Returns false on error. */ +static bool +quick_cluster_parse (struct lexer *lexer, struct qc *qc) { - struct qc qc; - struct Kmeans *kmeans; - bool ok; - memset (&qc, 0, sizeof qc); - qc.dataset = ds; - qc.dict = dataset_dict (ds); - qc.ngroups = 2; - qc.maxiter = 10; - qc.epsilon = DBL_EPSILON; - qc.missing_type = MISS_LISTWISE; - qc.exclude = MV_ANY; - - if (!parse_variables_const (lexer, qc.dict, &qc.vars, &qc.n_vars, + if (!parse_variables_const (lexer, qc->dict, &qc->vars, &qc->n_vars, PV_NO_DUPLICATE | PV_NUMERIC)) { return (CMD_FAILURE); @@ -817,24 +807,24 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT")) { - qc.missing_type = MISS_LISTWISE; + qc->missing_type = MISS_LISTWISE; } else if (lex_match_id (lexer, "PAIRWISE")) { - qc.missing_type = MISS_PAIRWISE; + qc->missing_type = MISS_PAIRWISE; } else if (lex_match_id (lexer, "INCLUDE")) { - qc.exclude = MV_SYSTEM; + qc->exclude = MV_SYSTEM; } else if (lex_match_id (lexer, "EXCLUDE")) { - qc.exclude = MV_ANY; + qc->exclude = MV_ANY; } else { lex_error (lexer, NULL); - goto error; + return false; } } } @@ -845,13 +835,13 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) && lex_token (lexer) != T_SLASH) { if (lex_match_id (lexer, "CLUSTER")) - qc.print_cluster_membership = true; + qc->print_cluster_membership = true; else if (lex_match_id (lexer, "INITIAL")) - qc.print_initial_clusters = true; + qc->print_initial_clusters = true; else { lex_error (lexer, NULL); - goto error; + return false; } } } @@ -863,61 +853,61 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) { if (lex_match_id (lexer, "CLUSTER")) { - qc.save_values |= SAVE_MEMBERSHIP; + qc->save_values |= SAVE_MEMBERSHIP; if (lex_match (lexer, T_LPAREN)) { if (!lex_force_id (lexer)) - goto error; + return false; - free (qc.var_membership); - qc.var_membership = xstrdup (lex_tokcstr (lexer)); - if (NULL != dict_lookup_var (dataset_dict (ds), qc.var_membership)) + free (qc->var_membership); + qc->var_membership = xstrdup (lex_tokcstr (lexer)); + if (NULL != dict_lookup_var (qc->dict, qc->var_membership)) { lex_error (lexer, _("A variable called `%s' already exists."), - qc.var_membership); - free (qc.var_membership); - qc.var_membership = NULL; - goto error; + qc->var_membership); + free (qc->var_membership); + qc->var_membership = NULL; + return false; } lex_get (lexer); if (!lex_force_match (lexer, T_RPAREN)) - goto error; + return false; } } else if (lex_match_id (lexer, "DISTANCE")) { - qc.save_values |= SAVE_DISTANCE; + qc->save_values |= SAVE_DISTANCE; if (lex_match (lexer, T_LPAREN)) { if (!lex_force_id (lexer)) - goto error; + return false; - free (qc.var_distance); - qc.var_distance = xstrdup (lex_tokcstr (lexer)); - if (NULL != dict_lookup_var (dataset_dict (ds), qc.var_distance)) + free (qc->var_distance); + qc->var_distance = xstrdup (lex_tokcstr (lexer)); + if (NULL != dict_lookup_var (qc->dict, qc->var_distance)) { lex_error (lexer, _("A variable called `%s' already exists."), - qc.var_distance); - free (qc.var_distance); - qc.var_distance = NULL; - goto error; + qc->var_distance); + free (qc->var_distance); + qc->var_distance = NULL; + return false; } lex_get (lexer); if (!lex_force_match (lexer, T_RPAREN)) - goto error; + return false; } } else { lex_error (lexer, _("Expecting %s or %s."), "CLUSTER", "DISTANCE"); - goto error; + return false; } } } @@ -930,72 +920,78 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) if (lex_match_id (lexer, "CLUSTERS")) { if (lex_force_match (lexer, T_LPAREN) && - lex_force_int (lexer)) + lex_force_int_range (lexer, "CLUSTERS", 1, INT_MAX)) { - qc.ngroups = lex_integer (lexer); - if (qc.ngroups <= 0) - { - lex_error (lexer, _("The number of clusters must be positive")); - goto error; - } + qc->ngroups = lex_integer (lexer); lex_get (lexer); if (!lex_force_match (lexer, T_RPAREN)) - goto error; + return false; } } else if (lex_match_id (lexer, "CONVERGE")) { if (lex_force_match (lexer, T_LPAREN) && - lex_force_num (lexer)) + lex_force_num_range_open (lexer, "CONVERGE", 0, DBL_MAX)) { - qc.epsilon = lex_number (lexer); - if (qc.epsilon <= 0) - { - lex_error (lexer, _("The convergence criterion must be positive")); - goto error; - } + qc->epsilon = lex_number (lexer); lex_get (lexer); if (!lex_force_match (lexer, T_RPAREN)) - goto error; + return false; } } else if (lex_match_id (lexer, "MXITER")) { if (lex_force_match (lexer, T_LPAREN) && - lex_force_int (lexer)) + lex_force_int_range (lexer, "MXITER", 1, INT_MAX)) { - qc.maxiter = lex_integer (lexer); - if (qc.maxiter <= 0) - { - lex_error (lexer, _("The number of iterations must be positive")); - goto error; - } + qc->maxiter = lex_integer (lexer); lex_get (lexer); if (!lex_force_match (lexer, T_RPAREN)) - goto error; + return false; } } else if (lex_match_id (lexer, "NOINITIAL")) { - qc.no_initial = true; + qc->no_initial = true; } else if (lex_match_id (lexer, "NOUPDATE")) { - qc.no_update = true; + qc->no_update = true; } else { lex_error (lexer, NULL); - goto error; + return false; } } } else { lex_error (lexer, NULL); - goto error; + return false; } } + return true; +} + +int +cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) +{ + struct qc qc; + struct Kmeans *kmeans; + bool ok; + memset (&qc, 0, sizeof qc); + qc.dataset = ds; + qc.dict = dataset_dict (ds); + qc.ngroups = 2; + qc.maxiter = 10; + qc.epsilon = DBL_EPSILON; + qc.missing_type = MISS_LISTWISE; + qc.exclude = MV_ANY; + + + if (!quick_cluster_parse (lexer, &qc)) + goto error; qc.wv = dict_get_weight (qc.dict); @@ -1005,7 +1001,7 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) while (casegrouper_get_next_group (grouper, &group)) { - if ( qc.missing_type == MISS_LISTWISE ) + if (qc.missing_type == MISS_LISTWISE) { group = casereader_create_filter_missing (group, qc.vars, qc.n_vars, qc.exclude, @@ -1075,7 +1071,12 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) std->distance = dict_create_var_assert (qc.dict, qc.var_distance, 0); } - add_transformation (qc.dataset, save_trans_func, save_trans_destroy, std); + static const struct trns_class trns_class = { + .name = "QUICK CLUSTER", + .execute = save_trans_func, + .destroy = save_trans_destroy, + }; + add_transformation (qc.dataset, &trns_class, std); } free (qc.var_distance);