QUICK CLUSTER: New subcommand: /PRINT
[pspp] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011, 2012, 2015 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
23 #include <math.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49 enum missing_type
50   {
51     MISS_LISTWISE,
52     MISS_PAIRWISE,
53   };
54
55
56 struct qc
57 {
58   const struct variable **vars;
59   size_t n_vars;
60
61   int ngroups;                  /* Number of group. (Given by the user) */
62   int maxiter;                  /* Maximum iterations (Given by the user) */
63   int print_cluster_membership; /* true => print membership */
64   int print_initial_clusters;   /* true => print initial cluster */
65
66   const struct variable *wv;    /* Weighting variable. */
67
68   enum missing_type missing_type;
69   enum mv_class exclude;
70 };
71
72 /* Holds all of the information for the functions.  int n, holds the number of
73    observation and its default value is -1.  We set it in
74    kmeans_recalculate_centers in first invocation. */
75 struct Kmeans
76 {
77   gsl_matrix *centers;          /* Centers for groups. */
78   gsl_vector_long *num_elements_groups;
79
80   casenumber n;                 /* Number of observations (default -1). */
81
82   int lastiter;                 /* Iteration where it found the solution. */
83   int trials;                   /* If not convergence, how many times has
84                                    clustering done. */
85   gsl_matrix *initial_centers;  /* Initial random centers. */
86
87   gsl_permutation *group_order; /* Group order for reporting. */
88   struct caseproto *proto;
89   struct casereader *index_rdr; /* Group ids for each case. */
90 };
91
92 static struct Kmeans *kmeans_create (const struct qc *qc);
93
94 static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc);
95
96 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
97
98 static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
99
100 static int
101 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
102
103 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
104
105 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
106
107 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
108
109 static void quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
110
111 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
112
113 static void quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
114
115 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
116
117 static void kmeans_destroy (struct Kmeans *kmeans);
118
119 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
120    variables 'variables', number of cases 'n', number of variables 'm', number
121    of clusters and amount of maximum iterations. */
122 static struct Kmeans *
123 kmeans_create (const struct qc *qc)
124 {
125   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
126   kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
127   kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
128   kmeans->n = 0;
129   kmeans->lastiter = 0;
130   kmeans->trials = 0;
131   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
132   kmeans->initial_centers = NULL;
133
134   kmeans->proto = caseproto_create ();
135   kmeans->proto = caseproto_add_width (kmeans->proto, 0);
136   kmeans->index_rdr = NULL;
137   return (kmeans);
138 }
139
140 static void
141 kmeans_destroy (struct Kmeans *kmeans)
142 {
143   gsl_matrix_free (kmeans->centers);
144   gsl_matrix_free (kmeans->initial_centers);
145
146   gsl_vector_long_free (kmeans->num_elements_groups);
147
148   gsl_permutation_free (kmeans->group_order);
149
150   caseproto_unref (kmeans->proto);
151
152   casereader_destroy (kmeans->index_rdr);
153
154   free (kmeans);
155 }
156
157 /* Creates random centers using randomly selected cases from the data. */
158 static void
159 kmeans_randomize_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
160 {
161   int i, j;
162   for (i = 0; i < qc->ngroups; i++)
163     {
164       for (j = 0; j < qc->n_vars; j++)
165         {
166           if (i == j)
167             {
168               gsl_matrix_set (kmeans->centers, i, j, 1);
169             }
170           else
171             {
172               gsl_matrix_set (kmeans->centers, i, j, 0);
173             }
174         }
175     }
176   /* If it is the first iteration, the variable kmeans->initial_centers is NULL
177      and it is created once for reporting issues. In SPSS, initial centers are
178      shown in the reports but in PSPP it is not shown now. I am leaving it
179      here. */
180   if (!kmeans->initial_centers)
181     {
182       kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
183       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
184     }
185 }
186
187 static int
188 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
189 {
190   int result = -1;
191   int i, j;
192   double mindist = INFINITY;
193   for (i = 0; i < qc->ngroups; i++)
194     {
195       double dist = 0;
196       for (j = 0; j < qc->n_vars; j++)
197         {
198           const union value *val = case_data (c, qc->vars[j]);
199           if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
200             continue;
201
202           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
203         }
204       if (dist < mindist)
205         {
206           mindist = dist;
207           result = i;
208         }
209     }
210   return (result);
211 }
212
213 /* Re-calculate the cluster centers. */
214 static void
215 kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
216 {
217   casenumber i = 0;
218   int v, j;
219   struct ccase *c;
220
221   struct casereader *cs = casereader_clone (reader);
222   struct casereader *cs_index = casereader_clone (kmeans->index_rdr);
223
224   gsl_matrix_set_all (kmeans->centers, 0.0);
225   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
226     {
227       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
228       struct ccase *c_index = casereader_read (cs_index);
229       int index = case_data_idx (c_index, 0)->f;
230       for (v = 0; v < qc->n_vars; ++v)
231         {
232           const union value *val = case_data (c, qc->vars[v]);
233           double x = val->f * weight;
234           double curval;
235
236           if ( var_is_value_missing (qc->vars[v], val, qc->exclude))
237             continue;
238
239           curval = gsl_matrix_get (kmeans->centers, index, v);
240           gsl_matrix_set (kmeans->centers, index, v, curval + x);
241         }
242       i++;
243       case_unref (c_index);
244     }
245   casereader_destroy (cs);
246   casereader_destroy (cs_index);
247
248   /* Getting number of cases */
249   if (kmeans->n == 0)
250     kmeans->n = i;
251
252   /* We got sum of each center but we need averages.
253      We are dividing centers to numobs. This may be inefficient and
254      we should check it again. */
255   for (i = 0; i < qc->ngroups; i++)
256     {
257       casenumber numobs = kmeans->num_elements_groups->data[i];
258       for (j = 0; j < qc->n_vars; j++)
259         {
260           if (numobs > 0)
261             {
262               double *x = gsl_matrix_ptr (kmeans->centers, i, j);
263               *x /= numobs;
264             }
265           else
266             {
267               gsl_matrix_set (kmeans->centers, i, j, 0);
268             }
269         }
270     }
271 }
272
273 /* The variable index in struct Kmeans holds integer values that represents the
274    current groups of cases.  index[n]=a shows the nth case is belong to ath
275    cluster.  This function calculates these indexes and returns the number of
276    different cases of the new and old index variables.  If last two index
277    variables are equal, there is no any enhancement of clustering. */
278 static int
279 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
280 {
281   int totaldiff = 0;
282   struct ccase *c;
283   struct casereader *cs = casereader_clone (reader);
284
285   /* A casewriter into which we will write the indexes. */
286   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
287
288   gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
289
290   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
291     {
292       /* A case to hold the new index. */
293       struct ccase *index_case_new = case_create (kmeans->proto);
294       int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
295       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
296       assert (bestindex < kmeans->num_elements_groups->size);
297       kmeans->num_elements_groups->data[bestindex] += weight;
298       if (kmeans->index_rdr)
299         {
300           /* A case from which the old index will be read. */
301           struct ccase *index_case_old = NULL;
302
303           /* Read the case from the index casereader. */
304           index_case_old = casereader_read (kmeans->index_rdr);
305
306           /* Set totaldiff, using the old_index. */
307           totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
308
309           /* We have no use for the old case anymore, so unref it. */
310           case_unref (index_case_old);
311         }
312       else
313         {
314           /* If this is the first run, then assume index is zero. */
315           totaldiff += bestindex;
316         }
317
318       /* Set the value of the new inde.x */
319       case_data_rw_idx (index_case_new, 0)->f = bestindex;
320
321       /* and write the new index to the casewriter */
322       casewriter_write (index_wtr, index_case_new);
323     }
324   casereader_destroy (cs);
325   /* We have now read through the entire index_rdr, so it's of no use
326      anymore. */
327   casereader_destroy (kmeans->index_rdr);
328
329   /* Convert the writer into a reader, ready for the next iteration to read */
330   kmeans->index_rdr = casewriter_make_reader (index_wtr);
331
332   return (totaldiff);
333 }
334
335 static void
336 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
337 {
338   gsl_vector *v = gsl_vector_alloc (qc->ngroups);
339   gsl_matrix_get_col (v, kmeans->centers, 0);
340   gsl_sort_vector_index (kmeans->group_order, v);
341   gsl_vector_free (v);
342 }
343
344 /* Main algorithm.
345    Does iterations, checks convergency. */
346 static void
347 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
348 {
349   int i;
350   bool redo;
351   int diffs;
352   bool show_warning1;
353   int redo_count = 0;
354
355   show_warning1 = true;
356 cluster:
357   redo = false;
358   kmeans_randomize_centers (kmeans, reader, qc);
359   for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
360        kmeans->lastiter++)
361     {
362       diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
363       kmeans_recalculate_centers (kmeans, reader, qc);
364       if (show_warning1 && qc->ngroups > kmeans->n)
365         {
366           msg (MW, _("Number of clusters may not be larger than the number "
367                      "of cases."));
368           show_warning1 = false;
369         }
370       if (diffs == 0)
371         break;
372     }
373
374   for (i = 0; i < qc->ngroups; i++)
375     {
376       if (kmeans->num_elements_groups->data[i] == 0)
377         {
378           kmeans->trials++;
379           if (kmeans->trials >= 3)
380             break;
381           redo = true;
382           break;
383         }
384     }
385
386   if (redo)
387     {
388       redo_count++;
389       assert (redo_count < 10);
390       goto cluster;
391     }
392
393 }
394
395 /* Reports centers of clusters.
396    Initial parameter is optional for future use.
397    If initial is true, initial cluster centers are reported.  Otherwise,
398    resulted centers are reported. */
399 static void
400 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
401 {
402   struct tab_table *t;
403   int nc, nr, currow;
404   int i, j;
405   nc = qc->ngroups + 1;
406   nr = qc->n_vars + 4;
407   t = tab_create (nc, nr);
408   tab_headers (t, 0, nc - 1, 0, 1);
409   currow = 0;
410   if (!initial)
411     {
412       tab_title (t, _("Final Cluster Centers"));
413     }
414   else
415     {
416       tab_title (t, _("Initial Cluster Centers"));
417     }
418   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
419   tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
420   tab_hline (t, TAL_1, 1, nc - 1, 2);
421   currow += 2;
422
423   for (i = 0; i < qc->ngroups; i++)
424     {
425       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
426     }
427   currow++;
428   tab_hline (t, TAL_1, 1, nc - 1, currow);
429   currow++;
430   for (i = 0; i < qc->n_vars; i++)
431     {
432       tab_text (t, 0, currow + i, TAB_LEFT,
433                 var_to_string (qc->vars[i]));
434     }
435
436   for (i = 0; i < qc->ngroups; i++)
437     {
438       for (j = 0; j < qc->n_vars; j++)
439         {
440           if (!initial)
441             {
442               tab_double (t, i + 1, j + 4, TAB_CENTER,
443                           gsl_matrix_get (kmeans->centers,
444                                           kmeans->group_order->data[i], j),
445                           var_get_print_format (qc->vars[j]), RC_OTHER);
446             }
447           else
448             {
449               tab_double (t, i + 1, j + 4, TAB_CENTER,
450                           gsl_matrix_get (kmeans->initial_centers,
451                                           kmeans->group_order->data[i], j),
452                           var_get_print_format (qc->vars[j]), RC_OTHER);
453             }
454         }
455     }
456   tab_submit (t);
457 }
458
459 /* Reports cluster membership for each case. */
460 static void
461 quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
462 {
463   struct tab_table *t;
464   int nc, nr;
465   int i, clust; 
466   struct ccase *c;
467   struct casereader *cs = casereader_clone (reader);
468   nc = 2;
469   nr = kmeans->n + 1;
470   t = tab_create (nc, nr);
471   tab_headers (t, 0, nc - 1, 0, 0);
472   tab_title (t, _("Cluster Membership"));
473   tab_text (t, 0, 0, TAB_CENTER, _("Case Number"));
474   tab_text (t, 1, 0, TAB_CENTER, _("Cluster"));
475   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
476
477   for (i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c))
478     {
479       assert (i < kmeans->n);
480       clust = kmeans_get_nearest_group (kmeans, c, qc);
481       clust = kmeans->group_order->data[clust];
482       tab_text_format (t, 0, i+1, TAB_CENTER, "%d", (i + 1));
483       tab_text_format (t, 1, i+1, TAB_CENTER, "%d", (clust + 1));
484     }
485   assert (i == kmeans->n);
486   tab_submit (t);
487   casereader_destroy (cs);
488 }
489
490
491 /* Reports number of cases of each single cluster. */
492 static void
493 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
494 {
495   struct tab_table *t;
496   int nc, nr;
497   int i, numelem;
498   long int total;
499   nc = 3;
500   nr = qc->ngroups + 1;
501   t = tab_create (nc, nr);
502   tab_headers (t, 0, nc - 1, 0, 0);
503   tab_title (t, _("Number of Cases in each Cluster"));
504   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
505   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
506
507   total = 0;
508   for (i = 0; i < qc->ngroups; i++)
509     {
510       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
511       numelem =
512         kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
513       tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
514       total += numelem;
515     }
516
517   tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
518   tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
519   tab_submit (t);
520 }
521
522 /* Reports. */
523 static void
524 quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
525 {
526   kmeans_order_groups (kmeans, qc); /* what does this do? */
527   if( qc->print_initial_clusters )
528     quick_cluster_show_centers (kmeans, true, qc);
529   quick_cluster_show_centers (kmeans, false, qc);
530   quick_cluster_show_number_cases (kmeans, qc);
531   if( qc->print_cluster_membership )
532      quick_cluster_show_membership(kmeans, reader, qc);
533 }
534
535 int
536 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
537 {
538   struct qc qc;
539   struct Kmeans *kmeans;
540   bool ok;
541   const struct dictionary *dict = dataset_dict (ds);
542   qc.ngroups = 2;
543   qc.maxiter = 2;
544   qc.missing_type = MISS_LISTWISE;
545   qc.exclude = MV_ANY;
546   qc.print_cluster_membership = false; /* default = do not output case cluster membership */
547   qc.print_initial_clusters = false;   /* default = do not print initial clusters */
548
549   if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
550                               PV_NO_DUPLICATE | PV_NUMERIC))
551     {
552       return (CMD_FAILURE);
553     }
554
555   while (lex_token (lexer) != T_ENDCMD)
556     {
557       lex_match (lexer, T_SLASH);
558
559       if (lex_match_id (lexer, "MISSING"))
560         {
561           lex_match (lexer, T_EQUALS);
562           while (lex_token (lexer) != T_ENDCMD
563                  && lex_token (lexer) != T_SLASH)
564             {
565               if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT"))
566                 {
567                   qc.missing_type = MISS_LISTWISE;
568                 }
569               else if (lex_match_id (lexer, "PAIRWISE"))
570                 {
571                   qc.missing_type = MISS_PAIRWISE;
572                 }
573               else if (lex_match_id (lexer, "INCLUDE"))
574                 {
575                   qc.exclude = MV_SYSTEM;
576                 }
577               else if (lex_match_id (lexer, "EXCLUDE"))
578                 {
579                   qc.exclude = MV_ANY;
580                 }
581               else
582                 goto error;
583             }     
584         }
585       else if (lex_match_id (lexer, "PRINT"))
586         {
587           lex_match (lexer, T_EQUALS);
588           while (lex_token (lexer) != T_ENDCMD
589                  && lex_token (lexer) != T_SLASH)
590             {
591               if (lex_match_id (lexer, "CLUSTER"))
592                 qc.print_cluster_membership = true;
593               else if (lex_match_id (lexer, "INITIAL"))
594                 qc.print_initial_clusters = true;
595               else
596                 goto error;
597             }
598         }
599       else if (lex_match_id (lexer, "CRITERIA"))
600         {
601           lex_match (lexer, T_EQUALS);
602           while (lex_token (lexer) != T_ENDCMD
603                  && lex_token (lexer) != T_SLASH)
604             {
605               if (lex_match_id (lexer, "CLUSTERS"))
606                 {
607                   if (lex_force_match (lexer, T_LPAREN))
608                     {
609                       lex_force_int (lexer);
610                       qc.ngroups = lex_integer (lexer);
611                       if (qc.ngroups <= 0)
612                         {
613                           lex_error (lexer, _("The number of clusters must be positive"));
614                           goto error;
615                         }
616                       lex_get (lexer);
617                       lex_force_match (lexer, T_RPAREN);
618                     }
619                 }
620               else if (lex_match_id (lexer, "MXITER"))
621                 {
622                   if (lex_force_match (lexer, T_LPAREN))
623                     {
624                       lex_force_int (lexer);
625                       qc.maxiter = lex_integer (lexer);
626                       if (qc.maxiter <= 0)
627                         {
628                           lex_error (lexer, _("The number of iterations must be positive"));
629                           goto error;
630                         }
631                       lex_get (lexer);
632                       lex_force_match (lexer, T_RPAREN);
633                     }
634                 }
635               else
636                 goto error;
637             }
638         }
639       else
640         {
641           lex_error (lexer, NULL);
642           goto error;
643         }
644     }
645
646   qc.wv = dict_get_weight (dict);
647
648   {
649     struct casereader *group;
650     struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
651
652     while (casegrouper_get_next_group (grouper, &group))
653       {
654         if ( qc.missing_type == MISS_LISTWISE )
655           {
656             group  = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
657                                                      qc.exclude,
658                                                      NULL,  NULL);
659           }
660
661         kmeans = kmeans_create (&qc);
662         kmeans_cluster (kmeans, group, &qc);
663         quick_cluster_show_results (kmeans, group, &qc);
664         kmeans_destroy (kmeans);
665         casereader_destroy (group);
666       }
667     ok = casegrouper_destroy (grouper);
668   }
669   ok = proc_commit (ds) && ok;
670
671   free (qc.vars);
672
673   return (ok);
674
675  error:
676   free (qc.vars);
677   return CMD_FAILURE;
678 }