f22658cd194012351e9cc488909276c1a4cd4808
[pspp-builds.git] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <math.h>
20
21 #include <libpspp/misc.h>
22
23 #include <libpspp/str.h>
24 #include <libpspp/message.h>
25
26
27 #include <data/dataset.h>
28 #include <data/missing-values.h>
29 #include <data/casereader.h>
30 #include <data/casewriter.h>
31 #include <data/casegrouper.h>
32 #include <data/dictionary.h>
33 #include <data/format.h>
34 #include <data/case.h>
35
36 #include <language/lexer/variable-parser.h>
37 #include <language/command.h>
38 #include <language/lexer/lexer.h>
39
40 #include <output/tab.h>
41 #include <output/text-item.h>
42
43 #include <stdio.h>
44 #include <stdlib.h>
45
46 #include <gsl/gsl_matrix.h>
47 #include <gsl/gsl_statistics.h>
48 #include <gsl/gsl_permutation.h>
49 #include <gsl/gsl_sort_vector.h>
50
51 #include <math/random.h>
52
53 #include "gettext.h"
54 #define _(msgid) gettext (msgid)
55 #define N_(msgid) msgid
56
57 /*
58 Struct KMeans:
59 Holds all of the information for the functions.
60 int n, holds the number of observation and its default value is -1.
61 We set it in kmeans_recalculate_centers in first invocation.
62 */
63 struct Kmeans
64 {
65   gsl_matrix *centers;          //Centers for groups
66   gsl_vector_long *num_elements_groups;
67   int ngroups;                  //Number of group. (Given by the user)
68   casenumber n;                 //Number of observations. By default it is -1.
69   int m;                        //Number of variables. (Given by the user)
70   int maxiter;                  //Maximum number of iterations (Given by the user)
71   int lastiter;                 //Show at which iteration it found the solution.
72   int trials;                   //If not convergence, how many times has clustering done.
73   gsl_matrix *initial_centers;  //Initial random centers
74   const struct variable **variables;    //Variables
75   gsl_permutation *group_order; //Handles group order for reporting
76   struct casereader *original_casereader;       //Casereader
77   struct caseproto *proto;
78   struct casereader *index_rdr; //We hold the group id's for each case in this structure
79   const struct variable *wv;    //Weighting variable
80 };
81
82 static struct Kmeans *kmeans_create (struct casereader *cs,
83                                      const struct variable **variables,
84                                      int m, int ngroups, int maxiter);
85
86 static void kmeans_randomize_centers (struct Kmeans *kmeans);
87
88 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c);
89
90 static void kmeans_recalculate_centers (struct Kmeans *kmeans);
91
92 static int
93 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans);
94
95 static void kmeans_order_groups (struct Kmeans *kmeans);
96
97 static void kmeans_cluster (struct Kmeans *kmeans);
98
99 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial);
100
101 static void quick_cluster_show_number_cases (struct Kmeans *kmeans);
102
103 static void quick_cluster_show_results (struct Kmeans *kmeans);
104
105 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
106
107 static void kmeans_destroy (struct Kmeans *kmeans);
108
109 /*
110 Creates and returns a struct of Kmeans with given casereader 'cs', parsed variables 'variables',
111 number of cases 'n', number of variables 'm', number of clusters and amount of maximum iterations.
112 */
113 static struct Kmeans *
114 kmeans_create (struct casereader *cs, const struct variable **variables,
115                int m, int ngroups, int maxiter)
116 {
117   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
118   kmeans->centers = gsl_matrix_alloc (ngroups, m);
119   kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
120   kmeans->ngroups = ngroups;
121   kmeans->n = 0;
122   kmeans->m = m;
123   kmeans->maxiter = maxiter;
124   kmeans->lastiter = 0;
125   kmeans->trials = 0;
126   kmeans->variables = variables;
127   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
128   kmeans->original_casereader = cs;
129   kmeans->initial_centers = NULL;
130
131   kmeans->proto = caseproto_create ();
132   kmeans->proto = caseproto_add_width (kmeans->proto, 0);
133   kmeans->index_rdr = NULL;
134   return (kmeans);
135 }
136
137
138 static void
139 kmeans_destroy (struct Kmeans *kmeans)
140 {
141   gsl_matrix_free (kmeans->centers);
142   gsl_matrix_free (kmeans->initial_centers);
143
144   gsl_vector_long_free (kmeans->num_elements_groups);
145
146   gsl_permutation_free (kmeans->group_order);
147
148   caseproto_unref (kmeans->proto);
149
150   /*
151      These reader and writer were already destroyed.
152      free (kmeans->original_casereader);
153      free (kmeans->index_rdr);
154    */
155
156   free (kmeans);
157 }
158
159
160
161 /*
162 Creates random centers using randomly selected cases from the data.
163 */
164 static void
165 kmeans_randomize_centers (struct Kmeans *kmeans)
166 {
167   int i, j;
168   for (i = 0; i < kmeans->ngroups; i++)
169     {
170       for (j = 0; j < kmeans->m; j++)
171         {
172           //gsl_matrix_set(kmeans->centers,i,j, gsl_rng_uniform (kmeans->rng));
173           if (i == j)
174             {
175               gsl_matrix_set (kmeans->centers, i, j, 1);
176             }
177           else
178             {
179               gsl_matrix_set (kmeans->centers, i, j, 0);
180             }
181         }
182     }
183 /*
184 If it is the first iteration, the variable kmeans->initial_centers is NULL and
185 it is created once for reporting issues. In SPSS, initial centers are shown in the reports
186 but in PSPP it is not shown now. I am leaving it here.
187 */
188   if (!kmeans->initial_centers)
189     {
190       kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
191       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
192     }
193 }
194
195
196 static int
197 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
198 {
199   int result = -1;
200   double x;
201   int i, j;
202   double dist;
203   double mindist;
204   mindist = INFINITY;
205   for (i = 0; i < kmeans->ngroups; i++)
206     {
207       dist = 0;
208       for (j = 0; j < kmeans->m; j++)
209         {
210           x = case_data (c, kmeans->variables[j])->f;
211           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
212         }
213       if (dist < mindist)
214         {
215           mindist = dist;
216           result = i;
217         }
218     }
219   return (result);
220 }
221
222
223
224
225 /*
226 Re-calculates the cluster centers
227 */
228 static void
229 kmeans_recalculate_centers (struct Kmeans *kmeans)
230 {
231   casenumber i;
232   int v, j;
233   double x, curval;
234   struct ccase *c;
235   struct ccase *c_index;
236   struct casereader *cs;
237   struct casereader *cs_index;
238   int index;
239   double weight;
240
241   i = 0;
242   cs = casereader_clone (kmeans->original_casereader);
243   cs_index = casereader_clone (kmeans->index_rdr);
244
245   gsl_matrix_set_all (kmeans->centers, 0.0);
246   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
247     {
248       c_index = casereader_read (cs_index);
249       index = case_data_idx (c_index, 0)->f;
250       for (v = 0; v < kmeans->m; ++v)
251         {
252           if (kmeans->wv)
253             {
254               weight = case_data (c, kmeans->wv)->f;
255             }
256           else
257             {
258               weight = 1.0;
259             }
260           x = case_data (c, kmeans->variables[v])->f * weight;
261           curval = gsl_matrix_get (kmeans->centers, index, v);
262           gsl_matrix_set (kmeans->centers, index, v, curval + x);
263         }
264       i++;
265       case_unref (c_index);
266     }
267   casereader_destroy (cs);
268   casereader_destroy (cs_index);
269
270   /* Getting number of cases */
271   if (kmeans->n == 0)
272     kmeans->n = i;
273
274   //We got sum of each center but we need averages.
275   //We are dividing centers to numobs. This may be inefficient and
276   //we should check it again.
277   for (i = 0; i < kmeans->ngroups; i++)
278     {
279       casenumber numobs = kmeans->num_elements_groups->data[i];
280       for (j = 0; j < kmeans->m; j++)
281         {
282           if (numobs > 0)
283             {
284               double *x = gsl_matrix_ptr (kmeans->centers, i, j);
285               *x /= numobs;
286             }
287           else
288             {
289               gsl_matrix_set (kmeans->centers, i, j, 0);
290             }
291         }
292     }
293 }
294
295
296 /*
297 The variable index in struct Kmeans holds integer values that represents the current groups of cases.
298 index[n]=a shows the nth case is belong to ath cluster.
299 This function calculates these indexes and returns the number of different cases of the new and old
300 index variables. If last two index variables are equal, there is no any enhancement of clustering.
301 */
302 static int
303 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
304 {
305   int totaldiff = 0;
306   double weight;
307   struct ccase *c;
308   struct casereader *cs = casereader_clone (kmeans->original_casereader);
309
310
311   /* A casewriter into which we will write the indexes */
312   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
313
314   gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
315
316   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
317     {
318       /* A case to hold the new index */
319       struct ccase *index_case_new = case_create (kmeans->proto);
320       int bestindex = kmeans_get_nearest_group (kmeans, c);
321       if (kmeans->wv)
322         {
323           weight = (casenumber) case_data (c, kmeans->wv)->f;
324         }
325       else
326         {
327           weight = 1.0;
328         }
329       kmeans->num_elements_groups->data[bestindex] += weight;
330       if (kmeans->index_rdr)
331         {
332           /* A case from which the old index will be read */
333           struct ccase *index_case_old = NULL;
334
335           /* Read the case from the index casereader */
336           index_case_old = casereader_read (kmeans->index_rdr);
337
338           /* Set totaldiff, using the old_index */
339           totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
340
341           /* We have no use for the old case anymore, so unref it */
342           case_unref (index_case_old);
343         }
344       else
345         {
346           /* If this is the first run, then assume index is zero */
347           totaldiff += bestindex;
348         }
349
350       /* Set the value of the new index */
351       case_data_rw_idx (index_case_new, 0)->f = bestindex;
352
353       /* and write the new index to the casewriter */
354       casewriter_write (index_wtr, index_case_new);
355     }
356   casereader_destroy (cs);
357   /* We have now read through the entire index_rdr, so it's
358      of no use anymore */
359   casereader_destroy (kmeans->index_rdr);
360
361   /* Convert the writer into a reader, ready for the next iteration to read */
362   kmeans->index_rdr = casewriter_make_reader (index_wtr);
363
364   return (totaldiff);
365 }
366
367
368 static void
369 kmeans_order_groups (struct Kmeans *kmeans)
370 {
371   gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
372   gsl_matrix_get_col (v, kmeans->centers, 0);
373   gsl_sort_vector_index (kmeans->group_order, v);
374 }
375
376 /*
377 Main algorithm.
378 Does iterations, checks convergency
379 */
380 static void
381 kmeans_cluster (struct Kmeans *kmeans)
382 {
383   int i;
384   bool redo;
385   int diffs;
386   bool show_warning1;
387
388   show_warning1 = true;
389 cluster:
390   redo = false;
391   kmeans_randomize_centers (kmeans);
392   for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
393        kmeans->lastiter++)
394     {
395       diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
396       kmeans_recalculate_centers (kmeans);
397       if (show_warning1 && kmeans->ngroups > kmeans->n)
398         {
399           msg (MW,
400                _
401                ("Number of clusters may not be larger than the number of cases."));
402           show_warning1 = false;
403         }
404       if (diffs == 0)
405         break;
406     }
407
408   for (i = 0; i < kmeans->ngroups; i++)
409     {
410       if (kmeans->num_elements_groups->data[i] == 0)
411         {
412           kmeans->trials++;
413           if (kmeans->trials >= 3)
414             break;
415           redo = true;
416           break;
417         }
418     }
419   if (redo)
420     goto cluster;
421
422 }
423
424
425 /*
426 Reports centers of clusters.
427 initial parameter is optional for future use.
428 if initial is true, initial cluster centers are reported. Otherwise, resulted centers are reported.
429 */
430 static void
431 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
432 {
433   struct tab_table *t;
434   int nc, nr, heading_columns, currow;
435   int i, j;
436   nc = kmeans->ngroups + 1;
437   nr = kmeans->m + 4;
438   heading_columns = 1;
439   t = tab_create (nc, nr);
440   tab_headers (t, 0, nc - 1, 0, 1);
441   currow = 0;
442   if (!initial)
443     {
444       tab_title (t, _("Final Cluster Centers"));
445     }
446   else
447     {
448       tab_title (t, _("Initial Cluster Centers"));
449     }
450   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
451   tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
452   tab_hline (t, TAL_1, 1, nc - 1, 2);
453   currow += 2;
454
455   for (i = 0; i < kmeans->ngroups; i++)
456     {
457       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
458     }
459   currow++;
460   tab_hline (t, TAL_1, 1, nc - 1, currow);
461   currow++;
462   for (i = 0; i < kmeans->m; i++)
463     {
464       tab_text (t, 0, currow + i, TAB_LEFT,
465                 var_to_string (kmeans->variables[i]));
466     }
467
468   for (i = 0; i < kmeans->ngroups; i++)
469     {
470       for (j = 0; j < kmeans->m; j++)
471         {
472           if (!initial)
473             {
474               tab_double (t, i + 1, j + 4, TAB_CENTER,
475                           gsl_matrix_get (kmeans->centers,
476                                           kmeans->group_order->data[i], j),
477                           var_get_print_format (kmeans->variables[j]));
478             }
479           else
480             {
481               tab_double (t, i + 1, j + 4, TAB_CENTER,
482                           gsl_matrix_get (kmeans->initial_centers,
483                                           kmeans->group_order->data[i], j),
484                           var_get_print_format (kmeans->variables[j]));
485             }
486         }
487     }
488   tab_submit (t);
489 }
490
491
492 /*
493 Reports number of cases of each single cluster.
494 */
495 static void
496 quick_cluster_show_number_cases (struct Kmeans *kmeans)
497 {
498   struct tab_table *t;
499   int nc, nr;
500   int i, numelem;
501   long int total;
502   nc = 3;
503   nr = kmeans->ngroups + 1;
504   t = tab_create (nc, nr);
505   tab_headers (t, 0, nc - 1, 0, 0);
506   tab_title (t, _("Number of Cases in each Cluster"));
507   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
508   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
509
510   total = 0;
511   for (i = 0; i < kmeans->ngroups; i++)
512     {
513       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
514       numelem =
515         kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
516       tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
517       total += numelem;
518     }
519
520   tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
521   tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
522   tab_submit (t);
523 }
524
525 /*
526 Reports
527 */
528 static void
529 quick_cluster_show_results (struct Kmeans *kmeans)
530 {
531   kmeans_order_groups (kmeans);
532   //uncomment the line above for reporting initial centers
533   //quick_cluster_show_centers (kmeans, true);
534   quick_cluster_show_centers (kmeans, false);
535   quick_cluster_show_number_cases (kmeans);
536 }
537
538
539 int
540 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
541 {
542   struct Kmeans *kmeans;
543   bool ok;
544   const struct dictionary *dict = dataset_dict (ds);
545   const struct variable **variables;
546   struct casereader *cs;
547   int groups = 2;
548   int maxiter = 2;
549   size_t p;
550
551
552
553   if (!parse_variables_const (lexer, dict, &variables, &p,
554                               PV_NO_DUPLICATE | PV_NUMERIC))
555     {
556       msg (ME, _("Variables cannot be parsed"));
557       return (CMD_FAILURE);
558     }
559
560
561
562   if (lex_match (lexer, T_SLASH))
563     {
564       if (lex_match_id (lexer, "CRITERIA"))
565         {
566           lex_match (lexer, T_EQUALS);
567           while (lex_token (lexer) != T_ENDCMD
568                  && lex_token (lexer) != T_SLASH)
569             {
570               if (lex_match_id (lexer, "CLUSTERS"))
571                 {
572                   if (lex_force_match (lexer, T_LPAREN))
573                     {
574                       lex_force_int (lexer);
575                       groups = lex_integer (lexer);
576                       lex_get (lexer);
577                       lex_force_match (lexer, T_RPAREN);
578                     }
579                 }
580               else if (lex_match_id (lexer, "MXITER"))
581                 {
582                   if (lex_force_match (lexer, T_LPAREN))
583                     {
584                       lex_force_int (lexer);
585                       maxiter = lex_integer (lexer);
586                       lex_get (lexer);
587                       lex_force_match (lexer, T_RPAREN);
588                     }
589                 }
590               else
591                 {
592                   //further command set
593                   return (CMD_FAILURE);
594                 }
595             }
596         }
597     }
598
599
600   cs = proc_open (ds);
601
602
603   kmeans = kmeans_create (cs, variables, p, groups, maxiter);
604
605   kmeans->wv = dict_get_weight (dict);
606   kmeans_cluster (kmeans);
607   quick_cluster_show_results (kmeans);
608   ok = proc_commit (ds);
609
610   kmeans_destroy (kmeans);
611
612   return (ok);
613 }