QUICK CLUSTER: Implement the /SAVE sub-command.
authorJohn Darrington <john@darrington.wattle.id.au>
Tue, 7 May 2019 08:07:05 +0000 (10:07 +0200)
committerJohn Darrington <john@darrington.wattle.id.au>
Tue, 7 May 2019 18:16:08 +0000 (20:16 +0200)
NEWS
doc/statistics.texi
src/language/stats/quick-cluster.c
tests/language/stats/quick-cluster.at

diff --git a/NEWS b/NEWS
index 5d51210d421af740558b3331a484a95c56be191f..00bf0c4e50df65add0f5612e64096d6c7c7933e3 100644 (file)
--- a/NEWS
+++ b/NEWS
@@ -21,9 +21,13 @@ Changes from 1.2.0 to 1.3.0:
  * The EXAMINE command will now perform the Shapiro-Wilk test when
    one or more plots are requested.
 
-* The REGRESSION command now supports the /STATISTICS=TOL which
+ * The REGRESSION command now supports the /STATISTICS=TOL option which
    outputs tolerance and variance inflation factor metrics for the data.
 
+ * The QUICK CLUSTER command now supports the /SAVE option which can
+   be used to save the cases' cluster membership and/or their distance
+   from the cluster centre to the active file.
+
  * A bug where the GUI would crash when T-TEST was executed whilst
    a filter was set has been fixed.
 
index 259c9abe53e1626449f251668d25258394879724..51cbb95516b80b59f7458086df54ee980714e92c 100644 (file)
@@ -1819,6 +1819,7 @@ QUICK CLUSTER @var{var_list}
       [/CRITERIA=CLUSTERS(@var{k}) [MXITER(@var{max_iter})] CONVERGE(@var{epsilon}) [NOINITIAL]]
       [/MISSING=@{EXCLUDE,INCLUDE@} @{LISTWISE, PAIRWISE@}]
       [/PRINT=@{INITIAL@} @{CLUSTER@}]
