QUICK CLUSTER: Implement pairwise missing option
[pspp] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
23 #include <math.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49 enum missing_type
50   {
51     MISS_LISTWISE,
52     MISS_PAIRWISE,
53   };
54
55
56 struct qc
57 {
58   const struct variable **vars;
59   size_t n_vars;
60
61   int ngroups;                  /* Number of group. (Given by the user) */
62   int maxiter;                  /* Maximum iterations (Given by the user) */
63
64   const struct variable *wv;    /* Weighting variable. */
65
66   enum missing_type missing_type;
67   enum mv_class exclude;
68 };
69
70 /* Holds all of the information for the functions.  int n, holds the number of
71    observation and its default value is -1.  We set it in
72    kmeans_recalculate_centers in first invocation. */
73 struct Kmeans
74 {
75   gsl_matrix *centers;          /* Centers for groups. */
76   gsl_vector_long *num_elements_groups;
77
78   casenumber n;                 /* Number of observations (default -1). */
79
80   int lastiter;                 /* Iteration where it found the solution. */
81   int trials;                   /* If not convergence, how many times has
82                                    clustering done. */
83   gsl_matrix *initial_centers;  /* Initial random centers. */
84
85   gsl_permutation *group_order; /* Group order for reporting. */
86   struct caseproto *proto;
87   struct casereader *index_rdr; /* Group ids for each case. */
88 };
89
90 static struct Kmeans *kmeans_create (const struct qc *qc);
91
92 static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc);
93
94 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
95
96 static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
97
98 static int
99 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
100
101 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
102
103 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
104
105 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
106
107 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
108
109 static void quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *);
110
111 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
112
113 static void kmeans_destroy (struct Kmeans *kmeans);
114
115 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
116    variables 'variables', number of cases 'n', number of variables 'm', number
117    of clusters and amount of maximum iterations. */
118 static struct Kmeans *
119 kmeans_create (const struct qc *qc)
120 {
121   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
122   kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
123   kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
124   kmeans->n = 0;
125   kmeans->lastiter = 0;
126   kmeans->trials = 0;
127   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
128   kmeans->initial_centers = NULL;
129
130   kmeans->proto = caseproto_create ();
131   kmeans->proto = caseproto_add_width (kmeans->proto, 0);
132   kmeans->index_rdr = NULL;
133   return (kmeans);
134 }
135
136 static void
137 kmeans_destroy (struct Kmeans *kmeans)
138 {
139   gsl_matrix_free (kmeans->centers);
140   gsl_matrix_free (kmeans->initial_centers);
141
142   gsl_vector_long_free (kmeans->num_elements_groups);
143
144   gsl_permutation_free (kmeans->group_order);
145
146   caseproto_unref (kmeans->proto);
147
148   casereader_destroy (kmeans->index_rdr);
149
150   free (kmeans);
151 }
152
153 /* Creates random centers using randomly selected cases from the data. */
154 static void
155 kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc)
156 {
157   int i, j;
158   for (i = 0; i < qc->ngroups; i++)
159     {
160       for (j = 0; j < qc->n_vars; j++)
161         {
162           if (i == j)
163             {
164               gsl_matrix_set (kmeans->centers, i, j, 1);
165             }
166           else
167             {
168               gsl_matrix_set (kmeans->centers, i, j, 0);
169             }
170         }
171     }
172   /* If it is the first iteration, the variable kmeans->initial_centers is NULL
173      and it is created once for reporting issues. In SPSS, initial centers are
174      shown in the reports but in PSPP it is not shown now. I am leaving it
175      here. */
176   if (!kmeans->initial_centers)
177     {
178       kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
179       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
180     }
181 }
182
183 static int
184 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
185 {
186   int result = -1;
187   int i, j;
188   double mindist = INFINITY;
189   for (i = 0; i < qc->ngroups; i++)
190     {
191       double dist = 0;
192       for (j = 0; j < qc->n_vars; j++)
193         {
194           const union value *val = case_data (c, qc->vars[j]);
195           if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
196             continue;
197
198           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
199         }
200       if (dist < mindist)
201         {
202           mindist = dist;
203           result = i;
204         }
205     }
206   return (result);
207 }
208
209 /* Re-calculate the cluster centers. */
210 static void
211 kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
212 {
213   casenumber i = 0;
214   int v, j;
215   struct ccase *c;
216
217   struct casereader *cs = casereader_clone (reader);
218   struct casereader *cs_index = casereader_clone (kmeans->index_rdr);
219
220   gsl_matrix_set_all (kmeans->centers, 0.0);
221   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
222     {
223       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
224       struct ccase *c_index = casereader_read (cs_index);
225       int index = case_data_idx (c_index, 0)->f;
226       for (v = 0; v < qc->n_vars; ++v)
227         {
228           const union value *val = case_data (c, qc->vars[v]);
229           double x = val->f * weight;
230
231           if ( var_is_value_missing (qc->vars[v], val, qc->exclude))
232             continue;
233
234           double curval = gsl_matrix_get (kmeans->centers, index, v);
235           gsl_matrix_set (kmeans->centers, index, v, curval + x);
236         }
237       i++;
238       case_unref (c_index);
239     }
240   casereader_destroy (cs);
241   casereader_destroy (cs_index);
242
243   /* Getting number of cases */
244   if (kmeans->n == 0)
245     kmeans->n = i;
246
247   /* We got sum of each center but we need averages.
248      We are dividing centers to numobs. This may be inefficient and
249      we should check it again. */
250   for (i = 0; i < qc->ngroups; i++)
251     {
252       casenumber numobs = kmeans->num_elements_groups->data[i];
253       for (j = 0; j < qc->n_vars; j++)
254         {
255           if (numobs > 0)
256             {
257               double *x = gsl_matrix_ptr (kmeans->centers, i, j);
258               *x /= numobs;
259             }
260           else
261             {
262               gsl_matrix_set (kmeans->centers, i, j, 0);
263             }
264         }
265     }
266 }
267
268 /* The variable index in struct Kmeans holds integer values that represents the
269    current groups of cases.  index[n]=a shows the nth case is belong to ath
270    cluster.  This function calculates these indexes and returns the number of
271    different cases of the new and old index variables.  If last two index
272    variables are equal, there is no any enhancement of clustering. */
273 static int
274 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
275 {
276   int totaldiff = 0;
277   struct ccase *c;
278   struct casereader *cs = casereader_clone (reader);
279
280   /* A casewriter into which we will write the indexes. */
281   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
282
283   gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
284
285   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
286     {
287       /* A case to hold the new index. */
288       struct ccase *index_case_new = case_create (kmeans->proto);
289       int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
290       double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
291       assert (bestindex < kmeans->num_elements_groups->size);
292       kmeans->num_elements_groups->data[bestindex] += weight;
293       if (kmeans->index_rdr)
294         {
295           /* A case from which the old index will be read. */
296           struct ccase *index_case_old = NULL;
297
298           /* Read the case from the index casereader. */
299           index_case_old = casereader_read (kmeans->index_rdr);
300
301           /* Set totaldiff, using the old_index. */
302           totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
303
304           /* We have no use for the old case anymore, so unref it. */
305           case_unref (index_case_old);
306         }
307       else
308         {
309           /* If this is the first run, then assume index is zero. */
310           totaldiff += bestindex;
311         }
312
313       /* Set the value of the new inde.x */
314       case_data_rw_idx (index_case_new, 0)->f = bestindex;
315
316       /* and write the new index to the casewriter */
317       casewriter_write (index_wtr, index_case_new);
318     }
319   casereader_destroy (cs);
320   /* We have now read through the entire index_rdr, so it's of no use
321      anymore. */
322   casereader_destroy (kmeans->index_rdr);
323
324   /* Convert the writer into a reader, ready for the next iteration to read */
325   kmeans->index_rdr = casewriter_make_reader (index_wtr);
326
327   return (totaldiff);
328 }
329
330 static void
331 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
332 {
333   gsl_vector *v = gsl_vector_alloc (qc->ngroups);
334   gsl_matrix_get_col (v, kmeans->centers, 0);
335   gsl_sort_vector_index (kmeans->group_order, v);
336   gsl_vector_free (v);
337 }
338
339 /* Main algorithm.
340    Does iterations, checks convergency. */
341 static void
342 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
343 {
344   int i;
345   bool redo;
346   int diffs;
347   bool show_warning1;
348
349   show_warning1 = true;
350 cluster:
351   redo = false;
352   kmeans_randomize_centers (kmeans, qc);
353   for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
354        kmeans->lastiter++)
355     {
356       diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
357       kmeans_recalculate_centers (kmeans, reader, qc);
358       if (show_warning1 && qc->ngroups > kmeans->n)
359         {
360           msg (MW, _("Number of clusters may not be larger than the number "
361                      "of cases."));
362           show_warning1 = false;
363         }
364       if (diffs == 0)
365         break;
366     }
367
368   for (i = 0; i < qc->ngroups; i++)
369     {
370       if (kmeans->num_elements_groups->data[i] == 0)
371         {
372           kmeans->trials++;
373           if (kmeans->trials >= 3)
374             break;
375           redo = true;
376           break;
377         }
378     }
379   if (redo)
380     goto cluster;
381
382 }
383
384 /* Reports centers of clusters.
385    Initial parameter is optional for future use.
386    If initial is true, initial cluster centers are reported.  Otherwise,
387    resulted centers are reported. */
388 static void
389 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
390 {
391   struct tab_table *t;
392   int nc, nr, heading_columns, currow;
393   int i, j;
394   nc = qc->ngroups + 1;
395   nr = qc->n_vars + 4;
396   heading_columns = 1;
397   t = tab_create (nc, nr);
398   tab_headers (t, 0, nc - 1, 0, 1);
399   currow = 0;
400   if (!initial)
401     {
402       tab_title (t, _("Final Cluster Centers"));
403     }
404   else
405     {
406       tab_title (t, _("Initial Cluster Centers"));
407     }
408   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
409   tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
410   tab_hline (t, TAL_1, 1, nc - 1, 2);
411   currow += 2;
412
413   for (i = 0; i < qc->ngroups; i++)
414     {
415       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
416     }
417   currow++;
418   tab_hline (t, TAL_1, 1, nc - 1, currow);
419   currow++;
420   for (i = 0; i < qc->n_vars; i++)
421     {
422       tab_text (t, 0, currow + i, TAB_LEFT,
423                 var_to_string (qc->vars[i]));
424     }
425
426   for (i = 0; i < qc->ngroups; i++)
427     {
428       for (j = 0; j < qc->n_vars; j++)
429         {
430           if (!initial)
431             {
432               tab_double (t, i + 1, j + 4, TAB_CENTER,
433                           gsl_matrix_get (kmeans->centers,
434                                           kmeans->group_order->data[i], j),
435                           var_get_print_format (qc->vars[j]));
436             }
437           else
438             {
439               tab_double (t, i + 1, j + 4, TAB_CENTER,
440                           gsl_matrix_get (kmeans->initial_centers,
441                                           kmeans->group_order->data[i], j),
442                           var_get_print_format (qc->vars[j]));
443             }
444         }
445     }
446   tab_submit (t);
447 }
448
449 /* Reports number of cases of each single cluster. */
450 static void
451 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
452 {
453   struct tab_table *t;
454   int nc, nr;
455   int i, numelem;
456   long int total;
457   nc = 3;
458   nr = qc->ngroups + 1;
459   t = tab_create (nc, nr);
460   tab_headers (t, 0, nc - 1, 0, 0);
461   tab_title (t, _("Number of Cases in each Cluster"));
462   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
463   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
464
465   total = 0;
466   for (i = 0; i < qc->ngroups; i++)
467     {
468       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
469       numelem =
470         kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
471       tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
472       total += numelem;
473     }
474
475   tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
476   tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
477   tab_submit (t);
478 }
479
480 /* Reports. */
481 static void
482 quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *qc)
483 {
484   kmeans_order_groups (kmeans, qc);
485   /* Uncomment the line below for reporting initial centers. */
486   /* quick_cluster_show_centers (kmeans, true); */
487   quick_cluster_show_centers (kmeans, false, qc);
488   quick_cluster_show_number_cases (kmeans, qc);
489 }
490
491 int
492 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
493 {
494   struct qc qc;
495   struct Kmeans *kmeans;
496   bool ok;
497   const struct dictionary *dict = dataset_dict (ds);
498   qc.ngroups = 2;
499   qc.maxiter = 2;
500   qc.missing_type = MISS_LISTWISE;
501   qc.exclude = MV_ANY;
502
503   if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
504                               PV_NO_DUPLICATE | PV_NUMERIC))
505     {
506       return (CMD_FAILURE);
507     }
508
509   while (lex_token (lexer) != T_ENDCMD)
510     {
511       lex_match (lexer, T_SLASH);
512
513       if (lex_match_id (lexer, "MISSING"))
514         {
515           lex_match (lexer, T_EQUALS);
516           while (lex_token (lexer) != T_ENDCMD
517                  && lex_token (lexer) != T_SLASH)
518             {
519               if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT"))
520                 {
521                   qc.missing_type = MISS_LISTWISE;
522                 }
523               else if (lex_match_id (lexer, "PAIRWISE"))
524                 {
525                   qc.missing_type = MISS_PAIRWISE;
526                 }
527               else if (lex_match_id (lexer, "INCLUDE"))
528                 {
529                   qc.exclude = MV_SYSTEM;
530                 }
531               else
532                 goto error;
533             }     
534         }
535       else if (lex_match_id (lexer, "CRITERIA"))
536         {
537           lex_match (lexer, T_EQUALS);
538           while (lex_token (lexer) != T_ENDCMD
539                  && lex_token (lexer) != T_SLASH)
540             {
541               if (lex_match_id (lexer, "CLUSTERS"))
542                 {
543                   if (lex_force_match (lexer, T_LPAREN))
544                     {
545                       lex_force_int (lexer);
546                       qc.ngroups = lex_integer (lexer);
547                       lex_get (lexer);
548                       lex_force_match (lexer, T_RPAREN);
549                     }
550                 }
551               else if (lex_match_id (lexer, "MXITER"))
552                 {
553                   if (lex_force_match (lexer, T_LPAREN))
554                     {
555                       lex_force_int (lexer);
556                       qc.maxiter = lex_integer (lexer);
557                       lex_get (lexer);
558                       lex_force_match (lexer, T_RPAREN);
559                     }
560                 }
561               else
562                 goto error;
563             }
564         }
565     }
566
567   qc.wv = dict_get_weight (dict);
568
569   {
570     struct casereader *group;
571     struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
572
573     while (casegrouper_get_next_group (grouper, &group))
574       {
575         if ( qc.missing_type == MISS_LISTWISE )
576           {
577             group  = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
578                                                      qc.exclude,
579                                                      NULL,  NULL);
580           }
581
582         kmeans = kmeans_create (&qc);
583         kmeans_cluster (kmeans, group, &qc);
584         quick_cluster_show_results (kmeans, &qc);
585         kmeans_destroy (kmeans);
586         casereader_destroy (group);
587       }
588     ok = casegrouper_destroy (grouper);
589   }
590   ok = proc_commit (ds) && ok;
591
592   free (qc.vars);
593
594   return (ok);
595
596  error:
597   free (qc.vars);
598   return CMD_FAILURE;
599 }