Clean up how transformations work.
[pspp] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011, 2012, 2015, 2019 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25
26 #include "data/case.h"
27 #include "data/casegrouper.h"
28 #include "data/casereader.h"
29 #include "data/casewriter.h"
30 #include "data/dataset.h"
31 #include "data/dictionary.h"
32 #include "data/format.h"
33 #include "data/missing-values.h"
34 #include "language/command.h"
35 #include "language/lexer/lexer.h"
36 #include "language/lexer/variable-parser.h"
37 #include "libpspp/message.h"
38 #include "libpspp/misc.h"
39 #include "libpspp/assertion.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/pivot-table.h"
43 #include "output/output-item.h"
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49 enum missing_type
50   {
51     MISS_LISTWISE,
52     MISS_PAIRWISE,
53   };
54
55
56 struct save_trans_data
57 {
58   /* A writer which contains the values (if any) to be appended to
59      each case in the active dataset   */
60   struct casewriter *writer;
61
62   /* A reader created from the writer above. */
63   struct casereader *appending_reader;
64
65   /* The indices to be used to access values in the above,
66      reader/writer  */
67   int CASE_IDX_MEMBERSHIP;
68   int CASE_IDX_DISTANCE;
69
70   /* The variables created to hold the values appended to the dataset  */
71   struct variable *membership;
72   struct variable *distance;
73 };
74
75
76 #define SAVE_MEMBERSHIP 0x1
77 #define SAVE_DISTANCE   0x2
78
79 struct qc
80 {
81   struct dataset *dataset;
82   struct dictionary *dict;
83
84   const struct variable **vars;
85   size_t n_vars;
86
87   double epsilon;               /* The convergence criterion */
88
89   int ngroups;                  /* Number of group. (Given by the user) */
90   int maxiter;                  /* Maximum iterations (Given by the user) */
91   bool print_cluster_membership; /* true => print membership */
92   bool print_initial_clusters;   /* true => print initial cluster */
93   bool no_initial;              /* true => simplified initial cluster selection */
94   bool no_update;               /* true => do not iterate  */
95
96   const struct variable *wv;    /* Weighting variable. */
97
98   enum missing_type missing_type;
99   enum mv_class exclude;
100
101   /* Which values are to be saved?  */
102   int save_values;
103
104   /* The name of the new variable to contain the cluster of each case.  */
105   char *var_membership;
106
107   /* The name of the new variable to contain the distance of each case
108      from its cluster centre.  */
109   char *var_distance;
110
111   struct save_trans_data *save_trans_data;
112 };
113
114 /* Holds all of the information for the functions.  int n, holds the number of
115    observation and its default value is -1.  We set it in
116    kmeans_recalculate_centers in first invocation. */
117 struct Kmeans
118 {
119   gsl_matrix *centers;          /* Centers for groups. */
120   gsl_matrix *updated_centers;
121   casenumber n;
122
123   gsl_vector_long *num_elements_groups;
124
125   gsl_matrix *initial_centers;  /* Initial random centers. */
126   double convergence_criteria;
127   gsl_permutation *group_order; /* Group order for reporting. */
128 };
129
130 static struct Kmeans *kmeans_create (const struct qc *qc);
131
132 static void kmeans_get_nearest_group (const struct Kmeans *kmeans,
133                                       struct ccase *c, const struct qc *,
134                                       int *, double *, int *, double *);
135
136 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
137
138 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader,
139                             const struct qc *);
140
141 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial,
142                                         const struct qc *);
143
144 static void quick_cluster_show_membership (struct Kmeans *kmeans,
145                                            const struct casereader *reader,
146                                            struct qc *);
147
148 static void quick_cluster_show_number_cases (struct Kmeans *kmeans,
149                                              const struct qc *);
150
151 static void quick_cluster_show_results (struct Kmeans *kmeans,
152                                         const struct casereader *reader,
153                                         struct qc *);
154
155 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
156
157 static void kmeans_destroy (struct Kmeans *kmeans);
158
159 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
160    variables 'variables', number of cases 'n', number of variables 'm', number
161    of clusters and amount of maximum iterations. */
162 static struct Kmeans *
163 kmeans_create (const struct qc *qc)
164 {
165   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
166   kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
167   kmeans->updated_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
168   kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
169   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
170   kmeans->initial_centers = NULL;
171
172   return (kmeans);
173 }
174
175 static void
176 kmeans_destroy (struct Kmeans *kmeans)
177 {
178   gsl_matrix_free (kmeans->centers);
179   gsl_matrix_free (kmeans->updated_centers);
180   gsl_matrix_free (kmeans->initial_centers);
181
182   gsl_vector_long_free (kmeans->num_elements_groups);
183
184   gsl_permutation_free (kmeans->group_order);
185
186   free (kmeans);
187 }
188
189 static double
190 diff_matrix (const gsl_matrix *m1, const gsl_matrix *m2)
191 {
192   int i,j;
193   double max_diff = -INFINITY;
194   for (i = 0; i < m1->size1; ++i)
195     {
196       double diff = 0;
197       for (j = 0; j < m1->size2; ++j)
198         {
199           diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j));
200         }
201       if (diff > max_diff)
202         max_diff = diff;
203     }
204
205   return max_diff;
206 }
207
208
209
210 static double
211 matrix_mindist (const gsl_matrix *m, int *mn, int *mm)
212 {
213   int i, j;
214   double mindist = INFINITY;
215   for (i = 0; i < m->size1 - 1; ++i)
216     {
217       for (j = i + 1; j < m->size1; ++j)
218         {
219           int k;
220           double diff_sq = 0;
221           for (k = 0; k < m->size2; ++k)
222             {
223               diff_sq += pow2 (gsl_matrix_get (m, j, k) - gsl_matrix_get (m, i, k));
224             }
225           if (diff_sq < mindist)
226             {
227               mindist = diff_sq;
228               if (mn)
229                 *mn = i;
230               if (mm)
231                 *mm = j;
232             }
233         }
234     }
235
236   return mindist;
237 }
238
239
240 /* Return the distance of C from the group whose index is WHICH */
241 static double
242 dist_from_case (const struct Kmeans *kmeans, const struct ccase *c,
243                 const struct qc *qc, int which)
244 {
245   int j;
246   double dist = 0;
247   for (j = 0; j < qc->n_vars; j++)
248     {
249       const union value *val = case_data (c, qc->vars[j]);
250       if (var_is_value_missing (qc->vars[j], val, qc->exclude))
251         NOT_REACHED ();
252
253       dist += pow2 (gsl_matrix_get (kmeans->centers, which, j) - val->f);
254     }
255
256   return dist;
257 }
258
259 /* Return the minimum distance of the group WHICH and all other groups */
260 static double
261 min_dist_from (const struct Kmeans *kmeans, const struct qc *qc, int which)
262 {
263   int j, i;
264
265   double mindist = INFINITY;
266   for (i = 0; i < qc->ngroups; i++)
267     {
268       if (i == which)
269         continue;
270
271       double dist = 0;
272       for (j = 0; j < qc->n_vars; j++)
273         {
274           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j)
275                         - gsl_matrix_get (kmeans->centers, which, j));
276         }
277
278       if (dist < mindist)
279         {
280           mindist = dist;
281         }
282     }
283
284   return mindist;
285 }
286
287
288
289 /* Calculate the initial cluster centers. */
290 static void
291 kmeans_initial_centers (struct Kmeans *kmeans,
292                         const struct casereader *reader,
293                         const struct qc *qc)
294 {
295   struct ccase *c;
296   int nc = 0, j;
297
298   struct casereader *cs = casereader_clone (reader);
299   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
300     {
301       bool missing = false;
302       for (j = 0; j < qc->n_vars; ++j)
303         {
304           const union value *val = case_data (c, qc->vars[j]);
305           if (var_is_value_missing (qc->vars[j], val, qc->exclude))
306             {
307               missing = true;
308               break;
309             }
310
311           if (nc < qc->ngroups)
312             gsl_matrix_set (kmeans->centers, nc, j, val->f);
313         }
314
315       if (missing)
316         continue;
317
318       if (nc++ < qc->ngroups)
319         continue;
320
321       if (!qc->no_initial)
322         {
323           int mq, mp;
324           double delta;
325
326           int mn, mm;
327           double m = matrix_mindist (kmeans->centers, &mn, &mm);
328
329           kmeans_get_nearest_group (kmeans, c, qc, &mq, &delta, &mp, NULL);
330           if (delta > m)
331             /* If the distance between C and the nearest group, is greater than the distance
332                between the two  groups which are clostest to each
333                other, then one group must be replaced.  */
334             {
335               /* Out of mn and mm, which is the clostest of the two groups to C ? */
336               int which = (dist_from_case (kmeans, c, qc, mn)
337                            > dist_from_case (kmeans, c, qc, mm)) ? mm : mn;
338
339               for (j = 0; j < qc->n_vars; ++j)
340                 {
341                   const union value *val = case_data (c, qc->vars[j]);
342                   gsl_matrix_set (kmeans->centers, which, j, val->f);
343                 }
344             }
345           else if (dist_from_case (kmeans, c, qc, mp) > min_dist_from (kmeans, qc, mq))
346             /* If the distance between C and the second nearest group
347                (MP) is greater than the smallest distance between the
348                nearest group (MQ) and any other group, then replace
349                MQ with C.  */
350             {
351               for (j = 0; j < qc->n_vars; ++j)
352                 {
353                   const union value *val = case_data (c, qc->vars[j]);
354                   gsl_matrix_set (kmeans->centers, mq, j, val->f);
355                 }
356             }
357         }
358     }
359
360   casereader_destroy (cs);
361
362   kmeans->convergence_criteria = qc->epsilon * matrix_mindist (kmeans->centers, NULL, NULL);
363
364   /* As it is the first iteration, the variable kmeans->initial_centers is NULL
365      and it is created once for reporting issues. */
366   kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
367   gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
368 }
369
370
371 /* Return the index of the group which is nearest to the case C */
372 static void
373 kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c,
374                           const struct qc *qc, int *g_q, double *delta_q,
375                           int *g_p, double *delta_p)
376 {
377   int result0 = -1;
378   int result1 = -1;
379   int i, j;
380   double mindist0 = INFINITY;
381   double mindist1 = INFINITY;
382   for (i = 0; i < qc->ngroups; i++)
383     {
384       double dist = 0;
385       for (j = 0; j < qc->n_vars; j++)
386         {
387           const union value *val = case_data (c, qc->vars[j]);
388           if (var_is_value_missing (qc->vars[j], val, qc->exclude))
389             continue;
390
391           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
392         }
393
394       if (dist < mindist0)
395         {
396           mindist1 = mindist0;
397           result1 = result0;
398
399           mindist0 = dist;
400           result0 = i;
401         }
402       else if (dist < mindist1)
403         {
404           mindist1 = dist;
405           result1 = i;
406         }
407     }
408
409   if (delta_q)
410     *delta_q = mindist0;
411
412   if (g_q)
413     *g_q = result0;
414
415
416   if (delta_p)
417     *delta_p = mindist1;
418
419   if (g_p)
420     *g_p = result1;
421 }
422
423
424
425 static void
426 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
427 {
428   gsl_vector *v = gsl_vector_alloc (qc->ngroups);
429   gsl_matrix_get_col (v, kmeans->centers, 0);
430   gsl_sort_vector_index (kmeans->group_order, v);
431   gsl_vector_free (v);
432 }
433
434 /* Main algorithm.
435    Does iterations, checks convergency. */
436 static void
437 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader,
438                 const struct qc *qc)
439 {
440   int j;
441
442   kmeans_initial_centers (kmeans, reader, qc);
443
444   gsl_matrix_memcpy (kmeans->updated_centers, kmeans->centers);
445
446
447   for (int xx = 0 ; xx < qc->maxiter ; ++xx)
448     {
449       gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
450
451       kmeans->n = 0;
452       if (!qc->no_update)
453         {
454           struct casereader *r = casereader_clone (reader);
455           struct ccase *c;
456           for (; (c = casereader_read (r)) != NULL; case_unref (c))
457             {
458               int group = -1;
459               int g;
460               bool missing = false;
461
462               for (j = 0; j < qc->n_vars; j++)
463                 {
464                   const union value *val = case_data (c, qc->vars[j]);
465                   if (var_is_value_missing (qc->vars[j], val, qc->exclude))
466                     missing = true;
467                 }
468
469               if (missing)
470                 continue;
471
472               double mindist = INFINITY;
473               for (g = 0; g < qc->ngroups; ++g)
474                 {
475                   double d = dist_from_case (kmeans, c, qc, g);
476
477                   if (d < mindist)
478                     {
479                       mindist = d;
480                       group = g;
481                     }
482                 }
483
484               long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
485               *n += qc->wv ? case_num (c, qc->wv) : 1.0;
486               kmeans->n++;
487
488               for (j = 0; j < qc->n_vars; ++j)
489                 {
490                   const union value *val = case_data (c, qc->vars[j]);
491                   if (var_is_value_missing (qc->vars[j], val, qc->exclude))
492                     continue;
493                   double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
494                   *x += val->f * (qc->wv ? case_num (c, qc->wv) : 1.0);
495                 }
496             }
497
498           casereader_destroy (r);
499         }
500
501       int g;
502
503       /* Divide the cluster sums by the number of items in each cluster */
504       for (g = 0; g < qc->ngroups; ++g)
505         {
506           for (j = 0; j < qc->n_vars; ++j)
507             {
508               long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
509               double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
510               *x /= n + 1;  // Plus 1 for the initial centers
511             }
512         }
513
514
515       gsl_matrix_memcpy (kmeans->centers, kmeans->updated_centers);
516
517       {
518         kmeans->n = 0;
519         /* Step 3 */
520         gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
521         gsl_matrix_set_all (kmeans->updated_centers, 0.0);
522         struct ccase *c;
523         struct casereader *cs = casereader_clone (reader);
524         for (; (c = casereader_read (cs)) != NULL; case_unref (c))
525           {
526             int group = -1;
527             kmeans_get_nearest_group (kmeans, c, qc, &group, NULL, NULL, NULL);
528
529             for (j = 0; j < qc->n_vars; ++j)
530               {
531                 const union value *val = case_data (c, qc->vars[j]);
532                 if (var_is_value_missing (qc->vars[j], val, qc->exclude))
533                   continue;
534
535                 double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
536                 *x += val->f;
537               }
538
539             long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
540             *n += qc->wv ? case_num (c, qc->wv) : 1.0;
541             kmeans->n++;
542           }
543         casereader_destroy (cs);
544
545
546         /* Divide the cluster sums by the number of items in each cluster */
547         for (g = 0; g < qc->ngroups; ++g)
548           {
549             for (j = 0; j < qc->n_vars; ++j)
550               {
551                 long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
552                 double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
553                 *x /= n ;
554               }
555           }
556
557         double d = diff_matrix (kmeans->updated_centers, kmeans->centers);
558         if (d < kmeans->convergence_criteria)
559           break;
560       }
561
562       if (qc->no_update)
563         break;
564     }
565 }
566
567 /* Reports centers of clusters.
568    Initial parameter is optional for future use.
569    If initial is true, initial cluster centers are reported.  Otherwise,
570    resulted centers are reported. */
571 static void
572 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
573 {
574   struct pivot_table *table
575     = pivot_table_create (initial
576                           ? N_("Initial Cluster Centers")
577                           : N_("Final Cluster Centers"));
578
579   struct pivot_dimension *clusters
580     = pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"));
581
582   clusters->root->show_label = true;
583   for (size_t i = 0; i < qc->ngroups; i++)
584     pivot_category_create_leaf (clusters->root,
585                                 pivot_value_new_integer (i + 1));
586
587   struct pivot_dimension *variables
588     = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Variable"));
589
590   for (size_t i = 0; i < qc->n_vars; i++)
591     pivot_category_create_leaf (variables->root,
592                                 pivot_value_new_variable (qc->vars[i]));
593
594   const gsl_matrix *matrix = (initial
595                               ? kmeans->initial_centers
596                               : kmeans->centers);
597   for (size_t i = 0; i < qc->ngroups; i++)
598     for (size_t j = 0; j < qc->n_vars; j++)
599       {
600         double x = gsl_matrix_get (matrix, kmeans->group_order->data[i], j);
601         union value v = { .f = x };
602         pivot_table_put2 (table, i, j,
603                           pivot_value_new_var_value (qc->vars[j], &v));
604       }
605
606   pivot_table_submit (table);
607 }
608
609
610 /* A transformation function which juxtaposes the dataset with the
611    (pre-prepared) dataset containing membership and/or distance
612    values.  */
613 static enum trns_result
614 save_trans_func (void *aux, struct ccase **c, casenumber x UNUSED)
615 {
616   const struct save_trans_data *std = aux;
617   struct ccase *ca  = casereader_read (std->appending_reader);
618   if (ca == NULL)
619     return TRNS_CONTINUE;
620
621   *c = case_unshare (*c);
622
623   if (std->CASE_IDX_MEMBERSHIP >= 0)
624     *case_num_rw (*c, std->membership) = case_num_idx (ca, std->CASE_IDX_MEMBERSHIP);
625
626   if (std->CASE_IDX_DISTANCE >= 0)
627     *case_num_rw (*c, std->distance) = case_num_idx (ca, std->CASE_IDX_DISTANCE);
628
629   case_unref (ca);
630
631   return TRNS_CONTINUE;
632 }
633
634
635 /* Free the resources of the transformation.  */
636 static bool
637 save_trans_destroy (void *aux)
638 {
639   struct save_trans_data *std = aux;
640   casereader_destroy (std->appending_reader);
641   free (std);
642   return true;
643 }
644
645
646 /* Reports cluster membership for each case, and is requested
647 saves the membership and the distance of the case from the cluster
648 centre.  */
649 static void
650 quick_cluster_show_membership (struct Kmeans *kmeans,
651                                const struct casereader *reader,
652                                struct qc *qc)
653 {
654   struct pivot_table *table = NULL;
655   struct pivot_dimension *cases = NULL;
656   if (qc->print_cluster_membership)
657     {
658       table = pivot_table_create (N_("Cluster Membership"));
659
660       pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"),
661                               N_("Cluster"));
662
663       cases
664         = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Case Number"));
665
666       cases->root->show_label = true;
667     }
668
669   gsl_permutation *ip = gsl_permutation_alloc (qc->ngroups);
670   gsl_permutation_inverse (ip, kmeans->group_order);
671
672   struct caseproto *proto = caseproto_create ();
673   if (qc->save_values)
674     {
675       /* Prepare data which may potentially be used in a
676          transformation appending new variables to the active
677          dataset.  */
678       qc->save_trans_data = xzalloc (sizeof *qc->save_trans_data);
679       qc->save_trans_data->CASE_IDX_MEMBERSHIP = -1;
680       qc->save_trans_data->CASE_IDX_DISTANCE = -1;
681       qc->save_trans_data->writer = autopaging_writer_create (proto);
682
683       int idx = 0;
684       if (qc->save_values & SAVE_MEMBERSHIP)
685         {
686           proto = caseproto_add_width (proto, 0);
687           qc->save_trans_data->CASE_IDX_MEMBERSHIP = idx++;
688         }
689
690       if (qc->save_values & SAVE_DISTANCE)
691         {
692           proto = caseproto_add_width (proto, 0);
693           qc->save_trans_data->CASE_IDX_DISTANCE = idx++;
694         }
695     }
696
697   struct casereader *cs = casereader_clone (reader);
698   struct ccase *c;
699   for (int i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c))
700     {
701       assert (i < kmeans->n);
702       int clust;
703       kmeans_get_nearest_group (kmeans, c, qc, &clust, NULL, NULL, NULL);
704       int cluster = ip->data[clust];
705
706       if (qc->save_trans_data)
707       {
708         /* Calculate the membership and distance values.  */
709         struct ccase *outc = case_create (proto);
710         if (qc->save_values & SAVE_MEMBERSHIP)
711           *case_num_rw_idx (outc, qc->save_trans_data->CASE_IDX_MEMBERSHIP) = cluster + 1;
712
713         if (qc->save_values & SAVE_DISTANCE)
714           *case_num_rw_idx (outc, qc->save_trans_data->CASE_IDX_DISTANCE)
715             = sqrt (dist_from_case (kmeans, c, qc, clust));
716
717         casewriter_write (qc->save_trans_data->writer, outc);
718       }
719
720       if (qc->print_cluster_membership)
721         {
722           /* Print the cluster membership to the table.  */
723           int case_idx = pivot_category_create_leaf (cases->root,
724                                                  pivot_value_new_integer (i + 1));
725           pivot_table_put2 (table, 0, case_idx,
726                             pivot_value_new_integer (cluster + 1));
727         }
728     }
729
730   caseproto_unref (proto);
731   gsl_permutation_free (ip);
732
733   if (qc->print_cluster_membership)
734     pivot_table_submit (table);
735   casereader_destroy (cs);
736 }
737
738
739 /* Reports number of cases of each single cluster. */
740 static void
741 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
742 {
743   struct pivot_table *table
744     = pivot_table_create (N_("Number of Cases in each Cluster"));
745
746   pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
747                           N_("Count"));
748
749   struct pivot_dimension *clusters
750     = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Clusters"));
751
752   struct pivot_category *group
753     = pivot_category_create_group (clusters->root, N_("Cluster"));
754
755   long int total = 0;
756   for (int i = 0; i < qc->ngroups; i++)
757     {
758       int cluster_idx
759         = pivot_category_create_leaf (group, pivot_value_new_integer (i + 1));
760       int count = kmeans->num_elements_groups->data [kmeans->group_order->data[i]];
761       pivot_table_put2 (table, 0, cluster_idx, pivot_value_new_integer (count));
762       total += count;
763     }
764
765   int cluster_idx = pivot_category_create_leaf (clusters->root,
766                                                 pivot_value_new_text (N_("Valid")));
767   pivot_table_put2 (table, 0, cluster_idx, pivot_value_new_integer (total));
768   pivot_table_submit (table);
769 }
770
771 /* Reports. */
772 static void
773 quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader,
774                             struct qc *qc)
775 {
776   kmeans_order_groups (kmeans, qc); /* what does this do? */
777
778   if (qc->print_initial_clusters)
779     quick_cluster_show_centers (kmeans, true, qc);
780   quick_cluster_show_centers (kmeans, false, qc);
781   quick_cluster_show_number_cases (kmeans, qc);
782
783   quick_cluster_show_membership (kmeans, reader, qc);
784 }
785
786 /* Parse the QUICK CLUSTER command and populate QC accordingly.
787    Returns false on error.  */
788 static bool
789 quick_cluster_parse (struct lexer *lexer, struct qc *qc)
790 {
791   if (!parse_variables_const (lexer, qc->dict, &qc->vars, &qc->n_vars,
792                               PV_NO_DUPLICATE | PV_NUMERIC))
793     {
794       return (CMD_FAILURE);
795     }
796
797   while (lex_token (lexer) != T_ENDCMD)
798     {
799       lex_match (lexer, T_SLASH);
800
801       if (lex_match_id (lexer, "MISSING"))
802         {
803           lex_match (lexer, T_EQUALS);
804           while (lex_token (lexer) != T_ENDCMD
805                  && lex_token (lexer) != T_SLASH)
806             {
807               if (lex_match_id (lexer, "LISTWISE")
808                   || lex_match_id (lexer, "DEFAULT"))
809                 {
810                   qc->missing_type = MISS_LISTWISE;
811                 }
812               else if (lex_match_id (lexer, "PAIRWISE"))
813                 {
814                   qc->missing_type = MISS_PAIRWISE;
815                 }
816               else if (lex_match_id (lexer, "INCLUDE"))
817                 {
818                   qc->exclude = MV_SYSTEM;
819                 }
820               else if (lex_match_id (lexer, "EXCLUDE"))
821                 {
822                   qc->exclude = MV_ANY;
823                 }
824               else
825                 {
826                   lex_error (lexer, NULL);
827                   return false;
828                 }
829             }
830         }
831       else if (lex_match_id (lexer, "PRINT"))
832         {
833           lex_match (lexer, T_EQUALS);
834           while (lex_token (lexer) != T_ENDCMD
835                  && lex_token (lexer) != T_SLASH)
836             {
837               if (lex_match_id (lexer, "CLUSTER"))
838                 qc->print_cluster_membership = true;
839               else if (lex_match_id (lexer, "INITIAL"))
840                 qc->print_initial_clusters = true;
841               else
842                 {
843                   lex_error (lexer, NULL);
844                   return false;
845                 }
846             }
847         }
848       else if (lex_match_id (lexer, "SAVE"))
849         {
850           lex_match (lexer, T_EQUALS);
851           while (lex_token (lexer) != T_ENDCMD
852                  && lex_token (lexer) != T_SLASH)
853             {
854               if (lex_match_id (lexer, "CLUSTER"))
855                 {
856                   qc->save_values |= SAVE_MEMBERSHIP;
857                   if (lex_match (lexer, T_LPAREN))
858                     {
859                       if (!lex_force_id (lexer))
860                         return false;
861
862                       free (qc->var_membership);
863                       qc->var_membership = xstrdup (lex_tokcstr (lexer));
864                       if (NULL != dict_lookup_var (qc->dict, qc->var_membership))
865                         {
866                           lex_error (lexer,
867                                      _("A variable called `%s' already exists."),
868                                      qc->var_membership);
869                           free (qc->var_membership);
870                           qc->var_membership = NULL;
871                           return false;
872                         }
873
874                       lex_get (lexer);
875
876                       if (!lex_force_match (lexer, T_RPAREN))
877                         return false;
878                     }
879                 }
880               else if (lex_match_id (lexer, "DISTANCE"))
881                 {
882                   qc->save_values |= SAVE_DISTANCE;
883                   if (lex_match (lexer, T_LPAREN))
884                     {
885                       if (!lex_force_id (lexer))
886                         return false;
887
888                       free (qc->var_distance);
889                       qc->var_distance = xstrdup (lex_tokcstr (lexer));
890                       if (NULL != dict_lookup_var (qc->dict, qc->var_distance))
891                         {
892                           lex_error (lexer,
893                                      _("A variable called `%s' already exists."),
894                                      qc->var_distance);
895                           free (qc->var_distance);
896                           qc->var_distance = NULL;
897                           return false;
898                         }
899
900                       lex_get (lexer);
901
902                       if (!lex_force_match (lexer, T_RPAREN))
903                         return false;
904                     }
905                 }
906               else
907                 {
908                   lex_error (lexer, _("Expecting %s or %s."),
909                              "CLUSTER", "DISTANCE");
910                   return false;
911                 }
912             }
913         }
914       else if (lex_match_id (lexer, "CRITERIA"))
915         {
916           lex_match (lexer, T_EQUALS);
917           while (lex_token (lexer) != T_ENDCMD
918                  && lex_token (lexer) != T_SLASH)
919             {
920               if (lex_match_id (lexer, "CLUSTERS"))
921                 {
922                   if (lex_force_match (lexer, T_LPAREN) &&
923                       lex_force_int_range (lexer, "CLUSTERS", 1, INT_MAX))
924                     {
925                       qc->ngroups = lex_integer (lexer);
926                       lex_get (lexer);
927                       if (!lex_force_match (lexer, T_RPAREN))
928                         return false;
929                     }
930                 }
931               else if (lex_match_id (lexer, "CONVERGE"))
932                 {
933                   if (lex_force_match (lexer, T_LPAREN) &&
934                       lex_force_num (lexer))
935                     {
936                       qc->epsilon = lex_number (lexer);
937                       if (qc->epsilon <= 0)
938                         {
939                           lex_error (lexer, _("The convergence criterion must be positive"));
940                           return false;
941                         }
942                       lex_get (lexer);
943                       if (!lex_force_match (lexer, T_RPAREN))
944                         return false;
945                     }
946                 }
947               else if (lex_match_id (lexer, "MXITER"))
948                 {
949                   if (lex_force_match (lexer, T_LPAREN) &&
950                       lex_force_int_range (lexer, "MXITER", 1, INT_MAX))
951                     {
952                       qc->maxiter = lex_integer (lexer);
953                       lex_get (lexer);
954                       if (!lex_force_match (lexer, T_RPAREN))
955                         return false;
956                     }
957                 }
958               else if (lex_match_id (lexer, "NOINITIAL"))
959                 {
960                   qc->no_initial = true;
961                 }
962               else if (lex_match_id (lexer, "NOUPDATE"))
963                 {
964                   qc->no_update = true;
965                 }
966               else
967                 {
968                   lex_error (lexer, NULL);
969                   return false;
970                 }
971             }
972         }
973       else
974         {
975           lex_error (lexer, NULL);
976           return false;
977         }
978     }
979   return true;
980 }
981
982 int
983 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
984 {
985   struct qc qc;
986   struct Kmeans *kmeans;
987   bool ok;
988   memset (&qc, 0, sizeof qc);
989   qc.dataset = ds;
990   qc.dict =  dataset_dict (ds);
991   qc.ngroups = 2;
992   qc.maxiter = 10;
993   qc.epsilon = DBL_EPSILON;
994   qc.missing_type = MISS_LISTWISE;
995   qc.exclude = MV_ANY;
996
997
998   if (!quick_cluster_parse (lexer, &qc))
999     goto error;
1000
1001   qc.wv = dict_get_weight (qc.dict);
1002
1003   {
1004     struct casereader *group;
1005     struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), qc.dict);
1006
1007     while (casegrouper_get_next_group (grouper, &group))
1008       {
1009         if (qc.missing_type == MISS_LISTWISE)
1010           {
1011             group = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
1012                                                        qc.exclude,
1013                                                        NULL,  NULL);
1014           }
1015
1016         kmeans = kmeans_create (&qc);
1017         kmeans_cluster (kmeans, group, &qc);
1018         quick_cluster_show_results (kmeans, group, &qc);
1019         kmeans_destroy (kmeans);
1020         casereader_destroy (group);
1021       }
1022     ok = casegrouper_destroy (grouper);
1023   }
1024   ok = proc_commit (ds) && ok;
1025
1026
1027   /* If requested, set a transformation to append the cluster and
1028      distance values to the current dataset.  */
1029   if (qc.save_trans_data)
1030     {
1031       struct save_trans_data *std = qc.save_trans_data;
1032       std->appending_reader = casewriter_make_reader (std->writer);
1033       std->writer = NULL;
1034
1035       if (qc.save_values & SAVE_MEMBERSHIP)
1036         {
1037           /* Invent a variable name if necessary.  */
1038           int idx = 0;
1039           struct string name;
1040           ds_init_empty (&name);
1041           while (qc.var_membership == NULL)
1042             {
1043               ds_clear (&name);
1044               ds_put_format (&name, "QCL_%d", idx++);
1045
1046               if (!dict_lookup_var (qc.dict, ds_cstr (&name)))
1047                 {
1048                   qc.var_membership = strdup (ds_cstr (&name));
1049                   break;
1050                 }
1051             }
1052           ds_destroy (&name);
1053
1054           std->membership = dict_create_var_assert (qc.dict, qc.var_membership, 0);
1055         }
1056
1057       if (qc.save_values & SAVE_DISTANCE)
1058         {
1059           /* Invent a variable name if necessary.  */
1060           int idx = 0;
1061           struct string name;
1062           ds_init_empty (&name);
1063           while (qc.var_distance == NULL)
1064             {
1065               ds_clear (&name);
1066               ds_put_format (&name, "QCL_%d", idx++);
1067
1068               if (!dict_lookup_var (qc.dict, ds_cstr (&name)))
1069                 {
1070                   qc.var_distance = strdup (ds_cstr (&name));
1071                   break;
1072                 }
1073             }
1074           ds_destroy (&name);
1075
1076           std->distance = dict_create_var_assert (qc.dict, qc.var_distance, 0);
1077         }
1078
1079       static const struct trns_class trns_class = {
1080         .name = "QUICK CLUSTER",
1081         .execute = save_trans_func,
1082         .destroy = save_trans_destroy,
1083       };
1084       add_transformation (qc.dataset, &trns_class, std);
1085     }
1086
1087   free (qc.var_distance);
1088   free (qc.var_membership);
1089   free (qc.vars);
1090   return (ok);
1091
1092  error:
1093   free (qc.var_distance);
1094   free (qc.var_membership);
1095   free (qc.vars);
1096   return CMD_FAILURE;
1097 }