+      [/SAVE[=[CLUSTER[(@var{membership_var})]] [DISTANCE[(@var{distance_var})]]]
 @end display
 
 The @cmd{QUICK CLUSTER} command performs k-means clustering on the
@@ -1871,6 +1872,12 @@ be printed.
 If @subcmd{CLUSTER} is set, the cluster memberships of the individual
 cases will be displayed (potentially generating lengthy output).
 
+You can specify the subcommand @subcmd{SAVE} to ask that each case's cluster membership
+and the euclidean distance between the case and its cluster center be saved to
+a new variable in the active dataset.   To save the cluster membership use the
+@subcmd{CLUSTER} keyword and to save the distance use the @subcmd{DISTANCE} keyword.
+Each keyword may optionally be followed by a variable in parenthesis to specify
+the new variable which is to contain the saved parameter.
 
 @node RANK
 @section RANK
index 42a463970484e17123ddb6fb8a74c190ff3bb51b..d20d0f3f1e430b4073126016d14b92e6eb3c7893 100644 (file)
@@ -1,5 +1,5 @@
 /* PSPP - a program for statistical analysis.
-   Copyright (C) 2011, 2012, 2015 Free Software Foundation, Inc.
+   Copyright (C) 2011, 2012, 2015, 2019 Free Software Foundation, Inc.
 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
@@ -53,8 +53,34 @@ enum missing_type
   };
 
 
+struct save_trans_data
+{
+  /* A writer which contains the values (if any) to be appended to
+     each case in the active dataset   */
+  struct casewriter *writer;
+
+  /* A reader created from the writer above. */
+  struct casereader *appending_reader;
+
+  /* The indices to be used to access values in the above,
+     reader/writer  */
+  int CASE_IDX_MEMBERSHIP;
+  int CASE_IDX_DISTANCE;
+
+  /* The variables created to hold the values appended to the dataset  */
+  struct variable *membership;
+  struct variable *distance;
+};
+
+
+#define SAVE_MEMBERSHIP 0x1
+#define SAVE_DISTANCE   0x2
+
 struct qc
 {
+  struct dataset *dataset;
+  struct dictionary *dict;
+
   const struct variable **vars;
   size_t n_vars;
 
@@ -71,6 +97,18 @@ struct qc
 
   enum missing_type missing_type;
   enum mv_class exclude;
+
+  /* Which values are to be saved?  */
+  int save_values;
+
+  /* The name of the new variable to contain the cluster of each case.  */
+  char *var_membership;
+
+  /* The name of the new variable to contain the distance of each case
+     from its cluster centre.  */
+  char *var_distance;
+
+  struct save_trans_data *save_trans_data;
 };
 
 /* Holds all of the information for the functions.  int n, holds the number of
@@ -105,14 +143,14 @@ static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial,
 
 static void quick_cluster_show_membership (struct Kmeans *kmeans,
                                           const struct casereader *reader,
-                                          const struct qc *);
+                                          struct qc *);
 
 static void quick_cluster_show_number_cases (struct Kmeans *kmeans,
                                             const struct qc *);
 
 static void quick_cluster_show_results (struct Kmeans *kmeans,
                                        const struct casereader *reader,
-                                       const struct qc *);
+                                       struct qc *);
 
 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
 
@@ -568,25 +606,94 @@ quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc
   pivot_table_submit (table);
 }
 
-/* Reports cluster membership for each case. */
+
+/* A transformation function which juxtaposes the dataset with the
+   (pre-prepared) dataset containing membership and/or distance
+   values.  */
+static int
+save_trans_func (void *aux, struct ccase **c, casenumber x UNUSED)
+{
+  const struct save_trans_data *std = aux;
+  struct ccase *ca  = casereader_read (std->appending_reader);
+  if (ca == NULL)
+    return TRNS_CONTINUE;
+
+  *c = case_unshare (*c);
+
+  if (std->CASE_IDX_MEMBERSHIP >= 0)
+    case_data_rw (*c, std->membership)->f = case_data_idx (ca, std->CASE_IDX_MEMBERSHIP)->f;
+
+  if (std->CASE_IDX_DISTANCE >= 0)
+    case_data_rw (*c, std->distance)->f = case_data_idx (ca, std->CASE_IDX_DISTANCE)->f;
+
+  case_unref (ca);
+
+  return TRNS_CONTINUE;
+}
+
+
+/* Free the resources of the transformation.  */
+static bool
+save_trans_destroy (void *aux)
+{
+  struct save_trans_data *std = aux;
+  casereader_destroy (std->appending_reader);
+  free (std);
+  return true;
+}
+
+
+/* Reports cluster membership for each case, and is requested
+saves the membership and the distance of the case from the cluster
+centre.  */
 static void
 quick_cluster_show_membership (struct Kmeans *kmeans,
                               const struct casereader *reader,
-                              const struct qc *qc)
+                              struct qc *qc)
 {
-  struct pivot_table *table = pivot_table_create (N_("Cluster Membership"));
+  struct pivot_table *table;
+  struct pivot_dimension *cases;
+  if (qc->print_cluster_membership)
+    {
+      table = pivot_table_create (N_("Cluster Membership"));
 
-  pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"),
-                          N_("Cluster"));
+      pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"),
+                             N_("Cluster"));
 
-  struct pivot_dimension *cases
-    = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Case Number"));
+      cases
+       = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Case Number"));
 
-  cases->root->show_label = true;
+      cases->root->show_label = true;
+    }
 
   gsl_permutation *ip = gsl_permutation_alloc (qc->ngroups);
   gsl_permutation_inverse (ip, kmeans->group_order);
 
+  struct caseproto *proto = caseproto_create ();
+  if (qc->save_values)
+    {
+      /* Prepare data which may potentially be used in a
+        transformation appending new variables to the active
+        dataset.  */
+      qc->save_trans_data = xzalloc (sizeof *qc->save_trans_data);
+      qc->save_trans_data->CASE_IDX_MEMBERSHIP = -1;
+      qc->save_trans_data->CASE_IDX_DISTANCE = -1;
+      qc->save_trans_data->writer = autopaging_writer_create (proto);
+
+      int idx = 0;
+      if (qc->save_values & SAVE_MEMBERSHIP)
+       {
+         proto = caseproto_add_width (proto, 0);
+         qc->save_trans_data->CASE_IDX_MEMBERSHIP = idx++;
+       }
+
+      if (qc->save_values & SAVE_DISTANCE)
+       {
+         proto = caseproto_add_width (proto, 0);
+         qc->save_trans_data->CASE_IDX_DISTANCE = idx++;
+       }
+    }
+
   struct casereader *cs = casereader_clone (reader);
   struct ccase *c;
   for (int i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c))
