QUICK-CLUSTER: Seperate const from non-const data and make it handle splits
[pspp] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
23 #include <math.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49 struct qc
50 {
51   const struct variable **vars;
52   size_t n_vars;
53
54   int ngroups;                  /* Number of group. (Given by the user) */
55   int maxiter;                  /* Maximum iterations (Given by the user) */
56
57   const struct variable *wv;    /* Weighting variable. */
58 };
59
60 /* Holds all of the information for the functions.  int n, holds the number of
61    observation and its default value is -1.  We set it in
62    kmeans_recalculate_centers in first invocation. */
63 struct Kmeans
64 {
65   gsl_matrix *centers;          /* Centers for groups. */
66   gsl_vector_long *num_elements_groups;
67
68   casenumber n;                 /* Number of observations (default -1). */
69
70   int lastiter;                 /* Iteration where it found the solution. */
71   int trials;                   /* If not convergence, how many times has
72                                    clustering done. */
73   gsl_matrix *initial_centers;  /* Initial random centers. */
74
75   gsl_permutation *group_order; /* Group order for reporting. */
76   struct caseproto *proto;
77   struct casereader *index_rdr; /* Group ids for each case. */
78 };
79
80 static struct Kmeans *kmeans_create (const struct qc *qc);
81
82 static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc);
83
84 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
85
86 static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
87
88 static int
89 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
90
91 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
92
93 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
94
95 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
96
97 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
98
99 static void quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *);
100
101 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
102
103 static void kmeans_destroy (struct Kmeans *kmeans);
104
105 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
106    variables 'variables', number of cases 'n', number of variables 'm', number
107    of clusters and amount of maximum iterations. */
108 static struct Kmeans *
109 kmeans_create (const struct qc *qc)
110 {
111   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
112   kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
113   kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
114   kmeans->n = 0;
115   kmeans->lastiter = 0;
116   kmeans->trials = 0;
117   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
118   kmeans->initial_centers = NULL;
119
120   kmeans->proto = caseproto_create ();
121   kmeans->proto = caseproto_add_width (kmeans->proto, 0);
122   kmeans->index_rdr = NULL;
123   return (kmeans);
124 }
125
126 static void
127 kmeans_destroy (struct Kmeans *kmeans)
128 {
129   gsl_matrix_free (kmeans->centers);
130   gsl_matrix_free (kmeans->initial_centers);
131
132   gsl_vector_long_free (kmeans->num_elements_groups);
133
134   gsl_permutation_free (kmeans->group_order);
135
136   caseproto_unref (kmeans->proto);
137
138   casereader_destroy (kmeans->index_rdr);
139
140   free (kmeans);
141 }
142
143 /* Creates random centers using randomly selected cases from the data. */
144 static void
145 kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc)
146 {
147   int i, j;
148   for (i = 0; i < qc->ngroups; i++)
149     {
150       for (j = 0; j < qc->n_vars; j++)
151         {
152           if (i == j)
153             {
154               gsl_matrix_set (kmeans->centers, i, j, 1);
155             }
156           else
157             {
158               gsl_matrix_set (kmeans->centers, i, j, 0);
159             }
160         }
161     }
162   /* If it is the first iteration, the variable kmeans->initial_centers is NULL
163      and it is created once for reporting issues. In SPSS, initial centers are
164      shown in the reports but in PSPP it is not shown now. I am leaving it
165      here. */
166   if (!kmeans->initial_centers)
167     {
168       kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
169       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
170     }
171 }
172
173 static int
174 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
175 {
176   int result = -1;
177   double x;
178   int i, j;
179   double dist;
180   double mindist;
181   mindist = INFINITY;
182   for (i = 0; i < qc->ngroups; i++)
183     {
184       dist = 0;
185       for (j = 0; j < qc->n_vars; j++)
186         {
187           x = case_data (c, qc->vars[j])->f;
188           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
189         }
190       if (dist < mindist)
191         {
192           mindist = dist;
193           result = i;
194         }
195     }
196   return (result);
197 }
198
199 /* Re-calculate the cluster centers. */
200 static void
201 kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
202 {
203   casenumber i;
204   int v, j;
205   double x, curval;
206   struct ccase *c;
207   struct ccase *c_index;
208   struct casereader *cs;
209   struct casereader *cs_index;
210   int index;
211
212   i = 0;
213   cs = casereader_clone (reader);
214   cs_index = casereader_clone (kmeans->index_rdr);
215
216   gsl_matrix_set_all (kmeans->centers, 0.0);
217   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
218     {
219       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
220       c_index = casereader_read (cs_index);
221       index = case_data_idx (c_index, 0)->f;
222       for (v = 0; v < qc->n_vars; ++v)
223         {
224           x = case_data (c, qc->vars[v])->f * weight;
225           curval = gsl_matrix_get (kmeans->centers, index, v);
226           gsl_matrix_set (kmeans->centers, index, v, curval + x);
227         }
228       i++;
229       case_unref (c_index);
230     }
231   casereader_destroy (cs);
232   casereader_destroy (cs_index);
233
234   /* Getting number of cases */
235   if (kmeans->n == 0)
236     kmeans->n = i;
237
238   /* We got sum of each center but we need averages.
239      We are dividing centers to numobs. This may be inefficient and
240      we should check it again. */
241   for (i = 0; i < qc->ngroups; i++)
242     {
243       casenumber numobs = kmeans->num_elements_groups->data[i];
244       for (j = 0; j < qc->n_vars; j++)
245         {
246           if (numobs > 0)
247             {
248               double *x = gsl_matrix_ptr (kmeans->centers, i, j);
249               *x /= numobs;
250             }
251           else
252             {
253               gsl_matrix_set (kmeans->centers, i, j, 0);
254             }
255         }
256     }
257 }
258
259 /* The variable index in struct Kmeans holds integer values that represents the
260    current groups of cases.  index[n]=a shows the nth case is belong to ath
261    cluster.  This function calculates these indexes and returns the number of
262    different cases of the new and old index variables.  If last two index
263    variables are equal, there is no any enhancement of clustering. */
264 static int
265 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
266 {
267   int totaldiff = 0;
268   struct ccase *c;
269   struct casereader *cs = casereader_clone (reader);
270
271   /* A casewriter into which we will write the indexes. */
272   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
273
274   gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
275
276   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
277     {
278       /* A case to hold the new index. */
279       struct ccase *index_case_new = case_create (kmeans->proto);
280       int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
281       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
282       kmeans->num_elements_groups->data[bestindex] += weight;
283       if (kmeans->index_rdr)
284         {
285           /* A case from which the old index will be read. */
286           struct ccase *index_case_old = NULL;
287
288           /* Read the case from the index casereader. */
289           index_case_old = casereader_read (kmeans->index_rdr);
290
291           /* Set totaldiff, using the old_index. */
292           totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
293
294           /* We have no use for the old case anymore, so unref it. */
295           case_unref (index_case_old);
296         }
297       else
298         {
299           /* If this is the first run, then assume index is zero. */
300           totaldiff += bestindex;
301         }
302
303       /* Set the value of the new inde.x */
304       case_data_rw_idx (index_case_new, 0)->f = bestindex;
305
306       /* and write the new index to the casewriter */
307       casewriter_write (index_wtr, index_case_new);
308     }
309   casereader_destroy (cs);
310   /* We have now read through the entire index_rdr, so it's of no use
311      anymore. */
312   casereader_destroy (kmeans->index_rdr);
313
314   /* Convert the writer into a reader, ready for the next iteration to read */
315   kmeans->index_rdr = casewriter_make_reader (index_wtr);
316
317   return (totaldiff);
318 }
319
320 static void
321 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
322 {
323   gsl_vector *v = gsl_vector_alloc (qc->ngroups);
324   gsl_matrix_get_col (v, kmeans->centers, 0);
325   gsl_sort_vector_index (kmeans->group_order, v);
326   gsl_vector_free (v);
327 }
328
329 /* Main algorithm.
330    Does iterations, checks convergency. */
331 static void
332 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
333 {
334   int i;
335   bool redo;
336   int diffs;
337   bool show_warning1;
338
339   show_warning1 = true;
340 cluster:
341   redo = false;
342   kmeans_randomize_centers (kmeans, qc);
343   for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
344        kmeans->lastiter++)
345     {
346       diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
347       kmeans_recalculate_centers (kmeans, reader, qc);
348       if (show_warning1 && qc->ngroups > kmeans->n)
349         {
350           msg (MW, _("Number of clusters may not be larger than the number "
351                      "of cases."));
352           show_warning1 = false;
353         }
354       if (diffs == 0)
355         break;
356     }
357
358   for (i = 0; i < qc->ngroups; i++)
359     {
360       if (kmeans->num_elements_groups->data[i] == 0)
361         {
362           kmeans->trials++;
363           if (kmeans->trials >= 3)
364             break;
365           redo = true;
366           break;
367         }
368     }
369   if (redo)
370     goto cluster;
371
372 }
373
374 /* Reports centers of clusters.
375    Initial parameter is optional for future use.
376    If initial is true, initial cluster centers are reported.  Otherwise,
377    resulted centers are reported. */
378 static void
379 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
380 {
381   struct tab_table *t;
382   int nc, nr, heading_columns, currow;
383   int i, j;
384   nc = qc->ngroups + 1;
385   nr = qc->n_vars + 4;
386   heading_columns = 1;
387   t = tab_create (nc, nr);
388   tab_headers (t, 0, nc - 1, 0, 1);
389   currow = 0;
390   if (!initial)
391     {
392       tab_title (t, _("Final Cluster Centers"));
393     }
394   else
395     {
396       tab_title (t, _("Initial Cluster Centers"));
397     }
398   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
399   tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
400   tab_hline (t, TAL_1, 1, nc - 1, 2);
401   currow += 2;
402
403   for (i = 0; i < qc->ngroups; i++)
404     {
405       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
406     }
407   currow++;
408   tab_hline (t, TAL_1, 1, nc - 1, currow);
409   currow++;
410   for (i = 0; i < qc->n_vars; i++)
411     {
412       tab_text (t, 0, currow + i, TAB_LEFT,
413                 var_to_string (qc->vars[i]));
414     }
415
416   for (i = 0; i < qc->ngroups; i++)
417     {
418       for (j = 0; j < qc->n_vars; j++)
419         {
420           if (!initial)
421             {
422               tab_double (t, i + 1, j + 4, TAB_CENTER,
423                           gsl_matrix_get (kmeans->centers,
424                                           kmeans->group_order->data[i], j),
425                           var_get_print_format (qc->vars[j]));
426             }
427           else
428             {
429               tab_double (t, i + 1, j + 4, TAB_CENTER,
430                           gsl_matrix_get (kmeans->initial_centers,
431                                           kmeans->group_order->data[i], j),
432                           var_get_print_format (qc->vars[j]));
433             }
434         }
435     }
436   tab_submit (t);
437 }
438
439 /* Reports number of cases of each single cluster. */
440 static void
441 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
442 {
443   struct tab_table *t;
444   int nc, nr;
445   int i, numelem;
446   long int total;
447   nc = 3;
448   nr = qc->ngroups + 1;
449   t = tab_create (nc, nr);
450   tab_headers (t, 0, nc - 1, 0, 0);
451   tab_title (t, _("Number of Cases in each Cluster"));
452   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
453   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
454
455   total = 0;
456   for (i = 0; i < qc->ngroups; i++)
457     {
458       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
459       numelem =
460         kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
461       tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
462       total += numelem;
463     }
464
465   tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
466   tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
467   tab_submit (t);
468 }
469
470 /* Reports. */
471 static void
472 quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *qc)
473 {
474   kmeans_order_groups (kmeans, qc);
475   /* Uncomment the line below for reporting initial centers. */
476   /* quick_cluster_show_centers (kmeans, true); */
477   quick_cluster_show_centers (kmeans, false, qc);
478   quick_cluster_show_number_cases (kmeans, qc);
479 }
480
481 int
482 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
483 {
484   struct qc qc;
485   struct Kmeans *kmeans;
486   bool ok;
487   const struct dictionary *dict = dataset_dict (ds);
488   qc.ngroups = 2;
489   qc.maxiter = 2;
490
491   if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
492                               PV_NO_DUPLICATE | PV_NUMERIC))
493     {
494       return (CMD_FAILURE);
495     }
496
497   if (lex_match (lexer, T_SLASH))
498     {
499       if (lex_match_id (lexer, "CRITERIA"))
500         {
501           lex_match (lexer, T_EQUALS);
502           while (lex_token (lexer) != T_ENDCMD
503                  && lex_token (lexer) != T_SLASH)
504             {
505               if (lex_match_id (lexer, "CLUSTERS"))
506                 {
507                   if (lex_force_match (lexer, T_LPAREN))
508                     {
509                       lex_force_int (lexer);
510                       qc.ngroups = lex_integer (lexer);
511                       lex_get (lexer);
512                       lex_force_match (lexer, T_RPAREN);
513                     }
514                 }
515               else if (lex_match_id (lexer, "MXITER"))
516                 {
517                   if (lex_force_match (lexer, T_LPAREN))
518                     {
519                       lex_force_int (lexer);
520                       qc.maxiter = lex_integer (lexer);
521                       lex_get (lexer);
522                       lex_force_match (lexer, T_RPAREN);
523                     }
524                 }
525               else
526                 goto error;
527             }
528         }
529     }
530
531   qc.wv = dict_get_weight (dict);
532
533   {
534     struct casereader *group;
535     struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
536
537     while (casegrouper_get_next_group (grouper, &group))
538       {
539         kmeans = kmeans_create (&qc);
540         kmeans_cluster (kmeans, group, &qc);
541         quick_cluster_show_results (kmeans, &qc);
542         kmeans_destroy (kmeans);
543         casereader_destroy (group);
544       }
545     ok = casegrouper_destroy (grouper);
546   }
547   ok = proc_commit (ds) && ok;
548
549   free (qc.vars);
550
551   return (ok);
552
553  error:
554   free (qc.vars);
555   return CMD_FAILURE;
556 }