4dc3a2ff5a923f82c94205bd55c42cd734d3e69a
[pspp-builds.git] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <math.h>
20
21 #include <libpspp/misc.h>
22
23 #include <libpspp/str.h>
24 #include <libpspp/message.h>
25
26
27 #include <data/dataset.h>
28 #include <data/missing-values.h>
29 #include <data/casereader.h>
30 #include <data/casewriter.h>
31 #include <data/casegrouper.h>
32 #include <data/dictionary.h>
33 #include <data/format.h>
34 #include <data/case.h>
35
36 #include <language/lexer/variable-parser.h>
37 #include <language/command.h>
38 #include <language/lexer/lexer.h>
39
40 #include <output/tab.h>
41 #include <output/text-item.h>
42
43 #include <stdio.h>
44 #include <stdlib.h>
45
46 #include <gsl/gsl_matrix.h>
47 #include <gsl/gsl_statistics.h>
48 #include <gsl/gsl_permutation.h>
49 #include <gsl/gsl_sort_vector.h>
50
51 #include <math/random.h>
52
53 #include "gettext.h"
54 #define _(msgid) gettext (msgid)
55 #define N_(msgid) msgid
56
57 #include "quick-cluster.h"
58
59 /*
60 Struct KMeans:
61 Holds all of the information for the functions.
62 int n, holds the number of observation and its default value is -1.
63 We set it in kmeans_recalculate_centers in first invocation.
64 */
65 struct Kmeans
66 {
67   gsl_matrix *centers;          //Centers for groups
68   gsl_vector_long *num_elements_groups;
69   int ngroups;                  //Number of group. (Given by the user)
70   casenumber n;                 //Number of observations. By default it is -1.
71   int m;                        //Number of variables. (Given by the user)
72   int maxiter;                  //Maximum number of iterations (Given by the user)
73   int lastiter;                 //Show at which iteration it found the solution.
74   int trials;                   //If not convergence, how many times has clustering done.
75   gsl_matrix *initial_centers;  //Initial random centers
76   const struct variable **variables;    //Variables
77   gsl_permutation *group_order; //Handles group order for reporting
78   struct casereader *original_casereader;       //Casereader
79   struct caseproto *proto;
80   struct casereader *index_rdr; //We hold the group id's for each case in this structure
81   const struct variable *wv;    //Weighting variable
82 };
83
84
85 /*
86 Creates and returns a struct of Kmeans with given casereader 'cs', parsed variables 'variables',
87 number of cases 'n', number of variables 'm', number of clusters and amount of maximum iterations.
88 */
89 static struct Kmeans *
90 kmeans_create (struct casereader *cs, const struct variable **variables,
91                int m, int ngroups, int maxiter)
92 {
93   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
94   kmeans->centers = gsl_matrix_alloc (ngroups, m);
95   kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
96   kmeans->ngroups = ngroups;
97   kmeans->n = 0;
98   kmeans->m = m;
99   kmeans->maxiter = maxiter;
100   kmeans->lastiter = 0;
101   kmeans->trials = 0;
102   kmeans->variables = variables;
103   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
104   kmeans->original_casereader = cs;
105   kmeans->initial_centers = NULL;
106
107   kmeans->proto = caseproto_create ();
108   kmeans->proto = caseproto_add_width (kmeans->proto, 0);
109   kmeans->index_rdr = NULL;
110   return (kmeans);
111 }
112
113
114 static void
115 kmeans_destroy (struct Kmeans *kmeans)
116 {
117   gsl_matrix_free (kmeans->centers);
118   gsl_matrix_free (kmeans->initial_centers);
119
120   gsl_vector_long_free (kmeans->num_elements_groups);
121
122   gsl_permutation_free (kmeans->group_order);
123
124   caseproto_unref (kmeans->proto);
125
126   /*
127      These reader and writer were already destroyed.
128      free (kmeans->original_casereader);
129      free (kmeans->index_rdr);
130    */
131
132   free (kmeans);
133 }
134
135
136
137 /*
138 Creates random centers using randomly selected cases from the data.
139 */
140 static void
141 kmeans_randomize_centers (struct Kmeans *kmeans)
142 {
143   int i, j;
144   for (i = 0; i < kmeans->ngroups; i++)
145     {
146       for (j = 0; j < kmeans->m; j++)
147         {
148           //gsl_matrix_set(kmeans->centers,i,j, gsl_rng_uniform (kmeans->rng));
149           if (i == j)
150             {
151               gsl_matrix_set (kmeans->centers, i, j, 1);
152             }
153           else
154             {
155               gsl_matrix_set (kmeans->centers, i, j, 0);
156             }
157         }
158     }
159 /*
160 If it is the first iteration, the variable kmeans->initial_centers is NULL and
161 it is created once for reporting issues. In SPSS, initial centers are shown in the reports
162 but in PSPP it is not shown now. I am leaving it here.
163 */
164   if (!kmeans->initial_centers)
165     {
166       kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
167       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
168     }
169 }
170
171
172 static int
173 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
174 {
175   int result = -1;
176   double x;
177   int i, j;
178   double dist;
179   double mindist;
180   mindist = INFINITY;
181   for (i = 0; i < kmeans->ngroups; i++)
182     {
183       dist = 0;
184       for (j = 0; j < kmeans->m; j++)
185         {
186           x = case_data (c, kmeans->variables[j])->f;
187           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
188         }
189       if (dist < mindist)
190         {
191           mindist = dist;
192           result = i;
193         }
194     }
195   return (result);
196 }
197
198
199
200
201 /*
202 Re-calculates the cluster centers
203 */
204 static void
205 kmeans_recalculate_centers (struct Kmeans *kmeans)
206 {
207   casenumber i;
208   int v, j;
209   double x, curval;
210   struct ccase *c;
211   struct ccase *c_index;
212   struct casereader *cs;
213   struct casereader *cs_index;
214   int index;
215   double weight;
216
217   i = 0;
218   cs = casereader_clone (kmeans->original_casereader);
219   cs_index = casereader_clone (kmeans->index_rdr);
220
221   gsl_matrix_set_all (kmeans->centers, 0.0);
222   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
223     {
224       c_index = casereader_read (cs_index);
225       index = case_data_idx (c_index, 0)->f;
226       for (v = 0; v < kmeans->m; ++v)
227         {
228           if (kmeans->wv)
229             {
230               weight = case_data (c, kmeans->wv)->f;
231             }
232           else
233             {
234               weight = 1.0;
235             }
236           x = case_data (c, kmeans->variables[v])->f * weight;
237           curval = gsl_matrix_get (kmeans->centers, index, v);
238           gsl_matrix_set (kmeans->centers, index, v, curval + x);
239         }
240       i++;
241       case_unref (c_index);
242     }
243   casereader_destroy (cs);
244   casereader_destroy (cs_index);
245
246   /* Getting number of cases */
247   if (kmeans->n == 0)
248     kmeans->n = i;
249
250   //We got sum of each center but we need averages.
251   //We are dividing centers to numobs. This may be inefficient and
252   //we should check it again.
253   for (i = 0; i < kmeans->ngroups; i++)
254     {
255       casenumber numobs = kmeans->num_elements_groups->data[i];
256       for (j = 0; j < kmeans->m; j++)
257         {
258           if (numobs > 0)
259             {
260               double *x = gsl_matrix_ptr (kmeans->centers, i, j);
261               *x /= numobs;
262             }
263           else
264             {
265               gsl_matrix_set (kmeans->centers, i, j, 0);
266             }
267         }
268     }
269 }
270
271
272 /*
273 The variable index in struct Kmeans holds integer values that represents the current groups of cases.
274 index[n]=a shows the nth case is belong to ath cluster.
275 This function calculates these indexes and returns the number of different cases of the new and old
276 index variables. If last two index variables are equal, there is no any enhancement of clustering.
277 */
278 static int
279 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
280 {
281   int totaldiff = 0;
282   double weight;
283   struct ccase *c;
284   struct casereader *cs = casereader_clone (kmeans->original_casereader);
285
286
287   /* A casewriter into which we will write the indexes */
288   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
289
290   gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
291
292   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
293     {
294       /* A case to hold the new index */
295       struct ccase *index_case_new = case_create (kmeans->proto);
296       int bestindex = kmeans_get_nearest_group (kmeans, c);
297       if (kmeans->wv)
298         {
299           weight = (casenumber) case_data (c, kmeans->wv)->f;
300         }
301       else
302         {
303           weight = 1.0;
304         }
305       kmeans->num_elements_groups->data[bestindex] += weight;
306       if (kmeans->index_rdr)
307         {
308           /* A case from which the old index will be read */
309           struct ccase *index_case_old = NULL;
310
311           /* Read the case from the index casereader */
312           index_case_old = casereader_read (kmeans->index_rdr);
313
314           /* Set totaldiff, using the old_index */
315           totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
316
317           /* We have no use for the old case anymore, so unref it */
318           case_unref (index_case_old);
319         }
320       else
321         {
322           /* If this is the first run, then assume index is zero */
323           totaldiff += bestindex;
324         }
325
326       /* Set the value of the new index */
327       case_data_rw_idx (index_case_new, 0)->f = bestindex;
328
329       /* and write the new index to the casewriter */
330       casewriter_write (index_wtr, index_case_new);
331     }
332   casereader_destroy (cs);
333   /* We have now read through the entire index_rdr, so it's
334      of no use anymore */
335   casereader_destroy (kmeans->index_rdr);
336
337   /* Convert the writer into a reader, ready for the next iteration to read */
338   kmeans->index_rdr = casewriter_make_reader (index_wtr);
339
340   return (totaldiff);
341 }
342
343
344 static void
345 kmeans_order_groups (struct Kmeans *kmeans)
346 {
347   gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
348   gsl_matrix_get_col (v, kmeans->centers, 0);
349   gsl_sort_vector_index (kmeans->group_order, v);
350 }
351
352 /*
353 Main algorithm.
354 Does iterations, checks convergency
355 */
356 static void
357 kmeans_cluster (struct Kmeans *kmeans)
358 {
359   int i;
360   bool redo;
361   int diffs;
362   bool show_warning1;
363
364   show_warning1 = true;
365 cluster:
366   redo = false;
367   kmeans_randomize_centers (kmeans);
368   for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
369        kmeans->lastiter++)
370     {
371       diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
372       kmeans_recalculate_centers (kmeans);
373       if (show_warning1 && kmeans->ngroups > kmeans->n)
374         {
375           msg (MW,
376                _
377                ("Number of clusters may not be larger than the number of cases."));
378           show_warning1 = false;
379         }
380       if (diffs == 0)
381         break;
382     }
383
384   for (i = 0; i < kmeans->ngroups; i++)
385     {
386       if (kmeans->num_elements_groups->data[i] == 0)
387         {
388           kmeans->trials++;
389           if (kmeans->trials >= 3)
390             break;
391           redo = true;
392           break;
393         }
394     }
395   if (redo)
396     goto cluster;
397
398 }
399
400
401 /*
402 Reports centers of clusters.
403 initial parameter is optional for future use.
404 if initial is true, initial cluster centers are reported. Otherwise, resulted centers are reported.
405 */
406 static void
407 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
408 {
409   struct tab_table *t;
410   int nc, nr, heading_columns, currow;
411   int i, j;
412   nc = kmeans->ngroups + 1;
413   nr = kmeans->m + 4;
414   heading_columns = 1;
415   t = tab_create (nc, nr);
416   tab_headers (t, 0, nc - 1, 0, 1);
417   currow = 0;
418   if (!initial)
419     {
420       tab_title (t, _("Final Cluster Centers"));
421     }
422   else
423     {
424       tab_title (t, _("Initial Cluster Centers"));
425     }
426   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
427   tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
428   tab_hline (t, TAL_1, 1, nc - 1, 2);
429   currow += 2;
430
431   for (i = 0; i < kmeans->ngroups; i++)
432     {
433       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
434     }
435   currow++;
436   tab_hline (t, TAL_1, 1, nc - 1, currow);
437   currow++;
438   for (i = 0; i < kmeans->m; i++)
439     {
440       tab_text (t, 0, currow + i, TAB_LEFT,
441                 var_to_string (kmeans->variables[i]));
442     }
443
444   for (i = 0; i < kmeans->ngroups; i++)
445     {
446       for (j = 0; j < kmeans->m; j++)
447         {
448           if (!initial)
449             {
450               tab_double (t, i + 1, j + 4, TAB_CENTER,
451                           gsl_matrix_get (kmeans->centers,
452                                           kmeans->group_order->data[i], j),
453                           var_get_print_format (kmeans->variables[j]));
454             }
455           else
456             {
457               tab_double (t, i + 1, j + 4, TAB_CENTER,
458                           gsl_matrix_get (kmeans->initial_centers,
459                                           kmeans->group_order->data[i], j),
460                           var_get_print_format (kmeans->variables[j]));
461             }
462         }
463     }
464   tab_submit (t);
465 }
466
467
468 /*
469 Reports number of cases of each single cluster.
470 */
471 static void
472 quick_cluster_show_number_cases (struct Kmeans *kmeans)
473 {
474   struct tab_table *t;
475   int nc, nr;
476   int i, numelem;
477   long int total;
478   nc = 3;
479   nr = kmeans->ngroups + 1;
480   t = tab_create (nc, nr);
481   tab_headers (t, 0, nc - 1, 0, 0);
482   tab_title (t, _("Number of Cases in each Cluster"));
483   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
484   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
485
486   total = 0;
487   for (i = 0; i < kmeans->ngroups; i++)
488     {
489       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
490       numelem =
491         kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
492       tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
493       total += numelem;
494     }
495
496   tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
497   tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
498   tab_submit (t);
499 }
500
501 /*
502 Reports
503 */
504 static void
505 quick_cluster_show_results (struct Kmeans *kmeans)
506 {
507   kmeans_order_groups (kmeans);
508   //uncomment the line above for reporting initial centers
509   //quick_cluster_show_centers (kmeans, true);
510   quick_cluster_show_centers (kmeans, false);
511   quick_cluster_show_number_cases (kmeans);
512 }
513
514
515 int
516 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
517 {
518   struct Kmeans *kmeans;
519   bool ok;
520   const struct dictionary *dict = dataset_dict (ds);
521   const struct variable **variables;
522   struct casereader *cs;
523   int groups = 2;
524   int maxiter = 2;
525   size_t p;
526
527
528
529   if (!parse_variables_const (lexer, dict, &variables, &p,
530                               PV_NO_DUPLICATE | PV_NUMERIC))
531     {
532       msg (ME, _("Variables cannot be parsed"));
533       return (CMD_FAILURE);
534     }
535
536
537
538   if (lex_match (lexer, T_SLASH))
539     {
540       if (lex_match_id (lexer, "CRITERIA"))
541         {
542           lex_match (lexer, T_EQUALS);
543           while (lex_token (lexer) != T_ENDCMD
544                  && lex_token (lexer) != T_SLASH)
545             {
546               if (lex_match_id (lexer, "CLUSTERS"))
547                 {
548                   if (lex_force_match (lexer, T_LPAREN))
549                     {
550                       lex_force_int (lexer);
551                       groups = lex_integer (lexer);
552                       lex_get (lexer);
553                       lex_force_match (lexer, T_RPAREN);
554                     }
555                 }
556               else if (lex_match_id (lexer, "MXITER"))
557                 {
558                   if (lex_force_match (lexer, T_LPAREN))
559                     {
560                       lex_force_int (lexer);
561                       maxiter = lex_integer (lexer);
562                       lex_get (lexer);
563                       lex_force_match (lexer, T_RPAREN);
564                     }
565                 }
566               else
567                 {
568                   //further command set
569                   return (CMD_FAILURE);
570                 }
571             }
572         }
573     }
574
575
576   cs = proc_open (ds);
577
578
579   kmeans = kmeans_create (cs, variables, p, groups, maxiter);
580
581   kmeans->wv = dict_get_weight (dict);
582   kmeans_cluster (kmeans);
583   quick_cluster_show_results (kmeans);
584   ok = proc_commit (ds);
585
586   kmeans_destroy (kmeans);
587
588   return (ok);
589 }