X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?a=blobdiff_plain;f=src%2Flanguage%2Fstats%2Fquick-cluster.c;h=be04af929fb1dc435a090f481965d359b054ab20;hb=60c545e6e958d868db3399a8989d37d8f9e0c131;hp=1570e34d3e4a736b1cd147f1171612d088f89ce7;hpb=7635ce0697c163bd9c80adb8b382df7a9aa97f42;p=pspp diff --git a/src/language/stats/quick-cluster.c b/src/language/stats/quick-cluster.c index 1570e34d3e..be04af929f 100644 --- a/src/language/stats/quick-cluster.c +++ b/src/language/stats/quick-cluster.c @@ -1,5 +1,5 @@ /* PSPP - a program for statistical analysis. - Copyright (C) 2011, 2012, 2015 Free Software Foundation, Inc. + Copyright (C) 2011, 2012, 2015, 2019 Free Software Foundation, Inc. This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -39,8 +39,8 @@ #include "libpspp/assertion.h" #include "libpspp/str.h" #include "math/random.h" -#include "output/tab.h" -#include "output/text-item.h" +#include "output/pivot-table.h" +#include "output/output-item.h" #include "gettext.h" #define _(msgid) gettext (msgid) @@ -53,12 +53,38 @@ enum missing_type }; +struct save_trans_data +{ + /* A writer which contains the values (if any) to be appended to + each case in the active dataset */ + struct casewriter *writer; + + /* A reader created from the writer above. */ + struct casereader *appending_reader; + + /* The indices to be used to access values in the above, + reader/writer */ + int CASE_IDX_MEMBERSHIP; + int CASE_IDX_DISTANCE; + + /* The variables created to hold the values appended to the dataset */ + struct variable *membership; + struct variable *distance; +}; + + +#define SAVE_MEMBERSHIP 0x1 +#define SAVE_DISTANCE 0x2 + struct qc { + struct dataset *dataset; + struct dictionary *dict; + const struct variable **vars; size_t n_vars; - double epsilon; /* The convergence criterium */ + double epsilon; /* The convergence criterion */ int ngroups; /* Number of group. (Given by the user) */ int maxiter; /* Maximum iterations (Given by the user) */ @@ -71,6 +97,18 @@ struct qc enum missing_type missing_type; enum mv_class exclude; + + /* Which values are to be saved? */ + int save_values; + + /* The name of the new variable to contain the cluster of each case. */ + char *var_membership; + + /* The name of the new variable to contain the distance of each case + from its cluster centre. */ + char *var_distance; + + struct save_trans_data *save_trans_data; }; /* Holds all of the information for the functions. int n, holds the number of @@ -91,19 +129,28 @@ struct Kmeans static struct Kmeans *kmeans_create (const struct qc *qc); -static void kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c, const struct qc *, int *, double *, int *, double *); +static void kmeans_get_nearest_group (const struct Kmeans *kmeans, + struct ccase *c, const struct qc *, + int *, double *, int *, double *); static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *); -static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *); +static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, + const struct qc *); -static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *); +static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, + const struct qc *); -static void quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *); +static void quick_cluster_show_membership (struct Kmeans *kmeans, + const struct casereader *reader, + struct qc *); -static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *); +static void quick_cluster_show_number_cases (struct Kmeans *kmeans, + const struct qc *); -static void quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *); +static void quick_cluster_show_results (struct Kmeans *kmeans, + const struct casereader *reader, + struct qc *); int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds); @@ -149,7 +196,7 @@ diff_matrix (const gsl_matrix *m1, const gsl_matrix *m2) double diff = 0; for (j = 0; j < m1->size2; ++j) { - diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j) ); + diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j)); } if (diff > max_diff) max_diff = diff; @@ -160,7 +207,7 @@ diff_matrix (const gsl_matrix *m1, const gsl_matrix *m2) -static double +static double matrix_mindist (const gsl_matrix *m, int *mn, int *mm) { int i, j; @@ -192,19 +239,20 @@ matrix_mindist (const gsl_matrix *m, int *mn, int *mm) /* Return the distance of C from the group whose index is WHICH */ static double -dist_from_case (const struct Kmeans *kmeans, const struct ccase *c, const struct qc *qc, int which) +dist_from_case (const struct Kmeans *kmeans, const struct ccase *c, + const struct qc *qc, int which) { int j; double dist = 0; for (j = 0; j < qc->n_vars; j++) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val, qc->exclude)) NOT_REACHED (); - + dist += pow2 (gsl_matrix_get (kmeans->centers, which, j) - val->f); } - + return dist; } @@ -223,9 +271,10 @@ min_dist_from (const struct Kmeans *kmeans, const struct qc *qc, int which) double dist = 0; for (j = 0; j < qc->n_vars; j++) { - dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - gsl_matrix_get (kmeans->centers, which, j)); + dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) + - gsl_matrix_get (kmeans->centers, which, j)); } - + if (dist < mindist) { mindist = dist; @@ -237,9 +286,11 @@ min_dist_from (const struct Kmeans *kmeans, const struct qc *qc, int which) -/* Calculate the intial cluster centers. */ +/* Calculate the initial cluster centers. */ static void -kmeans_initial_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc) +kmeans_initial_centers (struct Kmeans *kmeans, + const struct casereader *reader, + const struct qc *qc) { struct ccase *c; int nc = 0, j; @@ -251,7 +302,7 @@ kmeans_initial_centers (struct Kmeans *kmeans, const struct casereader *reader, for (j = 0; j < qc->n_vars; ++j) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val, qc->exclude)) { missing = true; break; @@ -278,10 +329,12 @@ kmeans_initial_centers (struct Kmeans *kmeans, const struct casereader *reader, kmeans_get_nearest_group (kmeans, c, qc, &mq, &delta, &mp, NULL); if (delta > m) /* If the distance between C and the nearest group, is greater than the distance - between the two groups which are clostest to each other, then one group must be replaced */ + between the two groups which are clostest to each + other, then one group must be replaced. */ { /* Out of mn and mm, which is the clostest of the two groups to C ? */ - int which = (dist_from_case (kmeans, c, qc, mn) > dist_from_case (kmeans, c, qc, mm)) ? mm : mn; + int which = (dist_from_case (kmeans, c, qc, mn) + > dist_from_case (kmeans, c, qc, mm)) ? mm : mn; for (j = 0; j < qc->n_vars; ++j) { @@ -290,9 +343,10 @@ kmeans_initial_centers (struct Kmeans *kmeans, const struct casereader *reader, } } else if (dist_from_case (kmeans, c, qc, mp) > min_dist_from (kmeans, qc, mq)) - /* If the distance between C and the second nearest group (MP) is greater than the - smallest distance between the nearest group (MQ) and any other group, then replace - MQ with C */ + /* If the distance between C and the second nearest group + (MP) is greater than the smallest distance between the + nearest group (MQ) and any other group, then replace + MQ with C. */ { for (j = 0; j < qc->n_vars; ++j) { @@ -316,7 +370,9 @@ kmeans_initial_centers (struct Kmeans *kmeans, const struct casereader *reader, /* Return the index of the group which is nearest to the case C */ static void -kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c, const struct qc *qc, int *g_q, double *delta_q, int *g_p, double *delta_p) +kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c, + const struct qc *qc, int *g_q, double *delta_q, + int *g_p, double *delta_p) { int result0 = -1; int result1 = -1; @@ -329,7 +385,7 @@ kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c, const st for (j = 0; j < qc->n_vars; j++) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val, qc->exclude)) continue; dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f); @@ -378,7 +434,8 @@ kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc) /* Main algorithm. Does iterations, checks convergency. */ static void -kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc) +kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, + const struct qc *qc) { int j; @@ -405,10 +462,10 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct q for (j = 0; j < qc->n_vars; j++) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val, qc->exclude)) missing = true; } - + if (missing) continue; @@ -425,18 +482,18 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct q } long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group); - *n += qc->wv ? case_data (c, qc->wv)->f : 1.0; + *n += qc->wv ? case_num (c, qc->wv) : 1.0; kmeans->n++; for (j = 0; j < qc->n_vars; ++j) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val, qc->exclude)) continue; double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j); - *x += val->f * (qc->wv ? case_data (c, qc->wv)->f : 1.0); + *x += val->f * (qc->wv ? case_num (c, qc->wv) : 1.0); } - } + } casereader_destroy (r); } @@ -453,27 +510,26 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct q *x /= n + 1; // Plus 1 for the initial centers } } - + gsl_matrix_memcpy (kmeans->centers, kmeans->updated_centers); { kmeans->n = 0; - int i; /* Step 3 */ gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0); gsl_matrix_set_all (kmeans->updated_centers, 0.0); struct ccase *c; struct casereader *cs = casereader_clone (reader); - for (; (c = casereader_read (cs)) != NULL; i++, case_unref (c)) + for (; (c = casereader_read (cs)) != NULL; case_unref (c)) { - int group = -1; + int group = -1; kmeans_get_nearest_group (kmeans, c, qc, &group, NULL, NULL, NULL); for (j = 0; j < qc->n_vars; ++j) { const union value *val = case_data (c, qc->vars[j]); - if ( var_is_value_missing (qc->vars[j], val, qc->exclude)) + if (var_is_value_missing (qc->vars[j], val, qc->exclude)) continue; double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j); @@ -481,10 +537,8 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct q } long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group); - *n += qc->wv ? case_data (c, qc->wv)->f : 1.0; + *n += qc->wv ? case_num (c, qc->wv) : 1.0; kmeans->n++; - - } casereader_destroy (cs); @@ -517,97 +571,167 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct q static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc) { - struct tab_table *t; - int nc, nr, currow; - int i, j; - nc = qc->ngroups + 1; - nr = qc->n_vars + 4; - t = tab_create (nc, nr); - tab_headers (t, 0, nc - 1, 0, 1); - currow = 0; - if (!initial) - { - tab_title (t, _("Final Cluster Centers")); - } - else - { - tab_title (t, _("Initial Cluster Centers")); - } - tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1); - tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster")); - tab_hline (t, TAL_1, 1, nc - 1, 2); - currow += 2; + struct pivot_table *table + = pivot_table_create (initial + ? N_("Initial Cluster Centers") + : N_("Final Cluster Centers")); + + struct pivot_dimension *clusters + = pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster")); + + clusters->root->show_label = true; + for (size_t i = 0; i < qc->ngroups; i++) + pivot_category_create_leaf (clusters->root, + pivot_value_new_integer (i + 1)); + + struct pivot_dimension *variables + = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Variable")); + + for (size_t i = 0; i < qc->n_vars; i++) + pivot_category_create_leaf (variables->root, + pivot_value_new_variable (qc->vars[i])); + + const gsl_matrix *matrix = (initial + ? kmeans->initial_centers + : kmeans->centers); + for (size_t i = 0; i < qc->ngroups; i++) + for (size_t j = 0; j < qc->n_vars; j++) + { + double x = gsl_matrix_get (matrix, kmeans->group_order->data[i], j); + union value v = { .f = x }; + pivot_table_put2 (table, i, j, + pivot_value_new_var_value (qc->vars[j], &v)); + } - for (i = 0; i < qc->ngroups; i++) - { - tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1)); - } - currow++; - tab_hline (t, TAL_1, 1, nc - 1, currow); - currow++; - for (i = 0; i < qc->n_vars; i++) - { - tab_text (t, 0, currow + i, TAB_LEFT, - var_to_string (qc->vars[i])); - } + pivot_table_submit (table); +} - for (i = 0; i < qc->ngroups; i++) - { - for (j = 0; j < qc->n_vars; j++) - { - if (!initial) - { - tab_double (t, i + 1, j + 4, TAB_CENTER, - gsl_matrix_get (kmeans->centers, - kmeans->group_order->data[i], j), - var_get_print_format (qc->vars[j]), RC_OTHER); - } - else - { - tab_double (t, i + 1, j + 4, TAB_CENTER, - gsl_matrix_get (kmeans->initial_centers, - kmeans->group_order->data[i], j), - var_get_print_format (qc->vars[j]), RC_OTHER); - } - } - } - tab_submit (t); + +/* A transformation function which juxtaposes the dataset with the + (pre-prepared) dataset containing membership and/or distance + values. */ +static enum trns_result +save_trans_func (void *aux, struct ccase **c, casenumber x UNUSED) +{ + const struct save_trans_data *std = aux; + struct ccase *ca = casereader_read (std->appending_reader); + if (ca == NULL) + return TRNS_CONTINUE; + + *c = case_unshare (*c); + + if (std->CASE_IDX_MEMBERSHIP >= 0) + *case_num_rw (*c, std->membership) = case_num_idx (ca, std->CASE_IDX_MEMBERSHIP); + + if (std->CASE_IDX_DISTANCE >= 0) + *case_num_rw (*c, std->distance) = case_num_idx (ca, std->CASE_IDX_DISTANCE); + + case_unref (ca); + + return TRNS_CONTINUE; +} + + +/* Free the resources of the transformation. */ +static bool +save_trans_destroy (void *aux) +{ + struct save_trans_data *std = aux; + casereader_destroy (std->appending_reader); + free (std); + return true; } -/* Reports cluster membership for each case. */ + +/* Reports cluster membership for each case, and is requested +saves the membership and the distance of the case from the cluster +centre. */ static void -quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc) +quick_cluster_show_membership (struct Kmeans *kmeans, + const struct casereader *reader, + struct qc *qc) { - struct tab_table *t; - int nc, nr, i; + struct pivot_table *table = NULL; + struct pivot_dimension *cases = NULL; + if (qc->print_cluster_membership) + { + table = pivot_table_create (N_("Cluster Membership")); - struct ccase *c; - struct casereader *cs = casereader_clone (reader); - nc = 2; - nr = kmeans->n + 1; - t = tab_create (nc, nr); - tab_headers (t, 0, nc - 1, 0, 0); - tab_title (t, _("Cluster Membership")); - tab_text (t, 0, 0, TAB_CENTER, _("Case Number")); - tab_text (t, 1, 0, TAB_CENTER, _("Cluster")); - tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1); - tab_hline (t, TAL_1, 0, nc - 1, 1); + pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"), + N_("Cluster")); + + cases + = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Case Number")); + + cases->root->show_label = true; + } gsl_permutation *ip = gsl_permutation_alloc (qc->ngroups); gsl_permutation_inverse (ip, kmeans->group_order); - for (i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c)) + struct caseproto *proto = caseproto_create (); + if (qc->save_values) + { + /* Prepare data which may potentially be used in a + transformation appending new variables to the active + dataset. */ + qc->save_trans_data = xzalloc (sizeof *qc->save_trans_data); + qc->save_trans_data->CASE_IDX_MEMBERSHIP = -1; + qc->save_trans_data->CASE_IDX_DISTANCE = -1; + qc->save_trans_data->writer = autopaging_writer_create (proto); + + int idx = 0; + if (qc->save_values & SAVE_MEMBERSHIP) + { + proto = caseproto_add_width (proto, 0); + qc->save_trans_data->CASE_IDX_MEMBERSHIP = idx++; + } + + if (qc->save_values & SAVE_DISTANCE) + { + proto = caseproto_add_width (proto, 0); + qc->save_trans_data->CASE_IDX_DISTANCE = idx++; + } + } + + struct casereader *cs = casereader_clone (reader); + struct ccase *c; + for (int i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c)) { - int clust = -1; assert (i < kmeans->n); + int clust; kmeans_get_nearest_group (kmeans, c, qc, &clust, NULL, NULL, NULL); - clust = ip->data[clust]; - tab_text_format (t, 0, i+1, TAB_CENTER, "%d", (i + 1)); - tab_text_format (t, 1, i+1, TAB_CENTER, "%d", (clust + 1)); + int cluster = ip->data[clust]; + + if (qc->save_trans_data) + { + /* Calculate the membership and distance values. */ + struct ccase *outc = case_create (proto); + if (qc->save_values & SAVE_MEMBERSHIP) + *case_num_rw_idx (outc, qc->save_trans_data->CASE_IDX_MEMBERSHIP) = cluster + 1; + + if (qc->save_values & SAVE_DISTANCE) + *case_num_rw_idx (outc, qc->save_trans_data->CASE_IDX_DISTANCE) + = sqrt (dist_from_case (kmeans, c, qc, clust)); + + casewriter_write (qc->save_trans_data->writer, outc); + } + + if (qc->print_cluster_membership) + { + /* Print the cluster membership to the table. */ + int case_idx = pivot_category_create_leaf (cases->root, + pivot_value_new_integer (i + 1)); + pivot_table_put2 (table, 0, case_idx, + pivot_value_new_integer (cluster + 1)); + } } + + caseproto_unref (proto); gsl_permutation_free (ip); - assert (i == kmeans->n); - tab_submit (t); + + if (qc->print_cluster_membership) + pivot_table_submit (table); casereader_destroy (cs); } @@ -616,65 +740,55 @@ quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *r static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc) { - struct tab_table *t; - int nc, nr; - int i, numelem; - long int total; - nc = 3; - nr = qc->ngroups + 1; - t = tab_create (nc, nr); - tab_headers (t, 0, nc - 1, 0, 0); - tab_title (t, _("Number of Cases in each Cluster")); - tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1); - tab_text (t, 0, 0, TAB_LEFT, _("Cluster")); - - total = 0; - for (i = 0; i < qc->ngroups; i++) + struct pivot_table *table + = pivot_table_create (N_("Number of Cases in each Cluster")); + + pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"), + N_("Count")); + + struct pivot_dimension *clusters + = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Clusters")); + + struct pivot_category *group + = pivot_category_create_group (clusters->root, N_("Cluster")); + + long int total = 0; + for (int i = 0; i < qc->ngroups; i++) { - tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1)); - numelem = - kmeans->num_elements_groups->data[kmeans->group_order->data[i]]; - tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem); - total += numelem; + int cluster_idx + = pivot_category_create_leaf (group, pivot_value_new_integer (i + 1)); + int count = kmeans->num_elements_groups->data [kmeans->group_order->data[i]]; + pivot_table_put2 (table, 0, cluster_idx, pivot_value_new_integer (count)); + total += count; } - tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid")); - tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total); - tab_submit (t); + int cluster_idx = pivot_category_create_leaf (clusters->root, + pivot_value_new_text (N_("Valid"))); + pivot_table_put2 (table, 0, cluster_idx, pivot_value_new_integer (total)); + pivot_table_submit (table); } /* Reports. */ static void -quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc) +quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, + struct qc *qc) { kmeans_order_groups (kmeans, qc); /* what does this do? */ - - if( qc->print_initial_clusters ) + + if (qc->print_initial_clusters) quick_cluster_show_centers (kmeans, true, qc); quick_cluster_show_centers (kmeans, false, qc); quick_cluster_show_number_cases (kmeans, qc); - if( qc->print_cluster_membership ) - quick_cluster_show_membership(kmeans, reader, qc); + + quick_cluster_show_membership (kmeans, reader, qc); } -int -cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) +/* Parse the QUICK CLUSTER command and populate QC accordingly. + Returns false on error. */ +static bool +quick_cluster_parse (struct lexer *lexer, struct qc *qc) { - struct qc qc; - struct Kmeans *kmeans; - bool ok; - const struct dictionary *dict = dataset_dict (ds); - qc.ngroups = 2; - qc.maxiter = 10; - qc.epsilon = DBL_EPSILON; - qc.missing_type = MISS_LISTWISE; - qc.exclude = MV_ANY; - qc.print_cluster_membership = false; /* default = do not output case cluster membership */ - qc.print_initial_clusters = false; /* default = do not print initial clusters */ - qc.no_initial = false; /* default = use well separated initial clusters */ - qc.no_update = false; /* default = iterate until convergence or max iterations */ - - if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars, + if (!parse_variables_const (lexer, qc->dict, &qc->vars, &qc->n_vars, PV_NO_DUPLICATE | PV_NUMERIC)) { return (CMD_FAILURE); @@ -690,28 +804,29 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH) { - if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT")) + if (lex_match_id (lexer, "LISTWISE") + || lex_match_id (lexer, "DEFAULT")) { - qc.missing_type = MISS_LISTWISE; + qc->missing_type = MISS_LISTWISE; } else if (lex_match_id (lexer, "PAIRWISE")) { - qc.missing_type = MISS_PAIRWISE; + qc->missing_type = MISS_PAIRWISE; } else if (lex_match_id (lexer, "INCLUDE")) { - qc.exclude = MV_SYSTEM; + qc->exclude = MV_SYSTEM; } else if (lex_match_id (lexer, "EXCLUDE")) { - qc.exclude = MV_ANY; + qc->exclude = MV_ANY; } else { lex_error (lexer, NULL); - goto error; + return false; } - } + } } else if (lex_match_id (lexer, "PRINT")) { @@ -720,13 +835,79 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) && lex_token (lexer) != T_SLASH) { if (lex_match_id (lexer, "CLUSTER")) - qc.print_cluster_membership = true; + qc->print_cluster_membership = true; else if (lex_match_id (lexer, "INITIAL")) - qc.print_initial_clusters = true; + qc->print_initial_clusters = true; else { lex_error (lexer, NULL); - goto error; + return false; + } + } + } + else if (lex_match_id (lexer, "SAVE")) + { + lex_match (lexer, T_EQUALS); + while (lex_token (lexer) != T_ENDCMD + && lex_token (lexer) != T_SLASH) + { + if (lex_match_id (lexer, "CLUSTER")) + { + qc->save_values |= SAVE_MEMBERSHIP; + if (lex_match (lexer, T_LPAREN)) + { + if (!lex_force_id (lexer)) + return false; + + free (qc->var_membership); + qc->var_membership = xstrdup (lex_tokcstr (lexer)); + if (NULL != dict_lookup_var (qc->dict, qc->var_membership)) + { + lex_error (lexer, + _("A variable called `%s' already exists."), + qc->var_membership); + free (qc->var_membership); + qc->var_membership = NULL; + return false; + } + + lex_get (lexer); + + if (!lex_force_match (lexer, T_RPAREN)) + return false; + } + } + else if (lex_match_id (lexer, "DISTANCE")) + { + qc->save_values |= SAVE_DISTANCE; + if (lex_match (lexer, T_LPAREN)) + { + if (!lex_force_id (lexer)) + return false; + + free (qc->var_distance); + qc->var_distance = xstrdup (lex_tokcstr (lexer)); + if (NULL != dict_lookup_var (qc->dict, qc->var_distance)) + { + lex_error (lexer, + _("A variable called `%s' already exists."), + qc->var_distance); + free (qc->var_distance); + qc->var_distance = NULL; + return false; + } + + lex_get (lexer); + + if (!lex_force_match (lexer, T_RPAREN)) + return false; + } + } + else + { + lex_error (lexer, _("Expecting %s or %s."), + "CLUSTER", "DISTANCE"); + return false; } } } @@ -739,17 +920,12 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) if (lex_match_id (lexer, "CLUSTERS")) { if (lex_force_match (lexer, T_LPAREN) && - lex_force_int (lexer)) + lex_force_int_range (lexer, "CLUSTERS", 1, INT_MAX)) { - qc.ngroups = lex_integer (lexer); - if (qc.ngroups <= 0) - { - lex_error (lexer, _("The number of clusters must be positive")); - goto error; - } + qc->ngroups = lex_integer (lexer); lex_get (lexer); if (!lex_force_match (lexer, T_RPAREN)) - goto error; + return false; } } else if (lex_match_id (lexer, "CONVERGE")) @@ -757,66 +933,82 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) if (lex_force_match (lexer, T_LPAREN) && lex_force_num (lexer)) { - qc.epsilon = lex_number (lexer); - if (qc.epsilon <= 0) + qc->epsilon = lex_number (lexer); + if (qc->epsilon <= 0) { - lex_error (lexer, _("The convergence criterium must be positive")); - goto error; + lex_error (lexer, _("The convergence criterion must be positive")); + return false; } lex_get (lexer); if (!lex_force_match (lexer, T_RPAREN)) - goto error; + return false; } } else if (lex_match_id (lexer, "MXITER")) { if (lex_force_match (lexer, T_LPAREN) && - lex_force_int (lexer)) + lex_force_int_range (lexer, "MXITER", 1, INT_MAX)) { - qc.maxiter = lex_integer (lexer); - if (qc.maxiter <= 0) - { - lex_error (lexer, _("The number of iterations must be positive")); - goto error; - } + qc->maxiter = lex_integer (lexer); lex_get (lexer); if (!lex_force_match (lexer, T_RPAREN)) - goto error; + return false; } } else if (lex_match_id (lexer, "NOINITIAL")) { - qc.no_initial = true; + qc->no_initial = true; } else if (lex_match_id (lexer, "NOUPDATE")) { - qc.no_update = true; + qc->no_update = true; } else { lex_error (lexer, NULL); - goto error; + return false; } } } else { lex_error (lexer, NULL); - goto error; + return false; } } + return true; +} + +int +cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) +{ + struct qc qc; + struct Kmeans *kmeans; + bool ok; + memset (&qc, 0, sizeof qc); + qc.dataset = ds; + qc.dict = dataset_dict (ds); + qc.ngroups = 2; + qc.maxiter = 10; + qc.epsilon = DBL_EPSILON; + qc.missing_type = MISS_LISTWISE; + qc.exclude = MV_ANY; + + + if (!quick_cluster_parse (lexer, &qc)) + goto error; - qc.wv = dict_get_weight (dict); + qc.wv = dict_get_weight (qc.dict); { struct casereader *group; - struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict); + struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), qc.dict); while (casegrouper_get_next_group (grouper, &group)) { - if ( qc.missing_type == MISS_LISTWISE ) + if (qc.missing_type == MISS_LISTWISE) { - group = casereader_create_filter_missing (group, qc.vars, qc.n_vars, + group = casereader_create_filter_missing (group, qc.vars, qc.n_vars, qc.exclude, NULL, NULL); } @@ -831,11 +1023,75 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) } ok = proc_commit (ds) && ok; - free (qc.vars); + /* If requested, set a transformation to append the cluster and + distance values to the current dataset. */ + if (qc.save_trans_data) + { + struct save_trans_data *std = qc.save_trans_data; + std->appending_reader = casewriter_make_reader (std->writer); + std->writer = NULL; + + if (qc.save_values & SAVE_MEMBERSHIP) + { + /* Invent a variable name if necessary. */ + int idx = 0; + struct string name; + ds_init_empty (&name); + while (qc.var_membership == NULL) + { + ds_clear (&name); + ds_put_format (&name, "QCL_%d", idx++); + + if (!dict_lookup_var (qc.dict, ds_cstr (&name))) + { + qc.var_membership = strdup (ds_cstr (&name)); + break; + } + } + ds_destroy (&name); + + std->membership = dict_create_var_assert (qc.dict, qc.var_membership, 0); + } + + if (qc.save_values & SAVE_DISTANCE) + { + /* Invent a variable name if necessary. */ + int idx = 0; + struct string name; + ds_init_empty (&name); + while (qc.var_distance == NULL) + { + ds_clear (&name); + ds_put_format (&name, "QCL_%d", idx++); + + if (!dict_lookup_var (qc.dict, ds_cstr (&name))) + { + qc.var_distance = strdup (ds_cstr (&name)); + break; + } + } + ds_destroy (&name); + + std->distance = dict_create_var_assert (qc.dict, qc.var_distance, 0); + } + + static const struct trns_class trns_class = { + .name = "QUICK CLUSTER", + .execute = save_trans_func, + .destroy = save_trans_destroy, + }; + add_transformation (qc.dataset, &trns_class, std); + } + + free (qc.var_distance); + free (qc.var_membership); + free (qc.vars); return (ok); error: + free (qc.var_distance); + free (qc.var_membership); free (qc.vars); return CMD_FAILURE; }