QUICK CLUSTER: Adjust comment style.
[pspp-builds.git] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
23 #include <math.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49 /* Holds all of the information for the functions.  int n, holds the number of
50    observation and its default value is -1.  We set it in
51    kmeans_recalculate_centers in first invocation. */
52 struct Kmeans
53 {
54   gsl_matrix *centers;          /* Centers for groups. */
55   gsl_vector_long *num_elements_groups;
56   int ngroups;                  /* Number of group. (Given by the user) */
57   casenumber n;                 /* Number of observations (default -1). */
58   int m;                        /* Number of variables. (Given by the user) */
59   int maxiter;                  /* Maximum iterations (Given by the user) */
60   int lastiter;                 /* Iteration where it found the solution. */
61   int trials;                   /* If not convergence, how many times has
62                                    clustering done. */
63   gsl_matrix *initial_centers;  /* Initial random centers. */
64   const struct variable **variables;
65   gsl_permutation *group_order; /* Group order for reporting. */
66   struct casereader *original_casereader;
67   struct caseproto *proto;
68   struct casereader *index_rdr; /* Group ids for each case. */
69   const struct variable *wv;    /* Weighting variable. */
70 };
71
72 static struct Kmeans *kmeans_create (struct casereader *cs,
73                                      const struct variable **variables,
74                                      int m, int ngroups, int maxiter);
75
76 static void kmeans_randomize_centers (struct Kmeans *kmeans);
77
78 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c);
79
80 static void kmeans_recalculate_centers (struct Kmeans *kmeans);
81
82 static int
83 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans);
84
85 static void kmeans_order_groups (struct Kmeans *kmeans);
86
87 static void kmeans_cluster (struct Kmeans *kmeans);
88
89 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial);
90
91 static void quick_cluster_show_number_cases (struct Kmeans *kmeans);
92
93 static void quick_cluster_show_results (struct Kmeans *kmeans);
94
95 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
96
97 static void kmeans_destroy (struct Kmeans *kmeans);
98
99 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
100    variables 'variables', number of cases 'n', number of variables 'm', number
101    of clusters and amount of maximum iterations. */
102 static struct Kmeans *
103 kmeans_create (struct casereader *cs, const struct variable **variables,
104                int m, int ngroups, int maxiter)
105 {
106   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
107   kmeans->centers = gsl_matrix_alloc (ngroups, m);
108   kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
109   kmeans->ngroups = ngroups;
110   kmeans->n = 0;
111   kmeans->m = m;
112   kmeans->maxiter = maxiter;
113   kmeans->lastiter = 0;
114   kmeans->trials = 0;
115   kmeans->variables = variables;
116   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
117   kmeans->original_casereader = cs;
118   kmeans->initial_centers = NULL;
119
120   kmeans->proto = caseproto_create ();
121   kmeans->proto = caseproto_add_width (kmeans->proto, 0);
122   kmeans->index_rdr = NULL;
123   return (kmeans);
124 }
125
126 static void
127 kmeans_destroy (struct Kmeans *kmeans)
128 {
129   gsl_matrix_free (kmeans->centers);
130   gsl_matrix_free (kmeans->initial_centers);
131
132   gsl_vector_long_free (kmeans->num_elements_groups);
133
134   gsl_permutation_free (kmeans->group_order);
135
136   caseproto_unref (kmeans->proto);
137
138   /*
139      These reader and writer were already destroyed.
140      free (kmeans->original_casereader);
141      free (kmeans->index_rdr);
142    */
143
144   free (kmeans);
145 }
146
147 /* Creates random centers using randomly selected cases from the data. */
148 static void
149 kmeans_randomize_centers (struct Kmeans *kmeans)
150 {
151   int i, j;
152   for (i = 0; i < kmeans->ngroups; i++)
153     {
154       for (j = 0; j < kmeans->m; j++)
155         {
156           if (i == j)
157             {
158               gsl_matrix_set (kmeans->centers, i, j, 1);
159             }
160           else
161             {
162               gsl_matrix_set (kmeans->centers, i, j, 0);
163             }
164         }
165     }
166   /* If it is the first iteration, the variable kmeans->initial_centers is NULL
167      and it is created once for reporting issues. In SPSS, initial centers are
168      shown in the reports but in PSPP it is not shown now. I am leaving it
169      here. */
170   if (!kmeans->initial_centers)
171     {
172       kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
173       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
174     }
175 }
176
177 static int
178 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
179 {
180   int result = -1;
181   double x;
182   int i, j;
183   double dist;
184   double mindist;
185   mindist = INFINITY;
186   for (i = 0; i < kmeans->ngroups; i++)
187     {
188       dist = 0;
189       for (j = 0; j < kmeans->m; j++)
190         {
191           x = case_data (c, kmeans->variables[j])->f;
192           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
193         }
194       if (dist < mindist)
195         {
196           mindist = dist;
197           result = i;
198         }
199     }
200   return (result);
201 }
202
203 /* Re-calculate the cluster centers. */
204 static void
205 kmeans_recalculate_centers (struct Kmeans *kmeans)
206 {
207   casenumber i;
208   int v, j;
209   double x, curval;
210   struct ccase *c;
211   struct ccase *c_index;
212   struct casereader *cs;
213   struct casereader *cs_index;
214   int index;
215   double weight;
216
217   i = 0;
218   cs = casereader_clone (kmeans->original_casereader);
219   cs_index = casereader_clone (kmeans->index_rdr);
220
221   gsl_matrix_set_all (kmeans->centers, 0.0);
222   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
223     {
224       c_index = casereader_read (cs_index);
225       index = case_data_idx (c_index, 0)->f;
226       for (v = 0; v < kmeans->m; ++v)
227         {
228           if (kmeans->wv)
229             {
230               weight = case_data (c, kmeans->wv)->f;
231             }
232           else
233             {
234               weight = 1.0;
235             }
236           x = case_data (c, kmeans->variables[v])->f * weight;
237           curval = gsl_matrix_get (kmeans->centers, index, v);
238           gsl_matrix_set (kmeans->centers, index, v, curval + x);
239         }
240       i++;
241       case_unref (c_index);
242     }
243   casereader_destroy (cs);
244   casereader_destroy (cs_index);
245
246   /* Getting number of cases */
247   if (kmeans->n == 0)
248     kmeans->n = i;
249
250   /* We got sum of each center but we need averages.
251      We are dividing centers to numobs. This may be inefficient and
252      we should check it again. */
253   for (i = 0; i < kmeans->ngroups; i++)
254     {
255       casenumber numobs = kmeans->num_elements_groups->data[i];
256       for (j = 0; j < kmeans->m; j++)
257         {
258           if (numobs > 0)
259             {
260               double *x = gsl_matrix_ptr (kmeans->centers, i, j);
261               *x /= numobs;
262             }
263           else
264             {
265               gsl_matrix_set (kmeans->centers, i, j, 0);
266             }
267         }
268     }
269 }
270
271 /* The variable index in struct Kmeans holds integer values that represents the
272    current groups of cases.  index[n]=a shows the nth case is belong to ath
273    cluster.  This function calculates these indexes and returns the number of
274    different cases of the new and old index variables.  If last two index
275    variables are equal, there is no any enhancement of clustering. */
276 static int
277 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
278 {
279   int totaldiff = 0;
280   double weight;
281   struct ccase *c;
282   struct casereader *cs = casereader_clone (kmeans->original_casereader);
283
284   /* A casewriter into which we will write the indexes. */
285   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
286
287   gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
288
289   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
290     {
291       /* A case to hold the new index. */
292       struct ccase *index_case_new = case_create (kmeans->proto);
293       int bestindex = kmeans_get_nearest_group (kmeans, c);
294       if (kmeans->wv)
295         {
296           weight = (casenumber) case_data (c, kmeans->wv)->f;
297         }
298       else
299         {
300           weight = 1.0;
301         }
302       kmeans->num_elements_groups->data[bestindex] += weight;
303       if (kmeans->index_rdr)
304         {
305           /* A case from which the old index will be read. */
306           struct ccase *index_case_old = NULL;
307
308           /* Read the case from the index casereader. */
309           index_case_old = casereader_read (kmeans->index_rdr);
310
311           /* Set totaldiff, using the old_index. */
312           totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
313
314           /* We have no use for the old case anymore, so unref it. */
315           case_unref (index_case_old);
316         }
317       else
318         {
319           /* If this is the first run, then assume index is zero. */
320           totaldiff += bestindex;
321         }
322
323       /* Set the value of the new inde.x */
324       case_data_rw_idx (index_case_new, 0)->f = bestindex;
325
326       /* and write the new index to the casewriter */
327       casewriter_write (index_wtr, index_case_new);
328     }
329   casereader_destroy (cs);
330   /* We have now read through the entire index_rdr, so it's of no use
331      anymore. */
332   casereader_destroy (kmeans->index_rdr);
333
334   /* Convert the writer into a reader, ready for the next iteration to read */
335   kmeans->index_rdr = casewriter_make_reader (index_wtr);
336
337   return (totaldiff);
338 }
339
340 static void
341 kmeans_order_groups (struct Kmeans *kmeans)
342 {
343   gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
344   gsl_matrix_get_col (v, kmeans->centers, 0);
345   gsl_sort_vector_index (kmeans->group_order, v);
346 }
347
348 /* Main algorithm.
349    Does iterations, checks convergency. */
350 static void
351 kmeans_cluster (struct Kmeans *kmeans)
352 {
353   int i;
354   bool redo;
355   int diffs;
356   bool show_warning1;
357
358   show_warning1 = true;
359 cluster:
360   redo = false;
361   kmeans_randomize_centers (kmeans);
362   for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
363        kmeans->lastiter++)
364     {
365       diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
366       kmeans_recalculate_centers (kmeans);
367       if (show_warning1 && kmeans->ngroups > kmeans->n)
368         {
369           msg (MW, _("Number of clusters may not be larger than the number "
370                      "of cases."));
371           show_warning1 = false;
372         }
373       if (diffs == 0)
374         break;
375     }
376
377   for (i = 0; i < kmeans->ngroups; i++)
378     {
379       if (kmeans->num_elements_groups->data[i] == 0)
380         {
381           kmeans->trials++;
382           if (kmeans->trials >= 3)
383             break;
384           redo = true;
385           break;
386         }
387     }
388   if (redo)
389     goto cluster;
390
391 }
392
393 /* Reports centers of clusters.
394    Initial parameter is optional for future use.
395    If initial is true, initial cluster centers are reported.  Otherwise,
396    resulted centers are reported. */
397 static void
398 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
399 {
400   struct tab_table *t;
401   int nc, nr, heading_columns, currow;
402   int i, j;
403   nc = kmeans->ngroups + 1;
404   nr = kmeans->m + 4;
405   heading_columns = 1;
406   t = tab_create (nc, nr);
407   tab_headers (t, 0, nc - 1, 0, 1);
408   currow = 0;
409   if (!initial)
410     {
411       tab_title (t, _("Final Cluster Centers"));
412     }
413   else
414     {
415       tab_title (t, _("Initial Cluster Centers"));
416     }
417   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
418   tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
419   tab_hline (t, TAL_1, 1, nc - 1, 2);
420   currow += 2;
421
422   for (i = 0; i < kmeans->ngroups; i++)
423     {
424       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
425     }
426   currow++;
427   tab_hline (t, TAL_1, 1, nc - 1, currow);
428   currow++;
429   for (i = 0; i < kmeans->m; i++)
430     {
431       tab_text (t, 0, currow + i, TAB_LEFT,
432                 var_to_string (kmeans->variables[i]));
433     }
434
435   for (i = 0; i < kmeans->ngroups; i++)
436     {
437       for (j = 0; j < kmeans->m; j++)
438         {
439           if (!initial)
440             {
441               tab_double (t, i + 1, j + 4, TAB_CENTER,
442                           gsl_matrix_get (kmeans->centers,
443                                           kmeans->group_order->data[i], j),
444                           var_get_print_format (kmeans->variables[j]));
445             }
446           else
447             {
448               tab_double (t, i + 1, j + 4, TAB_CENTER,
449                           gsl_matrix_get (kmeans->initial_centers,
450                                           kmeans->group_order->data[i], j),
451                           var_get_print_format (kmeans->variables[j]));
452             }
453         }
454     }
455   tab_submit (t);
456 }
457
458 /* Reports number of cases of each single cluster. */
459 static void
460 quick_cluster_show_number_cases (struct Kmeans *kmeans)
461 {
462   struct tab_table *t;
463   int nc, nr;
464   int i, numelem;
465   long int total;
466   nc = 3;
467   nr = kmeans->ngroups + 1;
468   t = tab_create (nc, nr);
469   tab_headers (t, 0, nc - 1, 0, 0);
470   tab_title (t, _("Number of Cases in each Cluster"));
471   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
472   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
473
474   total = 0;
475   for (i = 0; i < kmeans->ngroups; i++)
476     {
477       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
478       numelem =
479         kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
480       tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
481       total += numelem;
482     }
483
484   tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
485   tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
486   tab_submit (t);
487 }
488
489 /* Reports. */
490 static void
491 quick_cluster_show_results (struct Kmeans *kmeans)
492 {
493   kmeans_order_groups (kmeans);
494   /* Uncomment the line below for reporting initial centers. */
495   /* quick_cluster_show_centers (kmeans, true); */
496   quick_cluster_show_centers (kmeans, false);
497   quick_cluster_show_number_cases (kmeans);
498 }
499
500 int
501 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
502 {
503   struct Kmeans *kmeans;
504   bool ok;
505   const struct dictionary *dict = dataset_dict (ds);
506   const struct variable **variables;
507   struct casereader *cs;
508   int groups = 2;
509   int maxiter = 2;
510   size_t p;
511
512   if (!parse_variables_const (lexer, dict, &variables, &p,
513                               PV_NO_DUPLICATE | PV_NUMERIC))
514     {
515       msg (ME, _("Variables cannot be parsed"));
516       return (CMD_FAILURE);
517     }
518
519   if (lex_match (lexer, T_SLASH))
520     {
521       if (lex_match_id (lexer, "CRITERIA"))
522         {
523           lex_match (lexer, T_EQUALS);
524           while (lex_token (lexer) != T_ENDCMD
525                  && lex_token (lexer) != T_SLASH)
526             {
527               if (lex_match_id (lexer, "CLUSTERS"))
528                 {
529                   if (lex_force_match (lexer, T_LPAREN))
530                     {
531                       lex_force_int (lexer);
532                       groups = lex_integer (lexer);
533                       lex_get (lexer);
534                       lex_force_match (lexer, T_RPAREN);
535                     }
536                 }
537               else if (lex_match_id (lexer, "MXITER"))
538                 {
539                   if (lex_force_match (lexer, T_LPAREN))
540                     {
541                       lex_force_int (lexer);
542                       maxiter = lex_integer (lexer);
543                       lex_get (lexer);
544                       lex_force_match (lexer, T_RPAREN);
545                     }
546                 }
547               else
548                 return CMD_FAILURE;
549             }
550         }
551     }
552
553   cs = proc_open (ds);
554
555   kmeans = kmeans_create (cs, variables, p, groups, maxiter);
556
557   kmeans->wv = dict_get_weight (dict);
558   kmeans_cluster (kmeans);
559   quick_cluster_show_results (kmeans);
560   ok = proc_commit (ds);
561
562   kmeans_destroy (kmeans);
563
564   return (ok);
565 }