From eee5d5538428d0d74c757552f74a3af0d2c4d0f5 Mon Sep 17 00:00:00 2001 From: John Darrington Date: Tue, 7 May 2019 10:07:05 +0200 Subject: [PATCH] QUICK CLUSTER: Implement the /SAVE sub-command. --- NEWS | 6 +- doc/statistics.texi | 7 + src/language/stats/quick-cluster.c | 305 +++++++++++++++++++++++--- tests/language/stats/quick-cluster.at | 138 ++++++++++++ 4 files changed, 428 insertions(+), 28 deletions(-) diff --git a/NEWS b/NEWS index 5d51210d42..00bf0c4e50 100644 --- a/NEWS +++ b/NEWS @@ -21,9 +21,13 @@ Changes from 1.2.0 to 1.3.0: * The EXAMINE command will now perform the Shapiro-Wilk test when one or more plots are requested. -* The REGRESSION command now supports the /STATISTICS=TOL which + * The REGRESSION command now supports the /STATISTICS=TOL option which outputs tolerance and variance inflation factor metrics for the data. + * The QUICK CLUSTER command now supports the /SAVE option which can + be used to save the cases' cluster membership and/or their distance + from the cluster centre to the active file. + * A bug where the GUI would crash when T-TEST was executed whilst a filter was set has been fixed. diff --git a/doc/statistics.texi b/doc/statistics.texi index 259c9abe53..51cbb95516 100644 --- a/doc/statistics.texi +++ b/doc/statistics.texi @@ -1819,6 +1819,7 @@ QUICK CLUSTER @var{var_list} [/CRITERIA=CLUSTERS(@var{k}) [MXITER(@var{max_iter})] CONVERGE(@var{epsilon}) [NOINITIAL]] [/MISSING=@{EXCLUDE,INCLUDE@} @{LISTWISE, PAIRWISE@}] [/PRINT=@{INITIAL@} @{CLUSTER@}] + [/SAVE[=[CLUSTER[(@var{membership_var})]] [DISTANCE[(@var{distance_var})]]] @end display The @cmd{QUICK CLUSTER} command performs k-means clustering on the @@ -1871,6 +1872,12 @@ be printed. If @subcmd{CLUSTER} is set, the cluster memberships of the individual cases will be displayed (potentially generating lengthy output). +You can specify the subcommand @subcmd{SAVE} to ask that each case's cluster membership +and the euclidean distance between the case and its cluster center be saved to +a new variable in the active dataset. To save the cluster membership use the +@subcmd{CLUSTER} keyword and to save the distance use the @subcmd{DISTANCE} keyword. +Each keyword may optionally be followed by a variable in parenthesis to specify +the new variable which is to contain the saved parameter. @node RANK @section RANK diff --git a/src/language/stats/quick-cluster.c b/src/language/stats/quick-cluster.c index 42a4639704..d20d0f3f1e 100644 --- a/src/language/stats/quick-cluster.c +++ b/src/language/stats/quick-cluster.c @@ -1,5 +1,5 @@ /* PSPP - a program for statistical analysis. - Copyright (C) 2011, 2012, 2015 Free Software Foundation, Inc. + Copyright (C) 2011, 2012, 2015, 2019 Free Software Foundation, Inc. This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -53,8 +53,34 @@ enum missing_type }; +struct save_trans_data +{ + /* A writer which contains the values (if any) to be appended to + each case in the active dataset */ + struct casewriter *writer; + + /* A reader created from the writer above. */ + struct casereader *appending_reader; + + /* The indices to be used to access values in the above, + reader/writer */ + int CASE_IDX_MEMBERSHIP; + int CASE_IDX_DISTANCE; + + /* The variables created to hold the values appended to the dataset */ + struct variable *membership; + struct variable *distance; +}; + + +#define SAVE_MEMBERSHIP 0x1 +#define SAVE_DISTANCE 0x2 + struct qc { + struct dataset *dataset; + struct dictionary *dict; + const struct variable **vars; size_t n_vars; @@ -71,6 +97,18 @@ struct qc enum missing_type missing_type; enum mv_class exclude; + + /* Which values are to be saved? */ + int save_values; + + /* The name of the new variable to contain the cluster of each case. */ + char *var_membership; + + /* The name of the new variable to contain the distance of each case + from its cluster centre. */ + char *var_distance; + + struct save_trans_data *save_trans_data; }; /* Holds all of the information for the functions. int n, holds the number of @@ -105,14 +143,14 @@ static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, static void quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, - const struct qc *); + struct qc *); static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *); static void quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, - const struct qc *); + struct qc *); int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds); @@ -568,25 +606,94 @@ quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc pivot_table_submit (table); } -/* Reports cluster membership for each case. */ + +/* A transformation function which juxtaposes the dataset with the + (pre-prepared) dataset containing membership and/or distance + values. */ +static int +save_trans_func (void *aux, struct ccase **c, casenumber x UNUSED) +{ + const struct save_trans_data *std = aux; + struct ccase *ca = casereader_read (std->appending_reader); + if (ca == NULL) + return TRNS_CONTINUE; + + *c = case_unshare (*c); + + if (std->CASE_IDX_MEMBERSHIP >= 0) + case_data_rw (*c, std->membership)->f = case_data_idx (ca, std->CASE_IDX_MEMBERSHIP)->f; + + if (std->CASE_IDX_DISTANCE >= 0) + case_data_rw (*c, std->distance)->f = case_data_idx (ca, std->CASE_IDX_DISTANCE)->f; + + case_unref (ca); + + return TRNS_CONTINUE; +} + + +/* Free the resources of the transformation. */ +static bool +save_trans_destroy (void *aux) +{ + struct save_trans_data *std = aux; + casereader_destroy (std->appending_reader); + free (std); + return true; +} + + +/* Reports cluster membership for each case, and is requested +saves the membership and the distance of the case from the cluster +centre. */ static void quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, - const struct qc *qc) + struct qc *qc) { - struct pivot_table *table = pivot_table_create (N_("Cluster Membership")); + struct pivot_table *table; + struct pivot_dimension *cases; + if (qc->print_cluster_membership) + { + table = pivot_table_create (N_("Cluster Membership")); - pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"), - N_("Cluster")); + pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"), + N_("Cluster")); - struct pivot_dimension *cases - = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Case Number")); + cases + = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Case Number")); - cases->root->show_label = true; + cases->root->show_label = true; + } gsl_permutation *ip = gsl_permutation_alloc (qc->ngroups); gsl_permutation_inverse (ip, kmeans->group_order); + struct caseproto *proto = caseproto_create (); + if (qc->save_values) + { + /* Prepare data which may potentially be used in a + transformation appending new variables to the active + dataset. */ + qc->save_trans_data = xzalloc (sizeof *qc->save_trans_data); + qc->save_trans_data->CASE_IDX_MEMBERSHIP = -1; + qc->save_trans_data->CASE_IDX_DISTANCE = -1; + qc->save_trans_data->writer = autopaging_writer_create (proto); + + int idx = 0; + if (qc->save_values & SAVE_MEMBERSHIP) + { + proto = caseproto_add_width (proto, 0); + qc->save_trans_data->CASE_IDX_MEMBERSHIP = idx++; + } + + if (qc->save_values & SAVE_DISTANCE) + { + proto = caseproto_add_width (proto, 0); + qc->save_trans_data->CASE_IDX_DISTANCE = idx++; + } + } + struct casereader *cs = casereader_clone (reader); struct ccase *c; for (int i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c)) @@ -596,14 +703,35 @@ quick_cluster_show_membership (struct Kmeans *kmeans, kmeans_get_nearest_group (kmeans, c, qc, &clust, NULL, NULL, NULL); int cluster = ip->data[clust]; - int case_idx = pivot_category_create_leaf (cases->root, + if (qc->save_trans_data) + { + /* Calculate the membership and distance values. */ + struct ccase *outc = case_create (proto); + if (qc->save_values & SAVE_MEMBERSHIP) + case_data_rw_idx (outc, qc->save_trans_data->CASE_IDX_MEMBERSHIP)->f = cluster + 1; + + if (qc->save_values & SAVE_DISTANCE) + case_data_rw_idx (outc, qc->save_trans_data->CASE_IDX_DISTANCE)->f + = sqrt (dist_from_case (kmeans, c, qc, clust)); + + casewriter_write (qc->save_trans_data->writer, outc); + } + + if (qc->print_cluster_membership) + { + /* Print the cluster membership to the table. */ + int case_idx = pivot_category_create_leaf (cases->root, pivot_value_new_integer (i + 1)); - pivot_table_put2 (table, 0, case_idx, - pivot_value_new_integer (cluster + 1)); + pivot_table_put2 (table, 0, case_idx, + pivot_value_new_integer (cluster + 1)); + } } + caseproto_unref (proto); gsl_permutation_free (ip); - pivot_table_submit (table); + + if (qc->print_cluster_membership) + pivot_table_submit (table); casereader_destroy (cs); } @@ -643,7 +771,7 @@ quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc) /* Reports. */ static void quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, - const struct qc *qc) + struct qc *qc) { kmeans_order_groups (kmeans, qc); /* what does this do? */ @@ -651,8 +779,8 @@ quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *read quick_cluster_show_centers (kmeans, true, qc); quick_cluster_show_centers (kmeans, false, qc); quick_cluster_show_number_cases (kmeans, qc); - if (qc->print_cluster_membership) - quick_cluster_show_membership (kmeans, reader, qc); + + quick_cluster_show_membership (kmeans, reader, qc); } int @@ -661,18 +789,16 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) struct qc qc; struct Kmeans *kmeans; bool ok; - const struct dictionary *dict = dataset_dict (ds); + memset (&qc, 0, sizeof qc); + qc.dataset = ds; + qc.dict = dataset_dict (ds); qc.ngroups = 2; qc.maxiter = 10; qc.epsilon = DBL_EPSILON; qc.missing_type = MISS_LISTWISE; qc.exclude = MV_ANY; - qc.print_cluster_membership = false; /* default = do not output case cluster membership */ - qc.print_initial_clusters = false; /* default = do not print initial clusters */ - qc.no_initial = false; /* default = use well separated initial clusters */ - qc.no_update = false; /* default = iterate until convergence or max iterations */ - if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars, + if (!parse_variables_const (lexer, qc.dict, &qc.vars, &qc.n_vars, PV_NO_DUPLICATE | PV_NUMERIC)) { return (CMD_FAILURE); @@ -729,6 +855,72 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) } } } + else if (lex_match_id (lexer, "SAVE")) + { + lex_match (lexer, T_EQUALS); + while (lex_token (lexer) != T_ENDCMD + && lex_token (lexer) != T_SLASH) + { + if (lex_match_id (lexer, "CLUSTER")) + { + qc.save_values |= SAVE_MEMBERSHIP; + if (lex_match (lexer, T_LPAREN)) + { + if (!lex_force_id (lexer)) + goto error; + + free (qc.var_membership); + qc.var_membership = xstrdup (lex_tokcstr (lexer)); + if (NULL != dict_lookup_var (dataset_dict (ds), qc.var_membership)) + { + lex_error (lexer, + _("A variable called `%s' already exists."), + qc.var_membership); + free (qc.var_membership); + qc.var_membership = NULL; + goto error; + } + + lex_get (lexer); + + if (!lex_force_match (lexer, T_RPAREN)) + goto error; + } + } + else if (lex_match_id (lexer, "DISTANCE")) + { + qc.save_values |= SAVE_DISTANCE; + if (lex_match (lexer, T_LPAREN)) + { + if (!lex_force_id (lexer)) + goto error; + + free (qc.var_distance); + qc.var_distance = xstrdup (lex_tokcstr (lexer)); + if (NULL != dict_lookup_var (dataset_dict (ds), qc.var_distance)) + { + lex_error (lexer, + _("A variable called `%s' already exists."), + qc.var_distance); + free (qc.var_distance); + qc.var_distance = NULL; + goto error; + } + + lex_get (lexer); + + if (!lex_force_match (lexer, T_RPAREN)) + goto error; + } + } + else + { + lex_error (lexer, _("Expecting %s or %s."), + "CLUSTER", "DISTANCE"); + goto error; + } + } + } else if (lex_match_id (lexer, "CRITERIA")) { lex_match (lexer, T_EQUALS); @@ -805,11 +997,11 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) } } - qc.wv = dict_get_weight (dict); + qc.wv = dict_get_weight (qc.dict); { struct casereader *group; - struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict); + struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), qc.dict); while (casegrouper_get_next_group (grouper, &group)) { @@ -830,11 +1022,70 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds) } ok = proc_commit (ds) && ok; - free (qc.vars); + /* If requested, set a transformation to append the cluster and + distance values to the current dataset. */ + if (qc.save_trans_data) + { + struct save_trans_data *std = qc.save_trans_data; + std->appending_reader = casewriter_make_reader (std->writer); + std->writer = NULL; + + if (qc.save_values & SAVE_MEMBERSHIP) + { + /* Invent a variable name if necessary. */ + int idx = 0; + struct string name; + ds_init_empty (&name); + while (qc.var_membership == NULL) + { + ds_clear (&name); + ds_put_format (&name, "QCL_%d", idx++); + + if (!dict_lookup_var (qc.dict, ds_cstr (&name))) + { + qc.var_distance = strdup (ds_cstr (&name)); + break; + } + } + ds_destroy (&name); + + std->membership = dict_create_var_assert (qc.dict, qc.var_membership, 0); + } + + if (qc.save_values & SAVE_DISTANCE) + { + /* Invent a variable name if necessary. */ + int idx = 0; + struct string name; + ds_init_empty (&name); + while (qc.var_distance == NULL) + { + ds_clear (&name); + ds_put_format (&name, "QCL_%d", idx++); + + if (!dict_lookup_var (qc.dict, ds_cstr (&name))) + { + qc.var_distance = strdup (ds_cstr (&name)); + break; + } + } + ds_destroy (&name); + + std->distance = dict_create_var_assert (qc.dict, qc.var_distance, 0); + } + + add_transformation (qc.dataset, save_trans_func, save_trans_destroy, std); + } + + free (qc.var_distance); + free (qc.var_membership); + free (qc.vars); return (ok); error: + free (qc.var_distance); + free (qc.var_membership); free (qc.vars); return CMD_FAILURE; } diff --git a/tests/language/stats/quick-cluster.at b/tests/language/stats/quick-cluster.at index 34c04945e9..34294468c6 100644 --- a/tests/language/stats/quick-cluster.at +++ b/tests/language/stats/quick-cluster.at @@ -447,3 +447,141 @@ AT_CHECK([pspp -o pspp.csv empty-parens.sps], [1], [ignore]) AT_CLEANUP + + +AT_SETUP([QUICK CLUSTER with save]) +AT_DATA([quick-cluster.sps], [dnl +DATA LIST notable LIST /x y z. +BEGIN DATA. +22,2930,4099 +17,3350,4749 +22,2640,3799 +20, 3250,4816 +15,4080,7827 +4,5,4 +5,6,5 +6,7,6 +7,8,7 +8,9,8 +9,10,9 +END DATA. +QUICK CLUSTER x y z + /CRITERIA=CLUSTER(2) MXITER(20) + /SAVE = CLUSTER (cluster) DISTANCE (distance). + +list. +]) + +AT_CHECK([pspp -O format=csv quick-cluster.sps], [0], [dnl +Table: Final Cluster Centers +,Cluster, +,1,2 +x,6.50,19.20 +y,7.50,3250.00 +z,6.50,5058.00 + +Table: Number of Cases in each Cluster +,,Count +Cluster,1,6 +,2,5 +Valid,,11 + +Table: Data List +x,y,z,cluster,distance +22.00,2930.00,4099.00,2.00,1010.98 +17.00,3350.00,4749.00,2.00,324.79 +22.00,2640.00,3799.00,2.00,1399.00 +20.00,3250.00,4816.00,2.00,242.00 +15.00,4080.00,7827.00,2.00,2890.72 +4.00,5.00,4.00,1.00,4.33 +5.00,6.00,5.00,1.00,2.60 +6.00,7.00,6.00,1.00,.87 +7.00,8.00,7.00,1.00,.87 +8.00,9.00,8.00,1.00,2.60 +9.00,10.00,9.00,1.00,4.33 +]) +AT_CLEANUP + + +AT_SETUP([QUICK CLUSTER with single save]) +AT_DATA([quick-cluster.sps], [dnl +DATA LIST notable LIST /x y z. +BEGIN DATA. +22,2930,4099 +17,3350,4749 +22,2640,3799 +20, 3250,4816 +15,4080,7827 +4,5,4 +5,6,5 +6,7,6 +7,8,7 +8,9,8 +9,10,9 +END DATA. +QUICK CLUSTER x y z + /CRITERIA=CLUSTER(2) MXITER(20) + /SAVE = DISTANCE. + +list. +]) + +AT_CHECK([pspp -O format=csv quick-cluster.sps], [0], [dnl +Table: Final Cluster Centers +,Cluster, +,1,2 +x,6.50,19.20 +y,7.50,3250.00 +z,6.50,5058.00 + +Table: Number of Cases in each Cluster +,,Count +Cluster,1,6 +,2,5 +Valid,,11 + +Table: Data List +x,y,z,QCL_0 +22.00,2930.00,4099.00,1010.98 +17.00,3350.00,4749.00,324.79 +22.00,2640.00,3799.00,1399.00 +20.00,3250.00,4816.00,242.00 +15.00,4080.00,7827.00,2890.72 +4.00,5.00,4.00,4.33 +5.00,6.00,5.00,2.60 +6.00,7.00,6.00,.87 +7.00,8.00,7.00,.87 +8.00,9.00,8.00,2.60 +9.00,10.00,9.00,4.33 +]) +AT_CLEANUP + + +dnl This one was noticed to crash at one point. +AT_SETUP([QUICK CLUSTER crash on bizarre input]) +AT_DATA([badn.sps], [dnl +data list notable list /x. +begin da\a* +22 +17 +22 +22 +15 +4, +5, +6, +7,j8, +9, +end data. + +quick cluster x +" /criteria=cluster(2) mxiter(20) + /save = distance + . + +list. +]) + +AT_CHECK([pspp -O format=csv badn.sps], [1], [ignore]) + +AT_CLEANUP -- 2.30.2