1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2011 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
21 #include <libpspp/misc.h>
23 #include <libpspp/str.h>
24 #include <libpspp/message.h>
27 #include <data/dataset.h>
28 #include <data/missing-values.h>
29 #include <data/casereader.h>
30 #include <data/casewriter.h>
31 #include <data/casegrouper.h>
32 #include <data/dictionary.h>
33 #include <data/format.h>
34 #include <data/case.h>
36 #include <language/lexer/variable-parser.h>
37 #include <language/command.h>
38 #include <language/lexer/lexer.h>
40 #include <output/tab.h>
41 #include <output/text-item.h>
46 #include <gsl/gsl_matrix.h>
47 #include <gsl/gsl_statistics.h>
48 #include <gsl/gsl_permutation.h>
49 #include <gsl/gsl_sort_vector.h>
51 #include <math/random.h>
54 #define _(msgid) gettext (msgid)
55 #define N_(msgid) msgid
59 Holds all of the information for the functions.
60 int n, holds the number of observation and its default value is -1.
61 We set it in kmeans_recalculate_centers in first invocation.
65 gsl_matrix *centers; //Centers for groups
66 gsl_vector_long *num_elements_groups;
67 int ngroups; //Number of group. (Given by the user)
68 casenumber n; //Number of observations. By default it is -1.
69 int m; //Number of variables. (Given by the user)
70 int maxiter; //Maximum number of iterations (Given by the user)
71 int lastiter; //Show at which iteration it found the solution.
72 int trials; //If not convergence, how many times has clustering done.
73 gsl_matrix *initial_centers; //Initial random centers
74 const struct variable **variables; //Variables
75 gsl_permutation *group_order; //Handles group order for reporting
76 struct casereader *original_casereader; //Casereader
77 struct caseproto *proto;
78 struct casereader *index_rdr; //We hold the group id's for each case in this structure
79 const struct variable *wv; //Weighting variable
82 static struct Kmeans *kmeans_create (struct casereader *cs,
83 const struct variable **variables,
84 int m, int ngroups, int maxiter);
86 static void kmeans_randomize_centers (struct Kmeans *kmeans);
88 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c);
90 static void kmeans_recalculate_centers (struct Kmeans *kmeans);
93 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans);
95 static void kmeans_order_groups (struct Kmeans *kmeans);
97 static void kmeans_cluster (struct Kmeans *kmeans);
99 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial);
101 static void quick_cluster_show_number_cases (struct Kmeans *kmeans);
103 static void quick_cluster_show_results (struct Kmeans *kmeans);
105 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
107 static void kmeans_destroy (struct Kmeans *kmeans);
110 Creates and returns a struct of Kmeans with given casereader 'cs', parsed variables 'variables',
111 number of cases 'n', number of variables 'm', number of clusters and amount of maximum iterations.
113 static struct Kmeans *
114 kmeans_create (struct casereader *cs, const struct variable **variables,
115 int m, int ngroups, int maxiter)
117 struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
118 kmeans->centers = gsl_matrix_alloc (ngroups, m);
119 kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
120 kmeans->ngroups = ngroups;
123 kmeans->maxiter = maxiter;
124 kmeans->lastiter = 0;
126 kmeans->variables = variables;
127 kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
128 kmeans->original_casereader = cs;
129 kmeans->initial_centers = NULL;
131 kmeans->proto = caseproto_create ();
132 kmeans->proto = caseproto_add_width (kmeans->proto, 0);
133 kmeans->index_rdr = NULL;
139 kmeans_destroy (struct Kmeans *kmeans)
141 gsl_matrix_free (kmeans->centers);
142 gsl_matrix_free (kmeans->initial_centers);
144 gsl_vector_long_free (kmeans->num_elements_groups);
146 gsl_permutation_free (kmeans->group_order);
148 caseproto_unref (kmeans->proto);
151 These reader and writer were already destroyed.
152 free (kmeans->original_casereader);
153 free (kmeans->index_rdr);
162 Creates random centers using randomly selected cases from the data.
165 kmeans_randomize_centers (struct Kmeans *kmeans)
168 for (i = 0; i < kmeans->ngroups; i++)
170 for (j = 0; j < kmeans->m; j++)
172 //gsl_matrix_set(kmeans->centers,i,j, gsl_rng_uniform (kmeans->rng));
175 gsl_matrix_set (kmeans->centers, i, j, 1);
179 gsl_matrix_set (kmeans->centers, i, j, 0);
184 If it is the first iteration, the variable kmeans->initial_centers is NULL and
185 it is created once for reporting issues. In SPSS, initial centers are shown in the reports
186 but in PSPP it is not shown now. I am leaving it here.
188 if (!kmeans->initial_centers)
190 kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
191 gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
197 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
205 for (i = 0; i < kmeans->ngroups; i++)
208 for (j = 0; j < kmeans->m; j++)
210 x = case_data (c, kmeans->variables[j])->f;
211 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
226 Re-calculates the cluster centers
229 kmeans_recalculate_centers (struct Kmeans *kmeans)
235 struct ccase *c_index;
236 struct casereader *cs;
237 struct casereader *cs_index;
242 cs = casereader_clone (kmeans->original_casereader);
243 cs_index = casereader_clone (kmeans->index_rdr);
245 gsl_matrix_set_all (kmeans->centers, 0.0);
246 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
248 c_index = casereader_read (cs_index);
249 index = case_data_idx (c_index, 0)->f;
250 for (v = 0; v < kmeans->m; ++v)
254 weight = case_data (c, kmeans->wv)->f;
260 x = case_data (c, kmeans->variables[v])->f * weight;
261 curval = gsl_matrix_get (kmeans->centers, index, v);
262 gsl_matrix_set (kmeans->centers, index, v, curval + x);
265 case_unref (c_index);
267 casereader_destroy (cs);
268 casereader_destroy (cs_index);
270 /* Getting number of cases */
274 //We got sum of each center but we need averages.
275 //We are dividing centers to numobs. This may be inefficient and
276 //we should check it again.
277 for (i = 0; i < kmeans->ngroups; i++)
279 casenumber numobs = kmeans->num_elements_groups->data[i];
280 for (j = 0; j < kmeans->m; j++)
284 double *x = gsl_matrix_ptr (kmeans->centers, i, j);
289 gsl_matrix_set (kmeans->centers, i, j, 0);
297 The variable index in struct Kmeans holds integer values that represents the current groups of cases.
298 index[n]=a shows the nth case is belong to ath cluster.
299 This function calculates these indexes and returns the number of different cases of the new and old
300 index variables. If last two index variables are equal, there is no any enhancement of clustering.
303 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
308 struct casereader *cs = casereader_clone (kmeans->original_casereader);
311 /* A casewriter into which we will write the indexes */
312 struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
314 gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
316 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
318 /* A case to hold the new index */
319 struct ccase *index_case_new = case_create (kmeans->proto);
320 int bestindex = kmeans_get_nearest_group (kmeans, c);
323 weight = (casenumber) case_data (c, kmeans->wv)->f;
329 kmeans->num_elements_groups->data[bestindex] += weight;
330 if (kmeans->index_rdr)
332 /* A case from which the old index will be read */
333 struct ccase *index_case_old = NULL;
335 /* Read the case from the index casereader */
336 index_case_old = casereader_read (kmeans->index_rdr);
338 /* Set totaldiff, using the old_index */
339 totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
341 /* We have no use for the old case anymore, so unref it */
342 case_unref (index_case_old);
346 /* If this is the first run, then assume index is zero */
347 totaldiff += bestindex;
350 /* Set the value of the new index */
351 case_data_rw_idx (index_case_new, 0)->f = bestindex;
353 /* and write the new index to the casewriter */
354 casewriter_write (index_wtr, index_case_new);
356 casereader_destroy (cs);
357 /* We have now read through the entire index_rdr, so it's
359 casereader_destroy (kmeans->index_rdr);
361 /* Convert the writer into a reader, ready for the next iteration to read */
362 kmeans->index_rdr = casewriter_make_reader (index_wtr);
369 kmeans_order_groups (struct Kmeans *kmeans)
371 gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
372 gsl_matrix_get_col (v, kmeans->centers, 0);
373 gsl_sort_vector_index (kmeans->group_order, v);
378 Does iterations, checks convergency
381 kmeans_cluster (struct Kmeans *kmeans)
388 show_warning1 = true;
391 kmeans_randomize_centers (kmeans);
392 for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
395 diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
396 kmeans_recalculate_centers (kmeans);
397 if (show_warning1 && kmeans->ngroups > kmeans->n)
401 ("Number of clusters may not be larger than the number of cases."));
402 show_warning1 = false;
408 for (i = 0; i < kmeans->ngroups; i++)
410 if (kmeans->num_elements_groups->data[i] == 0)
413 if (kmeans->trials >= 3)
426 Reports centers of clusters.
427 initial parameter is optional for future use.
428 if initial is true, initial cluster centers are reported. Otherwise, resulted centers are reported.
431 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
434 int nc, nr, heading_columns, currow;
436 nc = kmeans->ngroups + 1;
439 t = tab_create (nc, nr);
440 tab_headers (t, 0, nc - 1, 0, 1);
444 tab_title (t, _("Final Cluster Centers"));
448 tab_title (t, _("Initial Cluster Centers"));
450 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
451 tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
452 tab_hline (t, TAL_1, 1, nc - 1, 2);
455 for (i = 0; i < kmeans->ngroups; i++)
457 tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
460 tab_hline (t, TAL_1, 1, nc - 1, currow);
462 for (i = 0; i < kmeans->m; i++)
464 tab_text (t, 0, currow + i, TAB_LEFT,
465 var_to_string (kmeans->variables[i]));
468 for (i = 0; i < kmeans->ngroups; i++)
470 for (j = 0; j < kmeans->m; j++)
474 tab_double (t, i + 1, j + 4, TAB_CENTER,
475 gsl_matrix_get (kmeans->centers,
476 kmeans->group_order->data[i], j),
477 var_get_print_format (kmeans->variables[j]));
481 tab_double (t, i + 1, j + 4, TAB_CENTER,
482 gsl_matrix_get (kmeans->initial_centers,
483 kmeans->group_order->data[i], j),
484 var_get_print_format (kmeans->variables[j]));
493 Reports number of cases of each single cluster.
496 quick_cluster_show_number_cases (struct Kmeans *kmeans)
503 nr = kmeans->ngroups + 1;
504 t = tab_create (nc, nr);
505 tab_headers (t, 0, nc - 1, 0, 0);
506 tab_title (t, _("Number of Cases in each Cluster"));
507 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
508 tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
511 for (i = 0; i < kmeans->ngroups; i++)
513 tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
515 kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
516 tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
520 tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
521 tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
529 quick_cluster_show_results (struct Kmeans *kmeans)
531 kmeans_order_groups (kmeans);
532 //uncomment the line above for reporting initial centers
533 //quick_cluster_show_centers (kmeans, true);
534 quick_cluster_show_centers (kmeans, false);
535 quick_cluster_show_number_cases (kmeans);
540 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
542 struct Kmeans *kmeans;
544 const struct dictionary *dict = dataset_dict (ds);
545 const struct variable **variables;
546 struct casereader *cs;
553 if (!parse_variables_const (lexer, dict, &variables, &p,
554 PV_NO_DUPLICATE | PV_NUMERIC))
556 msg (ME, _("Variables cannot be parsed"));
557 return (CMD_FAILURE);
562 if (lex_match (lexer, T_SLASH))
564 if (lex_match_id (lexer, "CRITERIA"))
566 lex_match (lexer, T_EQUALS);
567 while (lex_token (lexer) != T_ENDCMD
568 && lex_token (lexer) != T_SLASH)
570 if (lex_match_id (lexer, "CLUSTERS"))
572 if (lex_force_match (lexer, T_LPAREN))
574 lex_force_int (lexer);
575 groups = lex_integer (lexer);
577 lex_force_match (lexer, T_RPAREN);
580 else if (lex_match_id (lexer, "MXITER"))
582 if (lex_force_match (lexer, T_LPAREN))
584 lex_force_int (lexer);
585 maxiter = lex_integer (lexer);
587 lex_force_match (lexer, T_RPAREN);
592 //further command set
593 return (CMD_FAILURE);
603 kmeans = kmeans_create (cs, variables, p, groups, maxiter);
605 kmeans->wv = dict_get_weight (dict);
606 kmeans_cluster (kmeans);
607 quick_cluster_show_results (kmeans);
608 ok = proc_commit (ds);
610 kmeans_destroy (kmeans);