QUICK CLUSTER: Update #include directives to match current style.
[pspp-builds.git] / src / language / stats / quick-cluster.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
23 #include <math.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49 /*
50 Struct KMeans:
51 Holds all of the information for the functions.
52 int n, holds the number of observation and its default value is -1.
53 We set it in kmeans_recalculate_centers in first invocation.
54 */
55 struct Kmeans
56 {
57   gsl_matrix *centers;          //Centers for groups
58   gsl_vector_long *num_elements_groups;
59   int ngroups;                  //Number of group. (Given by the user)
60   casenumber n;                 //Number of observations. By default it is -1.
61   int m;                        //Number of variables. (Given by the user)
62   int maxiter;                  //Maximum number of iterations (Given by the user)
63   int lastiter;                 //Show at which iteration it found the solution.
64   int trials;                   //If not convergence, how many times has clustering done.
65   gsl_matrix *initial_centers;  //Initial random centers
66   const struct variable **variables;    //Variables
67   gsl_permutation *group_order; //Handles group order for reporting
68   struct casereader *original_casereader;       //Casereader
69   struct caseproto *proto;
70   struct casereader *index_rdr; //We hold the group id's for each case in this structure
71   const struct variable *wv;    //Weighting variable
72 };
73
74 static struct Kmeans *kmeans_create (struct casereader *cs,
75                                      const struct variable **variables,
76                                      int m, int ngroups, int maxiter);
77
78 static void kmeans_randomize_centers (struct Kmeans *kmeans);
79
80 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c);
81
82 static void kmeans_recalculate_centers (struct Kmeans *kmeans);
83
84 static int
85 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans);
86
87 static void kmeans_order_groups (struct Kmeans *kmeans);
88
89 static void kmeans_cluster (struct Kmeans *kmeans);
90
91 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial);
92
93 static void quick_cluster_show_number_cases (struct Kmeans *kmeans);
94
95 static void quick_cluster_show_results (struct Kmeans *kmeans);
96
97 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
98
99 static void kmeans_destroy (struct Kmeans *kmeans);
100
101 /*
102 Creates and returns a struct of Kmeans with given casereader 'cs', parsed variables 'variables',
103 number of cases 'n', number of variables 'm', number of clusters and amount of maximum iterations.
104 */
105 static struct Kmeans *
106 kmeans_create (struct casereader *cs, const struct variable **variables,
107                int m, int ngroups, int maxiter)
108 {
109   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
110   kmeans->centers = gsl_matrix_alloc (ngroups, m);
111   kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
112   kmeans->ngroups = ngroups;
113   kmeans->n = 0;
114   kmeans->m = m;
115   kmeans->maxiter = maxiter;
116   kmeans->lastiter = 0;
117   kmeans->trials = 0;
118   kmeans->variables = variables;
119   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
120   kmeans->original_casereader = cs;
121   kmeans->initial_centers = NULL;
122
123   kmeans->proto = caseproto_create ();
124   kmeans->proto = caseproto_add_width (kmeans->proto, 0);
125   kmeans->index_rdr = NULL;
126   return (kmeans);
127 }
128
129
130 static void
131 kmeans_destroy (struct Kmeans *kmeans)
132 {
133   gsl_matrix_free (kmeans->centers);
134   gsl_matrix_free (kmeans->initial_centers);
135
136   gsl_vector_long_free (kmeans->num_elements_groups);
137
138   gsl_permutation_free (kmeans->group_order);
139
140   caseproto_unref (kmeans->proto);
141
142   /*
143      These reader and writer were already destroyed.
144      free (kmeans->original_casereader);
145      free (kmeans->index_rdr);
146    */
147
148   free (kmeans);
149 }
150
151
152
153 /*
154 Creates random centers using randomly selected cases from the data.
155 */
156 static void
157 kmeans_randomize_centers (struct Kmeans *kmeans)
158 {
159   int i, j;
160   for (i = 0; i < kmeans->ngroups; i++)
161     {
162       for (j = 0; j < kmeans->m; j++)
163         {
164           //gsl_matrix_set(kmeans->centers,i,j, gsl_rng_uniform (kmeans->rng));
165           if (i == j)
166             {
167               gsl_matrix_set (kmeans->centers, i, j, 1);
168             }
169           else
170             {
171               gsl_matrix_set (kmeans->centers, i, j, 0);
172             }
173         }
174     }
175 /*
176 If it is the first iteration, the variable kmeans->initial_centers is NULL and
177 it is created once for reporting issues. In SPSS, initial centers are shown in the reports
178 but in PSPP it is not shown now. I am leaving it here.
179 */
180   if (!kmeans->initial_centers)
181     {
182       kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
183       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
184     }
185 }
186
187
188 static int
189 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
190 {
191   int result = -1;
192   double x;
193   int i, j;
194   double dist;
195   double mindist;
196   mindist = INFINITY;
197   for (i = 0; i < kmeans->ngroups; i++)
198     {
199       dist = 0;
200       for (j = 0; j < kmeans->m; j++)
201         {
202           x = case_data (c, kmeans->variables[j])->f;
203           dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
204         }
205       if (dist < mindist)
206         {
207           mindist = dist;
208           result = i;
209         }
210     }
211   return (result);
212 }
213
214
215
216
217 /*
218 Re-calculates the cluster centers
219 */
220 static void
221 kmeans_recalculate_centers (struct Kmeans *kmeans)
222 {
223   casenumber i;
224   int v, j;
225   double x, curval;
226   struct ccase *c;
227   struct ccase *c_index;
228   struct casereader *cs;
229   struct casereader *cs_index;
230   int index;
231   double weight;
232
233   i = 0;
234   cs = casereader_clone (kmeans->original_casereader);
235   cs_index = casereader_clone (kmeans->index_rdr);
236
237   gsl_matrix_set_all (kmeans->centers, 0.0);
238   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
239     {
240       c_index = casereader_read (cs_index);
241       index = case_data_idx (c_index, 0)->f;
242       for (v = 0; v < kmeans->m; ++v)
243         {
244           if (kmeans->wv)
245             {
246               weight = case_data (c, kmeans->wv)->f;
247             }
248           else
249             {
250               weight = 1.0;
251             }
252           x = case_data (c, kmeans->variables[v])->f * weight;
253           curval = gsl_matrix_get (kmeans->centers, index, v);
254           gsl_matrix_set (kmeans->centers, index, v, curval + x);
255         }
256       i++;
257       case_unref (c_index);
258     }
259   casereader_destroy (cs);
260   casereader_destroy (cs_index);
261
262   /* Getting number of cases */
263   if (kmeans->n == 0)
264     kmeans->n = i;
265
266   //We got sum of each center but we need averages.
267   //We are dividing centers to numobs. This may be inefficient and
268   //we should check it again.
269   for (i = 0; i < kmeans->ngroups; i++)
270     {
271       casenumber numobs = kmeans->num_elements_groups->data[i];
272       for (j = 0; j < kmeans->m; j++)
273         {
274           if (numobs > 0)
275             {
276               double *x = gsl_matrix_ptr (kmeans->centers, i, j);
277               *x /= numobs;
278             }
279           else
280             {
281               gsl_matrix_set (kmeans->centers, i, j, 0);
282             }
283         }
284     }
285 }
286
287
288 /*
289 The variable index in struct Kmeans holds integer values that represents the current groups of cases.
290 index[n]=a shows the nth case is belong to ath cluster.
291 This function calculates these indexes and returns the number of different cases of the new and old
292 index variables. If last two index variables are equal, there is no any enhancement of clustering.
293 */
294 static int
295 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
296 {
297   int totaldiff = 0;
298   double weight;
299   struct ccase *c;
300   struct casereader *cs = casereader_clone (kmeans->original_casereader);
301
302
303   /* A casewriter into which we will write the indexes */
304   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
305
306   gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
307
308   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
309     {
310       /* A case to hold the new index */
311       struct ccase *index_case_new = case_create (kmeans->proto);
312       int bestindex = kmeans_get_nearest_group (kmeans, c);
313       if (kmeans->wv)
314         {
315           weight = (casenumber) case_data (c, kmeans->wv)->f;
316         }
317       else
318         {
319           weight = 1.0;
320         }
321       kmeans->num_elements_groups->data[bestindex] += weight;
322       if (kmeans->index_rdr)
323         {
324           /* A case from which the old index will be read */
325           struct ccase *index_case_old = NULL;
326
327           /* Read the case from the index casereader */
328           index_case_old = casereader_read (kmeans->index_rdr);
329
330           /* Set totaldiff, using the old_index */
331           totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
332
333           /* We have no use for the old case anymore, so unref it */
334           case_unref (index_case_old);
335         }
336       else
337         {
338           /* If this is the first run, then assume index is zero */
339           totaldiff += bestindex;
340         }
341
342       /* Set the value of the new index */
343       case_data_rw_idx (index_case_new, 0)->f = bestindex;
344
345       /* and write the new index to the casewriter */
346       casewriter_write (index_wtr, index_case_new);
347     }
348   casereader_destroy (cs);
349   /* We have now read through the entire index_rdr, so it's
350      of no use anymore */
351   casereader_destroy (kmeans->index_rdr);
352
353   /* Convert the writer into a reader, ready for the next iteration to read */
354   kmeans->index_rdr = casewriter_make_reader (index_wtr);
355
356   return (totaldiff);
357 }
358
359
360 static void
361 kmeans_order_groups (struct Kmeans *kmeans)
362 {
363   gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
364   gsl_matrix_get_col (v, kmeans->centers, 0);
365   gsl_sort_vector_index (kmeans->group_order, v);
366 }
367
368 /*
369 Main algorithm.
370 Does iterations, checks convergency
371 */
372 static void
373 kmeans_cluster (struct Kmeans *kmeans)
374 {
375   int i;
376   bool redo;
377   int diffs;
378   bool show_warning1;
379
380   show_warning1 = true;
381 cluster:
382   redo = false;
383   kmeans_randomize_centers (kmeans);
384   for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
385        kmeans->lastiter++)
386     {
387       diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
388       kmeans_recalculate_centers (kmeans);
389       if (show_warning1 && kmeans->ngroups > kmeans->n)
390         {
391           msg (MW,
392                _
393                ("Number of clusters may not be larger than the number of cases."));
394           show_warning1 = false;
395         }
396       if (diffs == 0)
397         break;
398     }
399
400   for (i = 0; i < kmeans->ngroups; i++)
401     {
402       if (kmeans->num_elements_groups->data[i] == 0)
403         {
404           kmeans->trials++;
405           if (kmeans->trials >= 3)
406             break;
407           redo = true;
408           break;
409         }
410     }
411   if (redo)
412     goto cluster;
413
414 }
415
416
417 /*
418 Reports centers of clusters.
419 initial parameter is optional for future use.
420 if initial is true, initial cluster centers are reported. Otherwise, resulted centers are reported.
421 */
422 static void
423 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
424 {
425   struct tab_table *t;
426   int nc, nr, heading_columns, currow;
427   int i, j;
428   nc = kmeans->ngroups + 1;
429   nr = kmeans->m + 4;
430   heading_columns = 1;
431   t = tab_create (nc, nr);
432   tab_headers (t, 0, nc - 1, 0, 1);
433   currow = 0;
434   if (!initial)
435     {
436       tab_title (t, _("Final Cluster Centers"));
437     }
438   else
439     {
440       tab_title (t, _("Initial Cluster Centers"));
441     }
442   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
443   tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
444   tab_hline (t, TAL_1, 1, nc - 1, 2);
445   currow += 2;
446
447   for (i = 0; i < kmeans->ngroups; i++)
448     {
449       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
450     }
451   currow++;
452   tab_hline (t, TAL_1, 1, nc - 1, currow);
453   currow++;
454   for (i = 0; i < kmeans->m; i++)
455     {
456       tab_text (t, 0, currow + i, TAB_LEFT,
457                 var_to_string (kmeans->variables[i]));
458     }
459
460   for (i = 0; i < kmeans->ngroups; i++)
461     {
462       for (j = 0; j < kmeans->m; j++)
463         {
464           if (!initial)
465             {
466               tab_double (t, i + 1, j + 4, TAB_CENTER,
467                           gsl_matrix_get (kmeans->centers,
468                                           kmeans->group_order->data[i], j),
469                           var_get_print_format (kmeans->variables[j]));
470             }
471           else
472             {
473               tab_double (t, i + 1, j + 4, TAB_CENTER,
474                           gsl_matrix_get (kmeans->initial_centers,
475                                           kmeans->group_order->data[i], j),
476                           var_get_print_format (kmeans->variables[j]));
477             }
478         }
479     }
480   tab_submit (t);
481 }
482
483
484 /*
485 Reports number of cases of each single cluster.
486 */
487 static void
488 quick_cluster_show_number_cases (struct Kmeans *kmeans)
489 {
490   struct tab_table *t;
491   int nc, nr;
492   int i, numelem;
493   long int total;
494   nc = 3;
495   nr = kmeans->ngroups + 1;
496   t = tab_create (nc, nr);
497   tab_headers (t, 0, nc - 1, 0, 0);
498   tab_title (t, _("Number of Cases in each Cluster"));
499   tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
500   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
501
502   total = 0;
503   for (i = 0; i < kmeans->ngroups; i++)
504     {
505       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
506       numelem =
507         kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
508       tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
509       total += numelem;
510     }
511
512   tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
513   tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
514   tab_submit (t);
515 }
516
517 /*
518 Reports
519 */
520 static void
521 quick_cluster_show_results (struct Kmeans *kmeans)
522 {
523   kmeans_order_groups (kmeans);
524   //uncomment the line above for reporting initial centers
525   //quick_cluster_show_centers (kmeans, true);
526   quick_cluster_show_centers (kmeans, false);
527   quick_cluster_show_number_cases (kmeans);
528 }
529
530
531 int
532 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
533 {
534   struct Kmeans *kmeans;
535   bool ok;
536   const struct dictionary *dict = dataset_dict (ds);
537   const struct variable **variables;
538   struct casereader *cs;
539   int groups = 2;
540   int maxiter = 2;
541   size_t p;
542
543
544
545   if (!parse_variables_const (lexer, dict, &variables, &p,
546                               PV_NO_DUPLICATE | PV_NUMERIC))
547     {
548       msg (ME, _("Variables cannot be parsed"));
549       return (CMD_FAILURE);
550     }
551
552
553
554   if (lex_match (lexer, T_SLASH))
555     {
556       if (lex_match_id (lexer, "CRITERIA"))
557         {
558           lex_match (lexer, T_EQUALS);
559           while (lex_token (lexer) != T_ENDCMD
560                  && lex_token (lexer) != T_SLASH)
561             {
562               if (lex_match_id (lexer, "CLUSTERS"))
563                 {
564                   if (lex_force_match (lexer, T_LPAREN))
565                     {
566                       lex_force_int (lexer);
567                       groups = lex_integer (lexer);
568                       lex_get (lexer);
569                       lex_force_match (lexer, T_RPAREN);
570                     }
571                 }
572               else if (lex_match_id (lexer, "MXITER"))
573                 {
574                   if (lex_force_match (lexer, T_LPAREN))
575                     {
576                       lex_force_int (lexer);
577                       maxiter = lex_integer (lexer);
578                       lex_get (lexer);
579                       lex_force_match (lexer, T_RPAREN);
580                     }
581                 }
582               else
583                 {
584                   //further command set
585                   return (CMD_FAILURE);
586                 }
587             }
588         }
589     }
590
591
592   cs = proc_open (ds);
593
594
595   kmeans = kmeans_create (cs, variables, p, groups, maxiter);
596
597   kmeans->wv = dict_get_weight (dict);
598   kmeans_cluster (kmeans);
599   quick_cluster_show_results (kmeans);
600   ok = proc_commit (ds);
601
602   kmeans_destroy (kmeans);
603
604   return (ok);
605 }