946181b01a6b22e3061a4b26c961ed6f88236bd7
[pspp] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011, 2012 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
23 #include <math.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49 enum missing_type
50   {
51     MISS_LISTWISE,
52     MISS_PAIRWISE,
53   };
54
55
56 struct qc
57 {
58   const struct variable **vars;
59   size_t n_vars;
60
61   int ngroups;                  /* Number of group. (Given by the user) */
62   int maxiter;                  /* Maximum iterations (Given by the user) */
63
64   const struct variable *wv;    /* Weighting variable. */
65
66   enum missing_type missing_type;
67   enum mv_class exclude;
68 };
69
70 /* Holds all of the information for the functions.  int n, holds the number of
71    observation and its default value is -1.  We set it in
72    kmeans_recalculate_centers in first invocation. */
73 struct Kmeans
74 {
75   gsl_matrix *centers;          /* Centers for groups. */
76   gsl_vector_long *num_elements_groups;
77
78   casenumber n;                 /* Number of observations (default -1). */
79
80   int lastiter;                 /* Iteration where it found the solution. */
81   int trials;                   /* If not convergence, how many times has
82                                    clustering done. */
83   gsl_matrix *initial_centers;  /* Initial random centers. */
84
85   gsl_permutation *group_order; /* Group order for reporting. */
86   struct caseproto *proto;
87   struct casereader *index_rdr; /* Group ids for each case. */
88 };
89
90 static struct Kmeans *kmeans_create (const struct qc *qc);
91
92 static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc);
93
94 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
95
96 static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
97
98 static int
99 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
100
101 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
102
103 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
104
105 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
106
107 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
108
109 static void quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *);
110
111 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
112
113 static void kmeans_destroy (struct Kmeans *kmeans);
114
115 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
116    variables 'variables', number of cases 'n', number of variables 'm', number
117    of clusters and amount of maximum iterations. */
118 static struct Kmeans *
119 kmeans_create (const struct qc *qc)
120 {
121   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
122   kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
123   kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
124   kmeans->n = 0;
125   kmeans->lastiter = 0;
126   kmeans->trials = 0;
127   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
128   kmeans->initial_centers = NULL;
129
130   kmeans->proto = caseproto_create ();
131   kmeans->proto = caseproto_add_width (kmeans->proto, 0);
132   kmeans->index_rdr = NULL;
133   return (kmeans);
134 }
135
136 static void
137 kmeans_destroy (struct Kmeans *kmeans)
138 {
139   gsl_matrix_free (kmeans->centers);
140   gsl_matrix_free (kmeans->initial_centers);
141
142   gsl_vector_long_free (kmeans->num_elements_groups);
143
144   gsl_permutation_free (kmeans->group_order);
145
146   caseproto_unref (kmeans->proto);
147
148   casereader_destroy (kmeans->index_rdr);
149
150   free (kmeans);
151 }
152
153 /* Creates random centers using randomly selected cases from the data. */
154 static void
155 kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc)
156 {
157   int i, j;
158   for (i = 0; i < qc->ngroups; i++)
159     {
160       for (j = 0; j < qc->n_vars; j++)
161         {
162           if (i == j)
163             {
164               gsl_matrix_set (kmeans->centers, i, j, 1);
165             }
166           else
167             {
168               gsl_matrix_set (kmeans->centers, i, j, 0);
169             }
170         }
171     }
172   /* If it is the first iteration, the variable kmeans->initial_centers is NULL
173      and it is created once for reporting issues. In SPSS, initial centers are
174      shown in the reports but in PSPP it is not shown now. I am leaving it
175      here. */
176   if (!kmeans->initial_centers)
177     {
178       kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
179       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
180     }
181 }
182
183 static int
184 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
185 {
186   int result = -1;
187   int i, j;
188   double mindist = INFINITY;
189   for (i = 0; i < qc->ngroups; i++)
190     {
191       double dist = 0;
192       for (j = 0; j < qc->n_vars; j++)
193         {
194           const union value *val = case_data (c, qc->vars[j]);
195           if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
196             continue;
197
198           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
199         }
200       if (dist < mindist)
201         {
202           mindist = dist;
203           result = i;
204         }
205     }
206   return (result);
207 }
208
209 /* Re-calculate the cluster centers. */
210 static void
211 kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
212 {
213   casenumber i = 0;
214   int v, j;
215   struct ccase *c;
216
217   struct casereader *cs = casereader_clone (reader);
218   struct casereader *cs_index = casereader_clone (kmeans->index_rdr);
219
220   gsl_matrix_set_all (kmeans->centers, 0.0);
221   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
222     {
223       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
224       struct ccase *c_index = casereader_read (cs_index);
225       int index = case_data_idx (c_index, 0)->f;
226       for (v = 0; v < qc->n_vars; ++v)
227         {
228           const union value *val = case_data (c, qc->vars[v]);
229           double x = val->f * weight;
230           double curval;
231
232           if ( var_is_value_missing (qc->vars[v], val, qc->exclude))
233             continue;
234
235           curval = gsl_matrix_get (kmeans->centers, index, v);
236           gsl_matrix_set (kmeans->centers, index, v, curval + x);
237         }
238       i++;
239       case_unref (c_index);
240     }
241   casereader_destroy (cs);
242   casereader_destroy (cs_index);
243
244   /* Getting number of cases */
245   if (kmeans->n == 0)
246     kmeans->n = i;
247
248   /* We got sum of each center but we need averages.
249      We are dividing centers to numobs. This may be inefficient and
250      we should check it again. */
251   for (i = 0; i < qc->ngroups; i++)
252     {
253       casenumber numobs = kmeans->num_elements_groups->data[i];
254       for (j = 0; j < qc->n_vars; j++)
255         {
256           if (numobs > 0)
257             {
258               double *x = gsl_matrix_ptr (kmeans->centers, i, j);
259               *x /= numobs;
260             }
261           else
262             {
263               gsl_matrix_set (kmeans->centers, i, j, 0);
264             }
265         }
266     }
267 }
268
269 /* The variable index in struct Kmeans holds integer values that represents the
270    current groups of cases.  index[n]=a shows the nth case is belong to ath
271    cluster.  This function calculates these indexes and returns the number of
272    different cases of the new and old index variables.  If last two index
273    variables are equal, there is no any enhancement of clustering. */
274 static int
275 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
276 {
277   int totaldiff = 0;
278   struct ccase *c;
279   struct casereader *cs = casereader_clone (reader);
280
281   /* A casewriter into which we will write the indexes. */
282   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
283
284   gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
285
286   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
287     {
288       /* A case to hold the new index. */
289       struct ccase *index_case_new = case_create (kmeans->proto);
290       int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
291       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
292       assert (bestindex < kmeans->num_elements_groups->size);
293       kmeans->num_elements_groups->data[bestindex] += weight;
294       if (kmeans->index_rdr)
295         {
296           /* A case from which the old index will be read. */
297           struct ccase *index_case_old = NULL;
298
299           /* Read the case from the index casereader. */
300           index_case_old = casereader_read (kmeans->index_rdr);
301
302           /* Set totaldiff, using the old_index. */
303           totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
304
305           /* We have no use for the old case anymore, so unref it. */
306           case_unref (index_case_old);
307         }
308       else
309         {
310           /* If this is the first run, then assume index is zero. */
311           totaldiff += bestindex;
312         }
313
314       /* Set the value of the new inde.x */
315       case_data_rw_idx (index_case_new, 0)->f = bestindex;
316
317       /* and write the new index to the casewriter */
318       casewriter_write (index_wtr, index_case_new);
319     }
320   casereader_destroy (cs);
321   /* We have now read through the entire index_rdr, so it's of no use
322      anymore. */
323   casereader_destroy (kmeans->index_rdr);
324
325   /* Convert the writer into a reader, ready for the next iteration to read */
326   kmeans->index_rdr = casewriter_make_reader (index_wtr);
327
328   return (totaldiff);
329 }
330
331 static void
332 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
333 {
334   gsl_vector *v = gsl_vector_alloc (qc->ngroups);
335   gsl_matrix_get_col (v, kmeans->centers, 0);
336   gsl_sort_vector_index (kmeans->group_order, v);
337   gsl_vector_free (v);
338 }
339
340 /* Main algorithm.
341    Does iterations, checks convergency. */
342 static void
343 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
344 {
345   int i;
346   bool redo;
347   int diffs;
348   bool show_warning1;
349
350   show_warning1 = true;
351 cluster:
352   redo = false;
353   kmeans_randomize_centers (kmeans, qc);
354   for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
355        kmeans->lastiter++)
356     {
357       diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
358       kmeans_recalculate_centers (kmeans, reader, qc);
359       if (show_warning1 && qc->ngroups > kmeans->n)
360         {
361           msg (MW, _("Number of clusters may not be larger than the number "
362                      "of cases."));
363           show_warning1 = false;
364         }
365       if (diffs == 0)
366         break;
367     }
368
369   for (i = 0; i < qc->ngroups; i++)
370     {
371       if (kmeans->num_elements_groups->data[i] == 0)
372         {
373           kmeans->trials++;
374           if (kmeans->trials >= 3)
375             break;
376           redo = true;
377           break;
378         }
379     }
380   if (redo)
381     goto cluster;
382
383 }
384
385 /* Reports centers of clusters.
386    Initial parameter is optional for future use.
387    If initial is true, initial cluster centers are reported.  Otherwise,
388    resulted centers are reported. */
389 static void
390 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
391 {
392   struct tab_table *t;
393   int nc, nr, currow;
394   int i, j;
395   nc = qc->ngroups + 1;
396   nr = qc->n_vars + 4;
397   t = tab_create (nc, nr);
398   tab_headers (t, 0, nc - 1, 0, 1);
399   currow = 0;
400   if (!initial)
401     {
402       tab_title (t, _("Final Cluster Centers"));
403     }
404   else
405     {
406       tab_title (t, _("Initial Cluster Centers"));
407     }
408   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
409   tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
410   tab_hline (t, TAL_1, 1, nc - 1, 2);
411   currow += 2;
412
413   for (i = 0; i < qc->ngroups; i++)
414     {
415       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
416     }
417   currow++;
418   tab_hline (t, TAL_1, 1, nc - 1, currow);
419   currow++;
420   for (i = 0; i < qc->n_vars; i++)
421     {
422       tab_text (t, 0, currow + i, TAB_LEFT,
423                 var_to_string (qc->vars[i]));
424     }
425
426   for (i = 0; i < qc->ngroups; i++)
427     {
428       for (j = 0; j < qc->n_vars; j++)
429         {
430           if (!initial)
431             {
432               tab_double (t, i + 1, j + 4, TAB_CENTER,
433                           gsl_matrix_get (kmeans->centers,
434                                           kmeans->group_order->data[i], j),
435                           var_get_print_format (qc->vars[j]));
436             }
437           else
438             {
439               tab_double (t, i + 1, j + 4, TAB_CENTER,
440                           gsl_matrix_get (kmeans->initial_centers,
441                                           kmeans->group_order->data[i], j),
442                           var_get_print_format (qc->vars[j]));
443             }
444         }
445     }
446   tab_submit (t);
447 }
448
449 /* Reports number of cases of each single cluster. */
450 static void
451 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
452 {
453   struct tab_table *t;
454   int nc, nr;
455   int i, numelem;
456   long int total;
457   nc = 3;
458   nr = qc->ngroups + 1;
459   t = tab_create (nc, nr);
460   tab_headers (t, 0, nc - 1, 0, 0);
461   tab_title (t, _("Number of Cases in each Cluster"));
462   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
463   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
464
465   total = 0;
466   for (i = 0; i < qc->ngroups; i++)
467     {
468       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
469       numelem =
470         kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
471       tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
472       total += numelem;
473     }
474
475   tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
476   tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
477   tab_submit (t);
478 }
479
480 /* Reports. */
481 static void
482 quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *qc)
483 {
484   kmeans_order_groups (kmeans, qc);
485   /* Uncomment the line below for reporting initial centers. */
486   /* quick_cluster_show_centers (kmeans, true); */
487   quick_cluster_show_centers (kmeans, false, qc);
488   quick_cluster_show_number_cases (kmeans, qc);
489 }
490
491 int
492 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
493 {
494   struct qc qc;
495   struct Kmeans *kmeans;
496   bool ok;
497   const struct dictionary *dict = dataset_dict (ds);
498   qc.ngroups = 2;
499   qc.maxiter = 2;
500   qc.missing_type = MISS_LISTWISE;
501   qc.exclude = MV_ANY;
502
503   if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
504                               PV_NO_DUPLICATE | PV_NUMERIC))
505     {
506       return (CMD_FAILURE);
507     }
508
509   while (lex_token (lexer) != T_ENDCMD)
510     {
511       lex_match (lexer, T_SLASH);
512
513       if (lex_match_id (lexer, "MISSING"))
514         {
515           lex_match (lexer, T_EQUALS);
516           while (lex_token (lexer) != T_ENDCMD
517                  && lex_token (lexer) != T_SLASH)
518             {
519               if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT"))
520                 {
521                   qc.missing_type = MISS_LISTWISE;
522                 }
523               else if (lex_match_id (lexer, "PAIRWISE"))
524                 {
525                   qc.missing_type = MISS_PAIRWISE;
526                 }
527               else if (lex_match_id (lexer, "INCLUDE"))
528                 {
529                   qc.exclude = MV_SYSTEM;
530                 }
531               else if (lex_match_id (lexer, "EXCLUDE"))
532                 {
533                   qc.exclude = MV_ANY;
534                 }
535               else
536                 goto error;
537             }     
538         }
539       else if (lex_match_id (lexer, "CRITERIA"))
540         {
541           lex_match (lexer, T_EQUALS);
542           while (lex_token (lexer) != T_ENDCMD
543                  && lex_token (lexer) != T_SLASH)
544             {
545               if (lex_match_id (lexer, "CLUSTERS"))
546                 {
547                   if (lex_force_match (lexer, T_LPAREN))
548                     {
549                       lex_force_int (lexer);
550                       qc.ngroups = lex_integer (lexer);
551                       if (qc.ngroups <= 0)
552                         {
553                           lex_error (lexer, _("The number of clusters must be positive"));
554                           goto error;
555                         }
556                       lex_get (lexer);
557                       lex_force_match (lexer, T_RPAREN);
558                     }
559                 }
560               else if (lex_match_id (lexer, "MXITER"))
561                 {
562                   if (lex_force_match (lexer, T_LPAREN))
563                     {
564                       lex_force_int (lexer);
565                       qc.maxiter = lex_integer (lexer);
566                       if (qc.maxiter <= 0)
567                         {
568                           lex_error (lexer, _("The number of iterations must be positive"));
569                           goto error;
570                         }
571                       lex_get (lexer);
572                       lex_force_match (lexer, T_RPAREN);
573                     }
574                 }
575               else
576                 goto error;
577             }
578         }
579     }
580
581   qc.wv = dict_get_weight (dict);
582
583   {
584     struct casereader *group;
585     struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
586
587     while (casegrouper_get_next_group (grouper, &group))
588       {
589         if ( qc.missing_type == MISS_LISTWISE )
590           {
591             group  = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
592                                                      qc.exclude,
593                                                      NULL,  NULL);
594           }
595
596         kmeans = kmeans_create (&qc);
597         kmeans_cluster (kmeans, group, &qc);
598         quick_cluster_show_results (kmeans, &qc);
599         kmeans_destroy (kmeans);
600         casereader_destroy (group);
601       }
602     ok = casegrouper_destroy (grouper);
603   }
604   ok = proc_commit (ds) && ok;
605
606   free (qc.vars);
607
608   return (ok);
609
610  error:
611   free (qc.vars);
612   return CMD_FAILURE;
613 }