4f512b475661819ca1b08a6de551595fccf21bbf
[pspp] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011, 2012, 2015, 2019 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25
26 #include "data/case.h"
27 #include "data/casegrouper.h"
28 #include "data/casereader.h"
29 #include "data/casewriter.h"
30 #include "data/dataset.h"
31 #include "data/dictionary.h"
32 #include "data/format.h"
33 #include "data/missing-values.h"
34 #include "language/command.h"
35 #include "language/lexer/lexer.h"
36 #include "language/lexer/variable-parser.h"
37 #include "libpspp/message.h"
38 #include "libpspp/misc.h"
39 #include "libpspp/assertion.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/pivot-table.h"
43 #include "output/output-item.h"
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49 enum missing_type
50   {
51     MISS_LISTWISE,
52     MISS_PAIRWISE,
53   };
54
55
56 struct save_trans_data
57   {
58     /* A writer which contains the values (if any) to be appended to
59        each case in the active dataset   */
60     struct casewriter *writer;
61
62     /* A reader created from the writer above. */
63     struct casereader *appending_reader;
64
65     /* The indices to be used to access values in the above,
66        reader/writer  */
67     int membership_case_idx;
68     int distance_case_idx;
69
70     /* The variables created to hold the values appended to the dataset  */
71     struct variable *membership;
72     struct variable *distance;
73   };
74
75
76 struct qc
77   {
78     struct dataset *dataset;
79     struct dictionary *dict;
80
81     const struct variable **vars;
82     size_t n_vars;
83
84     double epsilon;               /* The convergence criterion */
85
86     int ngroups;                        /* Number of group. (Given by the user) */
87     int maxiter;                        /* Maximum iterations (Given by the user) */
88     bool print_cluster_membership; /* true => print membership */
89     bool print_initial_clusters;   /* true => print initial cluster */
90     bool initial;             /* false => simplified initial cluster selection */
91     bool update;               /* false => do not iterate  */
92
93     const struct variable *wv;  /* Weighting variable. */
94
95     enum missing_type missing_type;
96     enum mv_class exclude;
97
98     /* Which values are to be saved?  */
99     bool save_membership;
100     bool save_distance;
101
102     /* The name of the new variable to contain the cluster of each case.  */
103     char *var_membership;
104
105     /* The name of the new variable to contain the distance of each case
106        from its cluster centre.  */
107     char *var_distance;
108
109     struct save_trans_data *save_trans_data;
110   };
111
112 /* Holds all of the information for the functions.  int n, holds the number of
113    observation and its default value is -1.  We set it in
114    kmeans_recalculate_centers in first invocation. */
115 struct Kmeans
116   {
117     gsl_matrix *centers;                /* Centers for groups. */
118     gsl_matrix *updated_centers;
119     casenumber n;
120
121     gsl_vector_long *num_elements_groups;
122
123     gsl_matrix *initial_centers;        /* Initial random centers. */
124     double convergence_criteria;
125     gsl_permutation *group_order;       /* Group order for reporting. */
126   };
127
128 static struct Kmeans *kmeans_create (const struct qc *);
129
130 static void kmeans_get_nearest_group (const struct Kmeans *,
131                                       struct ccase *, const struct qc *,
132                                       int *, double *, int *, double *);
133
134 static void kmeans_order_groups (struct Kmeans *, const struct qc *);
135
136 static void kmeans_cluster (struct Kmeans *, struct casereader *,
137                             const struct qc *);
138
139 static void quick_cluster_show_centers (struct Kmeans *, bool initial,
140                                         const struct qc *);
141
142 static void quick_cluster_show_membership (struct Kmeans *,
143                                            const struct casereader *,
144                                            struct qc *);
145
146 static void quick_cluster_show_number_cases (struct Kmeans *,
147                                              const struct qc *);
148
149 static void quick_cluster_show_results (struct Kmeans *,
150                                         const struct casereader *,
151                                         struct qc *);
152
153 int cmd_quick_cluster (struct lexer *, struct dataset *);
154
155 static void kmeans_destroy (struct Kmeans *);
156
157 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
158    variables 'variables', number of cases 'n', number of variables 'm', number
159    of clusters and amount of maximum iterations. */
160 static struct Kmeans *
161 kmeans_create (const struct qc *qc)
162 {
163   struct Kmeans *kmeans = xmalloc (sizeof *kmeans);
164   *kmeans = (struct Kmeans) {
165     .centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars),
166     .updated_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars),
167     .num_elements_groups = gsl_vector_long_alloc (qc->ngroups),
168     .group_order = gsl_permutation_alloc (qc->ngroups),
169   };
170   return kmeans;
171 }
172
173 static void
174 kmeans_destroy (struct Kmeans *kmeans)
175 {
176   gsl_matrix_free (kmeans->centers);
177   gsl_matrix_free (kmeans->updated_centers);
178   gsl_matrix_free (kmeans->initial_centers);
179
180   gsl_vector_long_free (kmeans->num_elements_groups);
181
182   gsl_permutation_free (kmeans->group_order);
183
184   free (kmeans);
185 }
186
187 static double
188 diff_matrix (const gsl_matrix *m1, const gsl_matrix *m2)
189 {
190   double max_diff = -INFINITY;
191   for (size_t i = 0; i < m1->size1; ++i)
192     {
193       double diff = 0;
194       for (size_t j = 0; j < m1->size2; ++j)
195         diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j));
196       if (diff > max_diff)
197         max_diff = diff;
198     }
199
200   return max_diff;
201 }
202
203
204
205 static double
206 matrix_mindist (const gsl_matrix *m, int *mn, int *mm)
207 {
208   double mindist = INFINITY;
209   for (size_t i = 0; i + 1 < m->size1; ++i)
210     for (size_t j = i + 1; j < m->size1; ++j)
211       {
212         double diff_sq = 0;
213         for (size_t k = 0; k < m->size2; ++k)
214           diff_sq += pow2 (gsl_matrix_get (m, j, k) - gsl_matrix_get (m, i, k));
215         if (diff_sq < mindist)
216           {
217             mindist = diff_sq;
218             if (mn)
219               *mn = i;
220             if (mm)
221               *mm = j;
222           }
223       }
224   return mindist;
225 }
226
227 /* Return the distance of C from the group whose index is WHICH */
228 static double
229 dist_from_case (const struct Kmeans *kmeans, const struct ccase *c,
230                 const struct qc *qc, int which)
231 {
232   double dist = 0;
233   for (size_t j = 0; j < qc->n_vars; j++)
234     {
235       const union value *val = case_data (c, qc->vars[j]);
236       assert (!(var_is_value_missing (qc->vars[j], val) & qc->exclude));
237       dist += pow2 (gsl_matrix_get (kmeans->centers, which, j) - val->f);
238     }
239
240   return dist;
241 }
242
243 /* Return the minimum distance of the group WHICH and all other groups */
244 static double
245 min_dist_from (const struct Kmeans *kmeans, const struct qc *qc, int which)
246 {
247    double mindist = INFINITY;
248   for (size_t i = 0; i < qc->ngroups; i++)
249     {
250       if (i == which)
251         continue;
252
253       double dist = 0;
254       for (size_t j = 0; j < qc->n_vars; j++)
255         dist += pow2 (gsl_matrix_get (kmeans->centers, i, j)
256                       - gsl_matrix_get (kmeans->centers, which, j));
257
258       if (dist < mindist)
259         mindist = dist;
260     }
261
262   return mindist;
263 }
264
265 /* Calculate the initial cluster centers. */
266 static void
267 kmeans_initial_centers (struct Kmeans *kmeans,
268                         const struct casereader *reader,
269                         const struct qc *qc)
270 {
271   int nc = 0;
272
273   struct casereader *cs = casereader_clone (reader);
274   struct ccase *c;
275   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
276     {
277       bool missing = false;
278       for (size_t j = 0; j < qc->n_vars; ++j)
279         {
280           const union value *val = case_data (c, qc->vars[j]);
281           if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
282             {
283               missing = true;
284               break;
285             }
286
287           if (nc < qc->ngroups)
288             gsl_matrix_set (kmeans->centers, nc, j, val->f);
289         }
290       if (missing)
291         continue;
292
293       if (nc++ < qc->ngroups)
294         continue;
295
296       if (qc->initial)
297         {
298           int mn, mm;
299           double m = matrix_mindist (kmeans->centers, &mn, &mm);
300
301           int mq, mp;
302           double delta;
303           kmeans_get_nearest_group (kmeans, c, qc, &mq, &delta, &mp, NULL);
304           if (delta > m)
305             /* If the distance between C and the nearest group, is greater than the distance
306                between the two  groups which are clostest to each
307                other, then one group must be replaced.  */
308             {
309               /* Out of mn and mm, which is the clostest of the two groups to C ? */
310               int which = (dist_from_case (kmeans, c, qc, mn)
311                            > dist_from_case (kmeans, c, qc, mm)) ? mm : mn;
312
313               for (size_t j = 0; j < qc->n_vars; ++j)
314                 {
315                   const union value *val = case_data (c, qc->vars[j]);
316                   gsl_matrix_set (kmeans->centers, which, j, val->f);
317                 }
318             }
319           else if (dist_from_case (kmeans, c, qc, mp) > min_dist_from (kmeans, qc, mq))
320             /* If the distance between C and the second nearest group
321                (MP) is greater than the smallest distance between the
322                nearest group (MQ) and any other group, then replace
323                MQ with C.  */
324             {
325               for (size_t j = 0; j < qc->n_vars; ++j)
326                 {
327                   const union value *val = case_data (c, qc->vars[j]);
328                   gsl_matrix_set (kmeans->centers, mq, j, val->f);
329                 }
330             }
331         }
332     }
333
334   casereader_destroy (cs);
335
336   kmeans->convergence_criteria = qc->epsilon * matrix_mindist (kmeans->centers, NULL, NULL);
337
338   /* As it is the first iteration, the variable kmeans->initial_centers is NULL
339      and it is created once for reporting issues. */
340   kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
341   gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
342 }
343
344 /* Return the index of the group which is nearest to the case C */
345 static void
346 kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c,
347                           const struct qc *qc, int *g_q, double *delta_q,
348                           int *g_p, double *delta_p)
349 {
350   int result0 = -1;
351   int result1 = -1;
352   double mindist0 = INFINITY;
353   double mindist1 = INFINITY;
354   for (size_t i = 0; i < qc->ngroups; i++)
355     {
356       double dist = 0;
357       for (size_t j = 0; j < qc->n_vars; j++)
358         {
359           const union value *val = case_data (c, qc->vars[j]);
360           if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
361             continue;
362
363           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
364         }
365
366       if (dist < mindist0)
367         {
368           mindist1 = mindist0;
369           result1 = result0;
370
371           mindist0 = dist;
372           result0 = i;
373         }
374       else if (dist < mindist1)
375         {
376           mindist1 = dist;
377           result1 = i;
378         }
379     }
380
381   if (delta_q)
382     *delta_q = mindist0;
383
384   if (g_q)
385     *g_q = result0;
386
387   if (delta_p)
388     *delta_p = mindist1;
389
390   if (g_p)
391     *g_p = result1;
392 }
393
394 static void
395 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
396 {
397   gsl_vector *v = gsl_vector_alloc (qc->ngroups);
398   gsl_matrix_get_col (v, kmeans->centers, 0);
399   gsl_sort_vector_index (kmeans->group_order, v);
400   gsl_vector_free (v);
401 }
402
403 /* Main algorithm.
404    Does iterations, checks convergency. */
405 static void
406 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader,
407                 const struct qc *qc)
408 {
409   kmeans_initial_centers (kmeans, reader, qc);
410
411   gsl_matrix_memcpy (kmeans->updated_centers, kmeans->centers);
412   for (int xx = 0; xx < qc->maxiter; ++xx)
413     {
414       gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
415
416       kmeans->n = 0;
417       if (qc->update)
418         {
419           struct casereader *r = casereader_clone (reader);
420           struct ccase *c;
421           for (; (c = casereader_read (r)) != NULL; case_unref (c))
422             {
423               bool missing = false;
424               for (size_t j = 0; j < qc->n_vars; j++)
425                 {
426                   const union value *val = case_data (c, qc->vars[j]);
427                   if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
428                     missing = true;
429                 }
430               if (missing)
431                 continue;
432
433               double mindist = INFINITY;
434               int group = -1;
435               for (size_t g = 0; g < qc->ngroups; ++g)
436                 {
437                   double d = dist_from_case (kmeans, c, qc, g);
438
439                   if (d < mindist)
440                     {
441                       mindist = d;
442                       group = g;
443                     }
444                 }
445
446               long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
447               *n += qc->wv ? case_num (c, qc->wv) : 1.0;
448               kmeans->n++;
449
450               for (size_t j = 0; j < qc->n_vars; ++j)
451                 {
452                   const union value *val = case_data (c, qc->vars[j]);
453                   if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
454                     continue;
455                   double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
456                   *x += val->f * (qc->wv ? case_num (c, qc->wv) : 1.0);
457                 }
458             }
459
460           casereader_destroy (r);
461         }
462
463       /* Divide the cluster sums by the number of items in each cluster */
464       for (size_t g = 0; g < qc->ngroups; ++g)
465         for (size_t j = 0; j < qc->n_vars; ++j)
466           {
467             long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
468             double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
469             *x /= n + 1;  // Plus 1 for the initial centers
470           }
471       gsl_matrix_memcpy (kmeans->centers, kmeans->updated_centers);
472
473       kmeans->n = 0;
474       /* Step 3 */
475       gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
476       gsl_matrix_set_all (kmeans->updated_centers, 0.0);
477       struct ccase *c;
478       struct casereader *cs = casereader_clone (reader);
479       for (; (c = casereader_read (cs)) != NULL; case_unref (c))
480         {
481           int group = -1;
482           kmeans_get_nearest_group (kmeans, c, qc, &group, NULL, NULL, NULL);
483
484           for (size_t j = 0; j < qc->n_vars; ++j)
485             {
486               const union value *val = case_data (c, qc->vars[j]);
487               if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
488                 continue;
489
490               double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
491               *x += val->f;
492             }
493
494           long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
495           *n += qc->wv ? case_num (c, qc->wv) : 1.0;
496           kmeans->n++;
497         }
498       casereader_destroy (cs);
499
500       /* Divide the cluster sums by the number of items in each cluster */
501       for (size_t g = 0; g < qc->ngroups; ++g)
502         for (size_t j = 0; j < qc->n_vars; ++j)
503           {
504             long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
505             double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
506             *x /= n;
507           }
508
509       double d = diff_matrix (kmeans->updated_centers, kmeans->centers);
510       if (d < kmeans->convergence_criteria)
511         break;
512
513       if (!qc->update)
514         break;
515     }
516 }
517
518 /* Reports centers of clusters.
519    Initial parameter is optional for future use.
520    If initial is true, initial cluster centers are reported.  Otherwise,
521    resulted centers are reported. */
522 static void
523 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
524 {
525   struct pivot_table *table
526     = pivot_table_create (initial
527                           ? N_("Initial Cluster Centers")
528                           : N_("Final Cluster Centers"));
529
530   struct pivot_dimension *clusters
531     = pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"));
532
533   clusters->root->show_label = true;
534   for (size_t i = 0; i < qc->ngroups; i++)
535     pivot_category_create_leaf (clusters->root,
536                                 pivot_value_new_integer (i + 1));
537
538   struct pivot_dimension *variables
539     = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Variable"));
540
541   for (size_t i = 0; i < qc->n_vars; i++)
542     pivot_category_create_leaf (variables->root,
543                                 pivot_value_new_variable (qc->vars[i]));
544
545   const gsl_matrix *matrix = (initial
546                               ? kmeans->initial_centers
547                               : kmeans->centers);
548   for (size_t i = 0; i < qc->ngroups; i++)
549     for (size_t j = 0; j < qc->n_vars; j++)
550       {
551         double x = gsl_matrix_get (matrix, kmeans->group_order->data[i], j);
552         union value v = { .f = x };
553         pivot_table_put2 (table, i, j,
554                           pivot_value_new_var_value (qc->vars[j], &v));
555       }
556
557   pivot_table_submit (table);
558 }
559
560
561 /* A transformation function which juxtaposes the dataset with the
562    (pre-prepared) dataset containing membership and/or distance
563    values.  */
564 static enum trns_result
565 save_trans_func (void *aux, struct ccase **c, casenumber x UNUSED)
566 {
567   const struct save_trans_data *std = aux;
568   struct ccase *ca  = casereader_read (std->appending_reader);
569   if (ca == NULL)
570     return TRNS_CONTINUE;
571
572   *c = case_unshare (*c);
573
574   if (std->membership_case_idx >= 0)
575     *case_num_rw (*c, std->membership) = case_num_idx (ca, std->membership_case_idx);
576
577   if (std->distance_case_idx >= 0)
578     *case_num_rw (*c, std->distance) = case_num_idx (ca, std->distance_case_idx);
579
580   case_unref (ca);
581
582   return TRNS_CONTINUE;
583 }
584
585 /* Free the resources of the transformation.  */
586 static bool
587 save_trans_destroy (void *aux)
588 {
589   struct save_trans_data *std = aux;
590   casereader_destroy (std->appending_reader);
591   free (std);
592   return true;
593 }
594
595 /* Reports cluster membership for each case, and is requested saves the
596    membership and the distance of the case from the cluster centre.  */
597 static void
598 quick_cluster_show_membership (struct Kmeans *kmeans,
599                                const struct casereader *reader,
600                                struct qc *qc)
601 {
602   struct pivot_table *table = NULL;
603   struct pivot_dimension *cases = NULL;
604   if (qc->print_cluster_membership)
605     {
606       table = pivot_table_create (N_("Cluster Membership"));
607
608       pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"),
609                               N_("Cluster"));
610
611       cases
612         = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Case Number"));
613
614       cases->root->show_label = true;
615     }
616
617   gsl_permutation *ip = gsl_permutation_alloc (qc->ngroups);
618   gsl_permutation_inverse (ip, kmeans->group_order);
619
620   struct caseproto *proto = caseproto_create ();
621   if (qc->save_membership || qc->save_distance)
622     {
623       /* Prepare data which may potentially be used in a
624          transformation appending new variables to the active
625          dataset.  */
626       int idx = 0;
627       int membership_case_idx = -1;
628       if (qc->save_membership)
629         {
630           proto = caseproto_add_width (proto, 0);
631           membership_case_idx = idx++;
632         }
633
634       int distance_case_idx = -1;
635       if (qc->save_distance)
636         {
637           proto = caseproto_add_width (proto, 0);
638           distance_case_idx = idx++;
639         }
640
641       qc->save_trans_data = xmalloc (sizeof *qc->save_trans_data);
642       *qc->save_trans_data = (struct save_trans_data) {
643         .membership_case_idx = membership_case_idx,
644         .distance_case_idx = distance_case_idx,
645         .writer = autopaging_writer_create (proto),
646       };
647     }
648
649   struct casereader *cs = casereader_clone (reader);
650   struct ccase *c;
651   for (int i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c))
652     {
653       assert (i < kmeans->n);
654       int clust;
655       kmeans_get_nearest_group (kmeans, c, qc, &clust, NULL, NULL, NULL);
656       int cluster = ip->data[clust];
657
658       if (qc->save_trans_data)
659         {
660           /* Calculate the membership and distance values.  */
661           struct ccase *outc = case_create (proto);
662           if (qc->save_membership)
663             *case_num_rw_idx (outc, qc->save_trans_data->membership_case_idx) = cluster + 1;
664
665           if (qc->save_distance)
666             *case_num_rw_idx (outc, qc->save_trans_data->distance_case_idx)
667               = sqrt (dist_from_case (kmeans, c, qc, clust));
668
669           casewriter_write (qc->save_trans_data->writer, outc);
670         }
671
672       if (qc->print_cluster_membership)
673         {
674           /* Print the cluster membership to the table.  */
675           int case_idx = pivot_category_create_leaf (cases->root,
676                                                  pivot_value_new_integer (i + 1));
677           pivot_table_put2 (table, 0, case_idx,
678                             pivot_value_new_integer (cluster + 1));
679         }
680     }
681
682   caseproto_unref (proto);
683   gsl_permutation_free (ip);
684
685   if (qc->print_cluster_membership)
686     pivot_table_submit (table);
687   casereader_destroy (cs);
688 }
689
690
691 /* Reports number of cases of each single cluster. */
692 static void
693 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
694 {
695   struct pivot_table *table
696     = pivot_table_create (N_("Number of Cases in each Cluster"));
697
698   pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
699                           N_("Count"));
700
701   struct pivot_dimension *clusters
702     = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Clusters"));
703
704   struct pivot_category *group
705     = pivot_category_create_group (clusters->root, N_("Cluster"));
706
707   long int total = 0;
708   for (int i = 0; i < qc->ngroups; i++)
709     {
710       int cluster_idx
711         = pivot_category_create_leaf (group, pivot_value_new_integer (i + 1));
712       int count = kmeans->num_elements_groups->data [kmeans->group_order->data[i]];
713       pivot_table_put2 (table, 0, cluster_idx, pivot_value_new_integer (count));
714       total += count;
715     }
716
717   int cluster_idx = pivot_category_create_leaf (clusters->root,
718                                                 pivot_value_new_text (N_("Valid")));
719   pivot_table_put2 (table, 0, cluster_idx, pivot_value_new_integer (total));
720   pivot_table_submit (table);
721 }
722
723 /* Reports. */
724 static void
725 quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader,
726                             struct qc *qc)
727 {
728   kmeans_order_groups (kmeans, qc); /* what does this do? */
729
730   if (qc->print_initial_clusters)
731     quick_cluster_show_centers (kmeans, true, qc);
732   quick_cluster_show_centers (kmeans, false, qc);
733   quick_cluster_show_number_cases (kmeans, qc);
734
735   quick_cluster_show_membership (kmeans, reader, qc);
736 }
737
738 /* Parse the QUICK CLUSTER command and populate QC accordingly.
739    Returns false on error.  */
740 static bool
741 quick_cluster_parse (struct lexer *lexer, struct qc *qc)
742 {
743   if (!parse_variables_const (lexer, qc->dict, &qc->vars, &qc->n_vars,
744                               PV_NO_DUPLICATE | PV_NUMERIC))
745     return false;
746
747   while (lex_token (lexer) != T_ENDCMD)
748     {
749       lex_match (lexer, T_SLASH);
750
751       if (lex_match_id (lexer, "MISSING"))
752         {
753           lex_match (lexer, T_EQUALS);
754           while (lex_token (lexer) != T_ENDCMD
755                  && lex_token (lexer) != T_SLASH)
756             {
757               if (lex_match_id (lexer, "LISTWISE")
758                   || lex_match_id (lexer, "DEFAULT"))
759                 qc->missing_type = MISS_LISTWISE;
760               else if (lex_match_id (lexer, "PAIRWISE"))
761                 qc->missing_type = MISS_PAIRWISE;
762               else if (lex_match_id (lexer, "INCLUDE"))
763                 qc->exclude = MV_SYSTEM;
764               else if (lex_match_id (lexer, "EXCLUDE"))
765                 qc->exclude = MV_ANY;
766               else
767                 {
768                   lex_error_expecting (lexer, "LISTWISE", "DEFAULT",
769                                        "PAIRWISE", "INCLUDE", "EXCLUDE");
770                   return false;
771                 }
772             }
773         }
774       else if (lex_match_id (lexer, "PRINT"))
775         {
776           lex_match (lexer, T_EQUALS);
777           while (lex_token (lexer) != T_ENDCMD
778                  && lex_token (lexer) != T_SLASH)
779             {
780               if (lex_match_id (lexer, "CLUSTER"))
781                 qc->print_cluster_membership = true;
782               else if (lex_match_id (lexer, "INITIAL"))
783                 qc->print_initial_clusters = true;
784               else
785                 {
786                   lex_error_expecting (lexer, "CLUSTER", "INITIAL");
787                   return false;
788                 }
789             }
790         }
791       else if (lex_match_id (lexer, "SAVE"))
792         {
793           lex_match (lexer, T_EQUALS);
794           while (lex_token (lexer) != T_ENDCMD
795                  && lex_token (lexer) != T_SLASH)
796             {
797               if (lex_match_id (lexer, "CLUSTER"))
798                 {
799                   qc->save_membership = true;
800                   if (lex_match (lexer, T_LPAREN))
801                     {
802                       if (!lex_force_id (lexer))
803                         return false;
804
805                       free (qc->var_membership);
806                       qc->var_membership = xstrdup (lex_tokcstr (lexer));
807                       if (NULL != dict_lookup_var (qc->dict, qc->var_membership))
808                         {
809                           lex_error (lexer,
810                                      _("A variable called `%s' already exists."),
811                                      qc->var_membership);
812                           free (qc->var_membership);
813                           qc->var_membership = NULL;
814                           return false;
815                         }
816
817                       lex_get (lexer);
818
819                       if (!lex_force_match (lexer, T_RPAREN))
820                         return false;
821                     }
822                 }
823               else if (lex_match_id (lexer, "DISTANCE"))
824                 {
825                   qc->save_distance = true;
826                   if (lex_match (lexer, T_LPAREN))
827                     {
828                       if (!lex_force_id (lexer))
829                         return false;
830
831                       free (qc->var_distance);
832                       qc->var_distance = xstrdup (lex_tokcstr (lexer));
833                       if (NULL != dict_lookup_var (qc->dict, qc->var_distance))
834                         {
835                           lex_error (lexer,
836                                      _("A variable called `%s' already exists."),
837                                      qc->var_distance);
838                           free (qc->var_distance);
839                           qc->var_distance = NULL;
840                           return false;
841                         }
842
843                       lex_get (lexer);
844
845                       if (!lex_force_match (lexer, T_RPAREN))
846                         return false;
847                     }
848                 }
849               else
850                 {
851                   lex_error_expecting (lexer, "CLUSTER", "DISTANCE");
852                   return false;
853                 }
854             }
855         }
856       else if (lex_match_id (lexer, "CRITERIA"))
857         {
858           lex_match (lexer, T_EQUALS);
859           while (lex_token (lexer) != T_ENDCMD
860                  && lex_token (lexer) != T_SLASH)
861             {
862               if (lex_match_id (lexer, "CLUSTERS"))
863                 {
864                   if (!lex_force_match (lexer, T_LPAREN)
865                       || !lex_force_int_range (lexer, "CLUSTERS", 1, INT_MAX))
866                     return false;
867                   qc->ngroups = lex_integer (lexer);
868                   lex_get (lexer);
869                   if (!lex_force_match (lexer, T_RPAREN))
870                     return false;
871                 }
872               else if (lex_match_id (lexer, "CONVERGE"))
873                 {
874                   if (!lex_force_match (lexer, T_LPAREN)
875                       || !lex_force_num_range_open (lexer, "CONVERGE",
876                                                     0, DBL_MAX))
877                     return false;
878                   qc->epsilon = lex_number (lexer);
879                   lex_get (lexer);
880                   if (!lex_force_match (lexer, T_RPAREN))
881                     return false;
882                 }
883               else if (lex_match_id (lexer, "MXITER"))
884                 {
885                   if (!lex_force_match (lexer, T_LPAREN)
886                       || !lex_force_int_range (lexer, "MXITER", 1, INT_MAX))
887                     return false;
888                   qc->maxiter = lex_integer (lexer);
889                   lex_get (lexer);
890                   if (!lex_force_match (lexer, T_RPAREN))
891                     return false;
892                 }
893               else if (lex_match_id (lexer, "NOINITIAL"))
894                 qc->initial = false;
895               else if (lex_match_id (lexer, "NOUPDATE"))
896                 qc->update = false;
897               else
898                 {
899                   lex_error_expecting (lexer, "CLUSTERS", "CONVERGE", "MXITER",
900                                        "NOINITIAL", "NOUPDATE");
901                   return false;
902                 }
903             }
904         }
905       else
906         {
907           lex_error_expecting (lexer, "MISSING", "PRINT", "SAVE", "CRITERIA");
908           return false;
909         }
910     }
911   return true;
912 }
913
914 int
915 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
916 {
917   struct qc qc = {
918     .dataset = ds,
919     .dict = dataset_dict (ds),
920     .ngroups = 2,
921     .maxiter = 10,
922     .epsilon = DBL_EPSILON,
923     .missing_type = MISS_LISTWISE,
924     .exclude = MV_ANY,
925     .initial = true,
926     .update = true,
927   };
928
929   if (!quick_cluster_parse (lexer, &qc))
930     goto error;
931
932   qc.wv = dict_get_weight (qc.dict);
933
934   struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), qc.dict);
935   struct casereader *group;
936   while (casegrouper_get_next_group (grouper, &group))
937     {
938       if (qc.missing_type == MISS_LISTWISE)
939         group = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
940                                                   qc.exclude, NULL, NULL);
941
942       struct Kmeans *kmeans = kmeans_create (&qc);
943       kmeans_cluster (kmeans, group, &qc);
944       quick_cluster_show_results (kmeans, group, &qc);
945       kmeans_destroy (kmeans);
946       casereader_destroy (group);
947     }
948   bool ok = casegrouper_destroy (grouper);
949   ok = proc_commit (ds) && ok;
950
951   /* If requested, set a transformation to append the cluster and
952      distance values to the current dataset.  */
953   if (qc.save_trans_data)
954     {
955       struct save_trans_data *std = qc.save_trans_data;
956
957       std->appending_reader = casewriter_make_reader (std->writer);
958
959       if (qc.save_membership)
960         {
961           /* Invent a variable name if necessary.  */
962           int idx = 0;
963           struct string name;
964           ds_init_empty (&name);
965           while (qc.var_membership == NULL)
966             {
967               ds_clear (&name);
968               ds_put_format (&name, "QCL_%d", idx++);
969
970               if (!dict_lookup_var (qc.dict, ds_cstr (&name)))
971                 {
972                   qc.var_membership = strdup (ds_cstr (&name));
973                   break;
974                 }
975             }
976           ds_destroy (&name);
977
978           std->membership = dict_create_var_assert (qc.dict, qc.var_membership, 0);
979         }
980
981       if (qc.save_distance)
982         {
983           /* Invent a variable name if necessary.  */
984           int idx = 0;
985           struct string name;
986           ds_init_empty (&name);
987           while (qc.var_distance == NULL)
988             {
989               ds_clear (&name);
990               ds_put_format (&name, "QCL_%d", idx++);
991
992               if (!dict_lookup_var (qc.dict, ds_cstr (&name)))
993                 {
994                   qc.var_distance = strdup (ds_cstr (&name));
995                   break;
996                 }
997             }
998           ds_destroy (&name);
999
1000           std->distance = dict_create_var_assert (qc.dict, qc.var_distance, 0);
1001         }
1002
1003       static const struct trns_class trns_class = {
1004         .name = "QUICK CLUSTER",
1005         .execute = save_trans_func,
1006         .destroy = save_trans_destroy,
1007       };
1008       add_transformation (qc.dataset, &trns_class, std);
1009     }
1010
1011   free (qc.var_distance);
1012   free (qc.var_membership);
1013   free (qc.vars);
1014   return ok;
1015
1016  error:
1017   free (qc.var_distance);
1018   free (qc.var_membership);
1019   free (qc.vars);
1020   return CMD_FAILURE;
1021 }