@@ -596,14 +703,35 @@ quick_cluster_show_membership (struct Kmeans *kmeans,
       kmeans_get_nearest_group (kmeans, c, qc, &clust, NULL, NULL, NULL);
       int cluster = ip->data[clust];
 
-      int case_idx = pivot_category_create_leaf (cases->root,
+      if (qc->save_trans_data)
+      {
+       /* Calculate the membership and distance values.  */
+       struct ccase *outc = case_create (proto);
+       if (qc->save_values & SAVE_MEMBERSHIP)
+         case_data_rw_idx (outc, qc->save_trans_data->CASE_IDX_MEMBERSHIP)->f = cluster + 1;
+
+       if (qc->save_values & SAVE_DISTANCE)
+         case_data_rw_idx (outc, qc->save_trans_data->CASE_IDX_DISTANCE)->f
+           = sqrt (dist_from_case (kmeans, c, qc, clust));
+
+       casewriter_write (qc->save_trans_data->writer, outc);
+      }
+
+      if (qc->print_cluster_membership)
+       {
+         /* Print the cluster membership to the table.  */
+         int case_idx = pivot_category_create_leaf (cases->root,
                                                 pivot_value_new_integer (i + 1));
-      pivot_table_put2 (table, 0, case_idx,
-                        pivot_value_new_integer (cluster + 1));
+         pivot_table_put2 (table, 0, case_idx,
+                           pivot_value_new_integer (cluster + 1));
+       }
     }
 
+  caseproto_unref (proto);
   gsl_permutation_free (ip);
-  pivot_table_submit (table);
+
+  if (qc->print_cluster_membership)
+    pivot_table_submit (table);
   casereader_destroy (cs);
 }
 
@@ -643,7 +771,7 @@ quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
 /* Reports. */
 static void
 quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader,
-                           const struct qc *qc)
+                           struct qc *qc)
 {
   kmeans_order_groups (kmeans, qc); /* what does this do? */
 
@@ -651,8 +779,8 @@ quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *read
     quick_cluster_show_centers (kmeans, true, qc);
   quick_cluster_show_centers (kmeans, false, qc);
   quick_cluster_show_number_cases (kmeans, qc);
-  if (qc->print_cluster_membership)
-    quick_cluster_show_membership (kmeans, reader, qc);
+
+  quick_cluster_show_membership (kmeans, reader, qc);
 }
 
 int
@@ -661,18 +789,16 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
   struct qc qc;
   struct Kmeans *kmeans;
   bool ok;
-  const struct dictionary *dict = dataset_dict (ds);
+  memset (&qc, 0, sizeof qc);
+  qc.dataset = ds;
+  qc.dict =  dataset_dict (ds);
   qc.ngroups = 2;
   qc.maxiter = 10;
   qc.epsilon = DBL_EPSILON;
   qc.missing_type = MISS_LISTWISE;
   qc.exclude = MV_ANY;
-  qc.print_cluster_membership = false; /* default = do not output case cluster membership */
-  qc.print_initial_clusters = false;   /* default = do not print initial clusters */
-  qc.no_initial = false;               /* default = use well separated initial clusters */
-  qc.no_update = false;               /* default = iterate until convergence or max iterations */
 
-  if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
+  if (!parse_variables_const (lexer, qc.dict, &qc.vars, &qc.n_vars,
                              PV_NO_DUPLICATE | PV_NUMERIC))
     {
       return (CMD_FAILURE);
@@ -729,6 +855,72 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
                }
            }
        }
