1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2011 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
49 /* Holds all of the information for the functions. int n, holds the number of
50 observation and its default value is -1. We set it in
51 kmeans_recalculate_centers in first invocation. */
54 gsl_matrix *centers; /* Centers for groups. */
55 gsl_vector_long *num_elements_groups;
56 int ngroups; /* Number of group. (Given by the user) */
57 casenumber n; /* Number of observations (default -1). */
58 int m; /* Number of variables. (Given by the user) */
59 int maxiter; /* Maximum iterations (Given by the user) */
60 int lastiter; /* Iteration where it found the solution. */
61 int trials; /* If not convergence, how many times has
63 gsl_matrix *initial_centers; /* Initial random centers. */
64 const struct variable **variables;
65 gsl_permutation *group_order; /* Group order for reporting. */
66 struct casereader *original_casereader;
67 struct caseproto *proto;
68 struct casereader *index_rdr; /* Group ids for each case. */
69 const struct variable *wv; /* Weighting variable. */
72 static struct Kmeans *kmeans_create (struct casereader *cs,
73 const struct variable **variables,
74 int m, int ngroups, int maxiter);
76 static void kmeans_randomize_centers (struct Kmeans *kmeans);
78 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c);
80 static void kmeans_recalculate_centers (struct Kmeans *kmeans);
83 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans);
85 static void kmeans_order_groups (struct Kmeans *kmeans);
87 static void kmeans_cluster (struct Kmeans *kmeans);
89 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial);
91 static void quick_cluster_show_number_cases (struct Kmeans *kmeans);
93 static void quick_cluster_show_results (struct Kmeans *kmeans);
95 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
97 static void kmeans_destroy (struct Kmeans *kmeans);
99 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
100 variables 'variables', number of cases 'n', number of variables 'm', number
101 of clusters and amount of maximum iterations. */
102 static struct Kmeans *
103 kmeans_create (struct casereader *cs, const struct variable **variables,
104 int m, int ngroups, int maxiter)
106 struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
107 kmeans->centers = gsl_matrix_alloc (ngroups, m);
108 kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
109 kmeans->ngroups = ngroups;
112 kmeans->maxiter = maxiter;
113 kmeans->lastiter = 0;
115 kmeans->variables = variables;
116 kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
117 kmeans->original_casereader = cs;
118 kmeans->initial_centers = NULL;
120 kmeans->proto = caseproto_create ();
121 kmeans->proto = caseproto_add_width (kmeans->proto, 0);
122 kmeans->index_rdr = NULL;
127 kmeans_destroy (struct Kmeans *kmeans)
129 gsl_matrix_free (kmeans->centers);
130 gsl_matrix_free (kmeans->initial_centers);
132 gsl_vector_long_free (kmeans->num_elements_groups);
134 gsl_permutation_free (kmeans->group_order);
136 caseproto_unref (kmeans->proto);
139 These reader and writer were already destroyed.
140 free (kmeans->original_casereader);
141 free (kmeans->index_rdr);
147 /* Creates random centers using randomly selected cases from the data. */
149 kmeans_randomize_centers (struct Kmeans *kmeans)
152 for (i = 0; i < kmeans->ngroups; i++)
154 for (j = 0; j < kmeans->m; j++)
158 gsl_matrix_set (kmeans->centers, i, j, 1);
162 gsl_matrix_set (kmeans->centers, i, j, 0);
166 /* If it is the first iteration, the variable kmeans->initial_centers is NULL
167 and it is created once for reporting issues. In SPSS, initial centers are
168 shown in the reports but in PSPP it is not shown now. I am leaving it
170 if (!kmeans->initial_centers)
172 kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
173 gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
178 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
186 for (i = 0; i < kmeans->ngroups; i++)
189 for (j = 0; j < kmeans->m; j++)
191 x = case_data (c, kmeans->variables[j])->f;
192 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
203 /* Re-calculate the cluster centers. */
205 kmeans_recalculate_centers (struct Kmeans *kmeans)
211 struct ccase *c_index;
212 struct casereader *cs;
213 struct casereader *cs_index;
218 cs = casereader_clone (kmeans->original_casereader);
219 cs_index = casereader_clone (kmeans->index_rdr);
221 gsl_matrix_set_all (kmeans->centers, 0.0);
222 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
224 c_index = casereader_read (cs_index);
225 index = case_data_idx (c_index, 0)->f;
226 for (v = 0; v < kmeans->m; ++v)
230 weight = case_data (c, kmeans->wv)->f;
236 x = case_data (c, kmeans->variables[v])->f * weight;
237 curval = gsl_matrix_get (kmeans->centers, index, v);
238 gsl_matrix_set (kmeans->centers, index, v, curval + x);
241 case_unref (c_index);
243 casereader_destroy (cs);
244 casereader_destroy (cs_index);
246 /* Getting number of cases */
250 /* We got sum of each center but we need averages.
251 We are dividing centers to numobs. This may be inefficient and
252 we should check it again. */
253 for (i = 0; i < kmeans->ngroups; i++)
255 casenumber numobs = kmeans->num_elements_groups->data[i];
256 for (j = 0; j < kmeans->m; j++)
260 double *x = gsl_matrix_ptr (kmeans->centers, i, j);
265 gsl_matrix_set (kmeans->centers, i, j, 0);
271 /* The variable index in struct Kmeans holds integer values that represents the
272 current groups of cases. index[n]=a shows the nth case is belong to ath
273 cluster. This function calculates these indexes and returns the number of
274 different cases of the new and old index variables. If last two index
275 variables are equal, there is no any enhancement of clustering. */
277 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
282 struct casereader *cs = casereader_clone (kmeans->original_casereader);
284 /* A casewriter into which we will write the indexes. */
285 struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
287 gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
289 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
291 /* A case to hold the new index. */
292 struct ccase *index_case_new = case_create (kmeans->proto);
293 int bestindex = kmeans_get_nearest_group (kmeans, c);
296 weight = (casenumber) case_data (c, kmeans->wv)->f;
302 kmeans->num_elements_groups->data[bestindex] += weight;
303 if (kmeans->index_rdr)
305 /* A case from which the old index will be read. */
306 struct ccase *index_case_old = NULL;
308 /* Read the case from the index casereader. */
309 index_case_old = casereader_read (kmeans->index_rdr);
311 /* Set totaldiff, using the old_index. */
312 totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
314 /* We have no use for the old case anymore, so unref it. */
315 case_unref (index_case_old);
319 /* If this is the first run, then assume index is zero. */
320 totaldiff += bestindex;
323 /* Set the value of the new inde.x */
324 case_data_rw_idx (index_case_new, 0)->f = bestindex;
326 /* and write the new index to the casewriter */
327 casewriter_write (index_wtr, index_case_new);
329 casereader_destroy (cs);
330 /* We have now read through the entire index_rdr, so it's of no use
332 casereader_destroy (kmeans->index_rdr);
334 /* Convert the writer into a reader, ready for the next iteration to read */
335 kmeans->index_rdr = casewriter_make_reader (index_wtr);
341 kmeans_order_groups (struct Kmeans *kmeans)
343 gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
344 gsl_matrix_get_col (v, kmeans->centers, 0);
345 gsl_sort_vector_index (kmeans->group_order, v);
349 Does iterations, checks convergency. */
351 kmeans_cluster (struct Kmeans *kmeans)
358 show_warning1 = true;
361 kmeans_randomize_centers (kmeans);
362 for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
365 diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
366 kmeans_recalculate_centers (kmeans);
367 if (show_warning1 && kmeans->ngroups > kmeans->n)
369 msg (MW, _("Number of clusters may not be larger than the number "
371 show_warning1 = false;
377 for (i = 0; i < kmeans->ngroups; i++)
379 if (kmeans->num_elements_groups->data[i] == 0)
382 if (kmeans->trials >= 3)
393 /* Reports centers of clusters.
394 Initial parameter is optional for future use.
395 If initial is true, initial cluster centers are reported. Otherwise,
396 resulted centers are reported. */
398 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
401 int nc, nr, heading_columns, currow;
403 nc = kmeans->ngroups + 1;
406 t = tab_create (nc, nr);
407 tab_headers (t, 0, nc - 1, 0, 1);
411 tab_title (t, _("Final Cluster Centers"));
415 tab_title (t, _("Initial Cluster Centers"));
417 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
418 tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
419 tab_hline (t, TAL_1, 1, nc - 1, 2);
422 for (i = 0; i < kmeans->ngroups; i++)
424 tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
427 tab_hline (t, TAL_1, 1, nc - 1, currow);
429 for (i = 0; i < kmeans->m; i++)
431 tab_text (t, 0, currow + i, TAB_LEFT,
432 var_to_string (kmeans->variables[i]));
435 for (i = 0; i < kmeans->ngroups; i++)
437 for (j = 0; j < kmeans->m; j++)
441 tab_double (t, i + 1, j + 4, TAB_CENTER,
442 gsl_matrix_get (kmeans->centers,
443 kmeans->group_order->data[i], j),
444 var_get_print_format (kmeans->variables[j]));
448 tab_double (t, i + 1, j + 4, TAB_CENTER,
449 gsl_matrix_get (kmeans->initial_centers,
450 kmeans->group_order->data[i], j),
451 var_get_print_format (kmeans->variables[j]));
458 /* Reports number of cases of each single cluster. */
460 quick_cluster_show_number_cases (struct Kmeans *kmeans)
467 nr = kmeans->ngroups + 1;
468 t = tab_create (nc, nr);
469 tab_headers (t, 0, nc - 1, 0, 0);
470 tab_title (t, _("Number of Cases in each Cluster"));
471 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
472 tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
475 for (i = 0; i < kmeans->ngroups; i++)
477 tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
479 kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
480 tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
484 tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
485 tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
491 quick_cluster_show_results (struct Kmeans *kmeans)
493 kmeans_order_groups (kmeans);
494 /* Uncomment the line below for reporting initial centers. */
495 /* quick_cluster_show_centers (kmeans, true); */
496 quick_cluster_show_centers (kmeans, false);
497 quick_cluster_show_number_cases (kmeans);
501 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
503 struct Kmeans *kmeans;
505 const struct dictionary *dict = dataset_dict (ds);
506 const struct variable **variables;
507 struct casereader *cs;
512 if (!parse_variables_const (lexer, dict, &variables, &p,
513 PV_NO_DUPLICATE | PV_NUMERIC))
515 msg (ME, _("Variables cannot be parsed"));
516 return (CMD_FAILURE);
519 if (lex_match (lexer, T_SLASH))
521 if (lex_match_id (lexer, "CRITERIA"))
523 lex_match (lexer, T_EQUALS);
524 while (lex_token (lexer) != T_ENDCMD
525 && lex_token (lexer) != T_SLASH)
527 if (lex_match_id (lexer, "CLUSTERS"))
529 if (lex_force_match (lexer, T_LPAREN))
531 lex_force_int (lexer);
532 groups = lex_integer (lexer);
534 lex_force_match (lexer, T_RPAREN);
537 else if (lex_match_id (lexer, "MXITER"))
539 if (lex_force_match (lexer, T_LPAREN))
541 lex_force_int (lexer);
542 maxiter = lex_integer (lexer);
544 lex_force_match (lexer, T_RPAREN);
555 kmeans = kmeans_create (cs, variables, p, groups, maxiter);
557 kmeans->wv = dict_get_weight (dict);
558 kmeans_cluster (kmeans);
559 quick_cluster_show_results (kmeans);
560 ok = proc_commit (ds);
562 kmeans_destroy (kmeans);