Merge remote branch 'origin/sourceview'
[pspp] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
23 #include <math.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49 enum missing_type
50   {
51     MISS_LISTWISE,
52     MISS_PAIRWISE,
53   };
54
55
56 struct qc
57 {
58   const struct variable **vars;
59   size_t n_vars;
60
61   int ngroups;                  /* Number of group. (Given by the user) */
62   int maxiter;                  /* Maximum iterations (Given by the user) */
63
64   const struct variable *wv;    /* Weighting variable. */
65
66   enum missing_type missing_type;
67   enum mv_class exclude;
68 };
69
70 /* Holds all of the information for the functions.  int n, holds the number of
71    observation and its default value is -1.  We set it in
72    kmeans_recalculate_centers in first invocation. */
73 struct Kmeans
74 {
75   gsl_matrix *centers;          /* Centers for groups. */
76   gsl_vector_long *num_elements_groups;
77
78   casenumber n;                 /* Number of observations (default -1). */
79
80   int lastiter;                 /* Iteration where it found the solution. */
81   int trials;                   /* If not convergence, how many times has
82                                    clustering done. */
83   gsl_matrix *initial_centers;  /* Initial random centers. */
84
85   gsl_permutation *group_order; /* Group order for reporting. */
86   struct caseproto *proto;
87   struct casereader *index_rdr; /* Group ids for each case. */
88 };
89
90 static struct Kmeans *kmeans_create (const struct qc *qc);
91
92 static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc);
93
94 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
95
96 static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
97
98 static int
99 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
100
101 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
102
103 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
104
105 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
106
107 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
108
109 static void quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *);
110
111 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
112
113 static void kmeans_destroy (struct Kmeans *kmeans);
114
115 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
116    variables 'variables', number of cases 'n', number of variables 'm', number
117    of clusters and amount of maximum iterations. */
118 static struct Kmeans *
119 kmeans_create (const struct qc *qc)
120 {
121   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
122   kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
123   kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
124   kmeans->n = 0;
125   kmeans->lastiter = 0;
126   kmeans->trials = 0;
127   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
128   kmeans->initial_centers = NULL;
129
130   kmeans->proto = caseproto_create ();
131   kmeans->proto = caseproto_add_width (kmeans->proto, 0);
132   kmeans->index_rdr = NULL;
133   return (kmeans);
134 }
135
136 static void
137 kmeans_destroy (struct Kmeans *kmeans)
138 {
139   gsl_matrix_free (kmeans->centers);
140   gsl_matrix_free (kmeans->initial_centers);
141
142   gsl_vector_long_free (kmeans->num_elements_groups);
143
144   gsl_permutation_free (kmeans->group_order);
145
146   caseproto_unref (kmeans->proto);
147
148   casereader_destroy (kmeans->index_rdr);
149
150   free (kmeans);
151 }
152
153 /* Creates random centers using randomly selected cases from the data. */
154 static void
155 kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc)
156 {
157   int i, j;
158   for (i = 0; i < qc->ngroups; i++)
159     {
160       for (j = 0; j < qc->n_vars; j++)
161         {
162           if (i == j)
163             {
164               gsl_matrix_set (kmeans->centers, i, j, 1);
165             }
166           else
167             {
168               gsl_matrix_set (kmeans->centers, i, j, 0);
169             }
170         }
171     }
172   /* If it is the first iteration, the variable kmeans->initial_centers is NULL
173      and it is created once for reporting issues. In SPSS, initial centers are
174      shown in the reports but in PSPP it is not shown now. I am leaving it
175      here. */
176   if (!kmeans->initial_centers)
177     {
178       kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
179       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
180     }
181 }
182
183 static int
184 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
185 {
186   int result = -1;
187   int i, j;
188   double mindist = INFINITY;
189   for (i = 0; i < qc->ngroups; i++)
190     {
191       double dist = 0;
192       for (j = 0; j < qc->n_vars; j++)
193         {
194           const union value *val = case_data (c, qc->vars[j]);
195           if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
196             continue;
197
198           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
199         }
200       if (dist < mindist)
201         {
202           mindist = dist;
203           result = i;
204         }
205     }
206   return (result);
207 }
208
209 /* Re-calculate the cluster centers. */
210 static void
211 kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
212 {
213   casenumber i = 0;
214   int v, j;
215   struct ccase *c;
216
217   struct casereader *cs = casereader_clone (reader);
218   struct casereader *cs_index = casereader_clone (kmeans->index_rdr);
219
220   gsl_matrix_set_all (kmeans->centers, 0.0);
221   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
222     {
223       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
224       struct ccase *c_index = casereader_read (cs_index);
225       int index = case_data_idx (c_index, 0)->f;
226       for (v = 0; v < qc->n_vars; ++v)
227         {
228           const union value *val = case_data (c, qc->vars[v]);
229           double x = val->f * weight;
230           double curval;
231
232           if ( var_is_value_missing (qc->vars[v], val, qc->exclude))
233             continue;
234
235           curval = gsl_matrix_get (kmeans->centers, index, v);
236           gsl_matrix_set (kmeans->centers, index, v, curval + x);
237         }
238       i++;
239       case_unref (c_index);
240     }
241   casereader_destroy (cs);
242   casereader_destroy (cs_index);
243
244   /* Getting number of cases */
245   if (kmeans->n == 0)
246     kmeans->n = i;
247
248   /* We got sum of each center but we need averages.
249      We are dividing centers to numobs. This may be inefficient and
250      we should check it again. */
251   for (i = 0; i < qc->ngroups; i++)
252     {
253       casenumber numobs = kmeans->num_elements_groups->data[i];
254       for (j = 0; j < qc->n_vars; j++)
255         {
256           if (numobs > 0)
257             {
258               double *x = gsl_matrix_ptr (kmeans->centers, i, j);
259               *x /= numobs;
260             }
261           else
262             {
263               gsl_matrix_set (kmeans->centers, i, j, 0);
264             }
265         }
266     }
267 }
268
269 /* The variable index in struct Kmeans holds integer values that represents the
270    current groups of cases.  index[n]=a shows the nth case is belong to ath
271    cluster.  This function calculates these indexes and returns the number of
272    different cases of the new and old index variables.  If last two index
273    variables are equal, there is no any enhancement of clustering. */
274 static int
275 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
276 {
277   int totaldiff = 0;
278   struct ccase *c;
279   struct casereader *cs = casereader_clone (reader);
280
281   /* A casewriter into which we will write the indexes. */
282   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
283
284   gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
285
286   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
287     {
288       /* A case to hold the new index. */
289       struct ccase *index_case_new = case_create (kmeans->proto);
290       int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
291       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
292       assert (bestindex < kmeans->num_elements_groups->size);
293       kmeans->num_elements_groups->data[bestindex] += weight;
294       if (kmeans->index_rdr)
295         {
296           /* A case from which the old index will be read. */
297           struct ccase *index_case_old = NULL;
298
299           /* Read the case from the index casereader. */
300           index_case_old = casereader_read (kmeans->index_rdr);
301
302           /* Set totaldiff, using the old_index. */
303           totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
304
305           /* We have no use for the old case anymore, so unref it. */
306           case_unref (index_case_old);
307         }
308       else
309         {
310           /* If this is the first run, then assume index is zero. */
311           totaldiff += bestindex;
312         }
313
314       /* Set the value of the new inde.x */
315       case_data_rw_idx (index_case_new, 0)->f = bestindex;
316
317       /* and write the new index to the casewriter */
318       casewriter_write (index_wtr, index_case_new);
319     }
320   casereader_destroy (cs);
321   /* We have now read through the entire index_rdr, so it's of no use
322      anymore. */
323   casereader_destroy (kmeans->index_rdr);
324
325   /* Convert the writer into a reader, ready for the next iteration to read */
326   kmeans->index_rdr = casewriter_make_reader (index_wtr);
327
328   return (totaldiff);
329 }
330
331 static void
332 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
333 {
334   gsl_vector *v = gsl_vector_alloc (qc->ngroups);
335   gsl_matrix_get_col (v, kmeans->centers, 0);
336   gsl_sort_vector_index (kmeans->group_order, v);
337   gsl_vector_free (v);
338 }
339
340 /* Main algorithm.
341    Does iterations, checks convergency. */
342 static void
343 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
344 {
345   int i;
346   bool redo;
347   int diffs;
348   bool show_warning1;
349
350   show_warning1 = true;
351 cluster:
352   redo = false;
353   kmeans_randomize_centers (kmeans, qc);
354   for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
355        kmeans->lastiter++)
356     {
357       diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
358       kmeans_recalculate_centers (kmeans, reader, qc);
359       if (show_warning1 && qc->ngroups > kmeans->n)
360         {
361           msg (MW, _("Number of clusters may not be larger than the number "
362                      "of cases."));
363           show_warning1 = false;
364         }
365       if (diffs == 0)
366         break;
367     }
368
369   for (i = 0; i < qc->ngroups; i++)
370     {
371       if (kmeans->num_elements_groups->data[i] == 0)
372         {
373           kmeans->trials++;
374           if (kmeans->trials >= 3)
375             break;
376           redo = true;
377           break;
378         }
379     }
380   if (redo)
381     goto cluster;
382
383 }
384
385 /* Reports centers of clusters.
386    Initial parameter is optional for future use.
387    If initial is true, initial cluster centers are reported.  Otherwise,
388    resulted centers are reported. */
389 static void
390 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
391 {
392   struct tab_table *t;
393   int nc, nr, heading_columns, currow;
394   int i, j;
395   nc = qc->ngroups + 1;
396   nr = qc->n_vars + 4;
397   heading_columns = 1;
398   t = tab_create (nc, nr);
399   tab_headers (t, 0, nc - 1, 0, 1);
400   currow = 0;
401   if (!initial)
402     {
403       tab_title (t, _("Final Cluster Centers"));
404     }
405   else
406     {
407       tab_title (t, _("Initial Cluster Centers"));
408     }
409   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
410   tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
411   tab_hline (t, TAL_1, 1, nc - 1, 2);
412   currow += 2;
413
414   for (i = 0; i < qc->ngroups; i++)
415     {
416       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
417     }
418   currow++;
419   tab_hline (t, TAL_1, 1, nc - 1, currow);
420   currow++;
421   for (i = 0; i < qc->n_vars; i++)
422     {
423       tab_text (t, 0, currow + i, TAB_LEFT,
424                 var_to_string (qc->vars[i]));
425     }
426
427   for (i = 0; i < qc->ngroups; i++)
428     {
429       for (j = 0; j < qc->n_vars; j++)
430         {
431           if (!initial)
432             {
433               tab_double (t, i + 1, j + 4, TAB_CENTER,
434                           gsl_matrix_get (kmeans->centers,
435                                           kmeans->group_order->data[i], j),
436                           var_get_print_format (qc->vars[j]));
437             }
438           else
439             {
440               tab_double (t, i + 1, j + 4, TAB_CENTER,
441                           gsl_matrix_get (kmeans->initial_centers,
442                                           kmeans->group_order->data[i], j),
443                           var_get_print_format (qc->vars[j]));
444             }
445         }
446     }
447   tab_submit (t);
448 }
449
450 /* Reports number of cases of each single cluster. */
451 static void
452 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
453 {
454   struct tab_table *t;
455   int nc, nr;
456   int i, numelem;
457   long int total;
458   nc = 3;
459   nr = qc->ngroups + 1;
460   t = tab_create (nc, nr);
461   tab_headers (t, 0, nc - 1, 0, 0);
462   tab_title (t, _("Number of Cases in each Cluster"));
463   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
464   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
465
466   total = 0;
467   for (i = 0; i < qc->ngroups; i++)
468     {
469       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
470       numelem =
471         kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
472       tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
473       total += numelem;
474     }
475
476   tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
477   tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
478   tab_submit (t);
479 }
480
481 /* Reports. */
482 static void
483 quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *qc)
484 {
485   kmeans_order_groups (kmeans, qc);
486   /* Uncomment the line below for reporting initial centers. */
487   /* quick_cluster_show_centers (kmeans, true); */
488   quick_cluster_show_centers (kmeans, false, qc);
489   quick_cluster_show_number_cases (kmeans, qc);
490 }
491
492 int
493 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
494 {
495   struct qc qc;
496   struct Kmeans *kmeans;
497   bool ok;
498   const struct dictionary *dict = dataset_dict (ds);
499   qc.ngroups = 2;
500   qc.maxiter = 2;
501   qc.missing_type = MISS_LISTWISE;
502   qc.exclude = MV_ANY;
503
504   if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
505                               PV_NO_DUPLICATE | PV_NUMERIC))
506     {
507       return (CMD_FAILURE);
508     }
509
510   while (lex_token (lexer) != T_ENDCMD)
511     {
512       lex_match (lexer, T_SLASH);
513
514       if (lex_match_id (lexer, "MISSING"))
515         {
516           lex_match (lexer, T_EQUALS);
517           while (lex_token (lexer) != T_ENDCMD
518                  && lex_token (lexer) != T_SLASH)
519             {
520               if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT"))
521                 {
522                   qc.missing_type = MISS_LISTWISE;
523                 }
524               else if (lex_match_id (lexer, "PAIRWISE"))
525                 {
526                   qc.missing_type = MISS_PAIRWISE;
527                 }
528               else if (lex_match_id (lexer, "INCLUDE"))
529                 {
530                   qc.exclude = MV_SYSTEM;
531                 }
532               else if (lex_match_id (lexer, "EXCLUDE"))
533                 {
534                   qc.exclude = MV_ANY;
535                 }
536               else
537                 goto error;
538             }     
539         }
540       else if (lex_match_id (lexer, "CRITERIA"))
541         {
542           lex_match (lexer, T_EQUALS);
543           while (lex_token (lexer) != T_ENDCMD
544                  && lex_token (lexer) != T_SLASH)
545             {
546               if (lex_match_id (lexer, "CLUSTERS"))
547                 {
548                   if (lex_force_match (lexer, T_LPAREN))
549                     {
550                       lex_force_int (lexer);
551                       qc.ngroups = lex_integer (lexer);
552                       lex_get (lexer);
553                       lex_force_match (lexer, T_RPAREN);
554                     }
555                 }
556               else if (lex_match_id (lexer, "MXITER"))
557                 {
558                   if (lex_force_match (lexer, T_LPAREN))
559                     {
560                       lex_force_int (lexer);
561                       qc.maxiter = lex_integer (lexer);
562                       lex_get (lexer);
563                       lex_force_match (lexer, T_RPAREN);
564                     }
565                 }
566               else
567                 goto error;
568             }
569         }
570     }
571
572   qc.wv = dict_get_weight (dict);
573
574   {
575     struct casereader *group;
576     struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
577
578     while (casegrouper_get_next_group (grouper, &group))
579       {
580         if ( qc.missing_type == MISS_LISTWISE )
581           {
582             group  = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
583                                                      qc.exclude,
584                                                      NULL,  NULL);
585           }
586
587         kmeans = kmeans_create (&qc);
588         kmeans_cluster (kmeans, group, &qc);
589         quick_cluster_show_results (kmeans, &qc);
590         kmeans_destroy (kmeans);
591         casereader_destroy (group);
592       }
593     ok = casegrouper_destroy (grouper);
594   }
595   ok = proc_commit (ds) && ok;
596
597   free (qc.vars);
598
599   return (ok);
600
601  error:
602   free (qc.vars);
603   return CMD_FAILURE;
604 }