+      else if (lex_match_id (lexer, "SAVE"))
+       {
+         lex_match (lexer, T_EQUALS);
+         while (lex_token (lexer) != T_ENDCMD
+                && lex_token (lexer) != T_SLASH)
+           {
+             if (lex_match_id (lexer, "CLUSTER"))
+               {
+                 qc.save_values |= SAVE_MEMBERSHIP;
+                 if (lex_match (lexer, T_LPAREN))
+                   {
+                     if (!lex_force_id (lexer))
+                       goto error;
+
+                     free (qc.var_membership);
+                     qc.var_membership = xstrdup (lex_tokcstr (lexer));
+                     if (NULL != dict_lookup_var (dataset_dict (ds), qc.var_membership))
+                       {
+                         lex_error (lexer,
+                                    _("A variable called `%s' already exists."),
+                                    qc.var_membership);
+                         free (qc.var_membership);
+                         qc.var_membership = NULL;
+                         goto error;
+                       }
+
+                     lex_get (lexer);
+
+                     if (!lex_force_match (lexer, T_RPAREN))
+                       goto error;
+                   }
+               }
+             else if (lex_match_id (lexer, "DISTANCE"))
+               {
+                 qc.save_values |= SAVE_DISTANCE;
+                 if (lex_match (lexer, T_LPAREN))
+                   {
+                     if (!lex_force_id (lexer))
+                       goto error;
+
+                     free (qc.var_distance);
+                     qc.var_distance = xstrdup (lex_tokcstr (lexer));
+                     if (NULL != dict_lookup_var (dataset_dict (ds), qc.var_distance))
+                       {
+                         lex_error (lexer,
+                                    _("A variable called `%s' already exists."),
+                                    qc.var_distance);
+                         free (qc.var_distance);
+                         qc.var_distance = NULL;
+                         goto error;
+                       }
+
+                     lex_get (lexer);
+
+                     if (!lex_force_match (lexer, T_RPAREN))
+                       goto error;
+                   }
+               }
+             else
+               {
+                 lex_error (lexer, _("Expecting %s or %s."),
+                            "CLUSTER", "DISTANCE");
+                 goto error;
+               }
+           }
+       }
       else if (lex_match_id (lexer, "CRITERIA"))
        {
          lex_match (lexer, T_EQUALS);
@@ -805,11 +997,11 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
         }
     }
 
-  qc.wv = dict_get_weight (dict);
+  qc.wv = dict_get_weight (qc.dict);
 
   {
     struct casereader *group;
-    struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
+    struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), qc.dict);
 
     while (casegrouper_get_next_group (grouper, &group))
       {
@@ -830,11 +1022,70 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
   }
   ok = proc_commit (ds) && ok;
 
-  free (qc.vars);
 
+  /* If requested, set a transformation to append the cluster and
+     distance values to the current dataset.  */
+  if (qc.save_trans_data)
+    {
+      struct save_trans_data *std = qc.save_trans_data;
+      std->appending_reader = casewriter_make_reader (std->writer);
+      std->writer = NULL;
+
+      if (qc.save_values & SAVE_MEMBERSHIP)
+       {
+         /* Invent a variable name if necessary.  */
+         int idx = 0;
+         struct string name;
+         ds_init_empty (&name);
+         while (qc.var_membership == NULL)
+           {
+             ds_clear (&name);
+             ds_put_format (&name, "QCL_%d", idx++);
+
+             if (!dict_lookup_var (qc.dict, ds_cstr (&name)))
+               {
+                 qc.var_distance = strdup (ds_cstr (&name));
+                 break;
+               }
+           }
+         ds_destroy (&name);
+
+         std->membership = dict_create_var_assert (qc.dict, qc.var_membership, 0);
+       }
+
+      if (qc.save_values & SAVE_DISTANCE)
+       {
+         /* Invent a variable name if necessary.  */
+         int idx = 0;
+         struct string name;
+         ds_init_empty (&name);
+         while (qc.var_distance == NULL)
+           {
+             ds_clear (&name);
+             ds_put_format (&name, "QCL_%d", idx++);
+
+             if (!dict_lookup_var (qc.dict, ds_cstr (&name)))
+               {
+                 qc.var_distance = strdup (ds_cstr (&name));
+                 break;
+               }
+           }
+         ds_destroy (&name);
+
+         std->distance = dict_create_var_assert (qc.dict, qc.var_distance, 0);
+       }
+
+      add_transformation (qc.dataset, save_trans_func, save_trans_destroy, std);
+    }
+
+  free (qc.var_distance);
+  free (qc.var_membership);
+  free (qc.vars);
   return (ok);
 
  error:
+  free (qc.var_distance);
+  free (qc.var_membership);
   free (qc.vars);
   return CMD_FAILURE;
 }
index 34c04945e98150bbb3bc01b74b5ff4cf476e91d6..34294468c62679269e7e3cdea9fc7440774938ce 100644 (file)
@@ -447,3 +447,141 @@ AT_CHECK([pspp -o pspp.csv empty-parens.sps], [1], [ignore])
 
 AT_CLEANUP
 
+
+
+AT_SETUP([QUICK CLUSTER with save])
+AT_DATA([quick-cluster.sps], [dnl
+DATA LIST notable LIST /x y z.
+BEGIN DATA.
+22,2930,4099
+17,3350,4749
+22,2640,3799
+20, 3250,4816
+15,4080,7827
+4,5,4
+5,6,5
+6,7,6
+7,8,7
+8,9,8
+9,10,9
+END DATA.
+QUICK CLUSTER x y z
+  /CRITERIA=CLUSTER(2) MXITER(20)
+  /SAVE = CLUSTER (cluster) DISTANCE (distance).
+
+list.
+])
+
+AT_CHECK([pspp -O format=csv quick-cluster.sps], [0], [dnl
+Table: Final Cluster Centers
+,Cluster,
+,1,2
+x,6.50,19.20
+y,7.50,3250.00
+z,6.50,5058.00
+
+Table: Number of Cases in each Cluster
+,,Count
+Cluster,1,6
+,2,5
+Valid,,11
+
+Table: Data List
+x,y,z,cluster,distance
+22.00,2930.00,4099.00,2.00,1010.98
+17.00,3350.00,4749.00,2.00,324.79
+22.00,2640.00,3799.00,2.00,1399.00
+20.00,3250.00,4816.00,2.00,242.00
+15.00,4080.00,7827.00,2.00,2890.72
+4.00,5.00,4.00,1.00,4.33
+5.00,6.00,5.00,1.00,2.60
+6.00,7.00,6.00,1.00,.87
+7.00,8.00,7.00,1.00,.87
+8.00,9.00,8.00,1.00,2.60
+9.00,10.00,9.00,1.00,4.33
+])
+AT_CLEANUP
+
+
+AT_SETUP([QUICK CLUSTER with single save])
+AT_DATA([quick-cluster.sps], [dnl
+DATA LIST notable LIST /x y z.
+BEGIN DATA.
+22,2930,4099
+17,3350,4749
+22,2640,3799
+20, 3250,4816
+15,4080,7827
+4,5,4
+5,6,5
+6,7,6
+7,8,7
+8,9,8
+9,10,9
+END DATA.
+QUICK CLUSTER x y z
+  /CRITERIA=CLUSTER(2) MXITER(20)
+  /SAVE = DISTANCE.
+
+list.
+])
+
+AT_CHECK([pspp -O format=csv quick-cluster.sps], [0], [dnl
+Table: Final Cluster Centers
+,Cluster,
+,1,2
+x,6.50,19.20
+y,7.50,3250.00
+z,6.50,5058.00
+
+Table: Number of Cases in each Cluster
+,,Count
+Cluster,1,6
+,2,5
+Valid,,11
+
+Table: Data List
+x,y,z,QCL_0
+22.00,2930.00,4099.00,1010.98
+17.00,3350.00,4749.00,324.79
+22.00,2640.00,3799.00,1399.00
+20.00,3250.00,4816.00,242.00
+15.00,4080.00,7827.00,2890.72
+4.00,5.00,4.00,4.33
+5.00,6.00,5.00,2.60
+6.00,7.00,6.00,.87
+7.00,8.00,7.00,.87
+8.00,9.00,8.00,2.60
+9.00,10.00,9.00,4.33
+])
+AT_CLEANUP
+
+
+dnl This one was noticed to crash at one point.
+AT_SETUP([QUICK CLUSTER crash on bizarre input])
+AT_DATA([badn.sps], [dnl
+data list notable list /x.
+begin da\a*
+22
+17
+22
+22
+15
+4,
+5,
+6,
+7,j8,
+9,
+end data.
+
+quick cluster x
+" /criteria=cluster(2) mxiter(20)
+  /save = distance 
+  .
+
+list.
+])
+
+AT_CHECK([pspp -O format=csv badn.sps], [1], [ignore])
+
+AT_CLEANUP