1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2011 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
51 Holds all of the information for the functions.
52 int n, holds the number of observation and its default value is -1.
53 We set it in kmeans_recalculate_centers in first invocation.
57 gsl_matrix *centers; //Centers for groups
58 gsl_vector_long *num_elements_groups;
59 int ngroups; //Number of group. (Given by the user)
60 casenumber n; //Number of observations. By default it is -1.
61 int m; //Number of variables. (Given by the user)
62 int maxiter; //Maximum number of iterations (Given by the user)
63 int lastiter; //Show at which iteration it found the solution.
64 int trials; //If not convergence, how many times has clustering done.
65 gsl_matrix *initial_centers; //Initial random centers
66 const struct variable **variables; //Variables
67 gsl_permutation *group_order; //Handles group order for reporting
68 struct casereader *original_casereader; //Casereader
69 struct caseproto *proto;
70 struct casereader *index_rdr; //We hold the group id's for each case in this structure
71 const struct variable *wv; //Weighting variable
74 static struct Kmeans *kmeans_create (struct casereader *cs,
75 const struct variable **variables,
76 int m, int ngroups, int maxiter);
78 static void kmeans_randomize_centers (struct Kmeans *kmeans);
80 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c);
82 static void kmeans_recalculate_centers (struct Kmeans *kmeans);
85 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans);
87 static void kmeans_order_groups (struct Kmeans *kmeans);
89 static void kmeans_cluster (struct Kmeans *kmeans);
91 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial);
93 static void quick_cluster_show_number_cases (struct Kmeans *kmeans);
95 static void quick_cluster_show_results (struct Kmeans *kmeans);
97 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
99 static void kmeans_destroy (struct Kmeans *kmeans);
102 Creates and returns a struct of Kmeans with given casereader 'cs', parsed variables 'variables',
103 number of cases 'n', number of variables 'm', number of clusters and amount of maximum iterations.
105 static struct Kmeans *
106 kmeans_create (struct casereader *cs, const struct variable **variables,
107 int m, int ngroups, int maxiter)
109 struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
110 kmeans->centers = gsl_matrix_alloc (ngroups, m);
111 kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
112 kmeans->ngroups = ngroups;
115 kmeans->maxiter = maxiter;
116 kmeans->lastiter = 0;
118 kmeans->variables = variables;
119 kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
120 kmeans->original_casereader = cs;
121 kmeans->initial_centers = NULL;
123 kmeans->proto = caseproto_create ();
124 kmeans->proto = caseproto_add_width (kmeans->proto, 0);
125 kmeans->index_rdr = NULL;
131 kmeans_destroy (struct Kmeans *kmeans)
133 gsl_matrix_free (kmeans->centers);
134 gsl_matrix_free (kmeans->initial_centers);
136 gsl_vector_long_free (kmeans->num_elements_groups);
138 gsl_permutation_free (kmeans->group_order);
140 caseproto_unref (kmeans->proto);
143 These reader and writer were already destroyed.
144 free (kmeans->original_casereader);
145 free (kmeans->index_rdr);
154 Creates random centers using randomly selected cases from the data.
157 kmeans_randomize_centers (struct Kmeans *kmeans)
160 for (i = 0; i < kmeans->ngroups; i++)
162 for (j = 0; j < kmeans->m; j++)
164 //gsl_matrix_set(kmeans->centers,i,j, gsl_rng_uniform (kmeans->rng));
167 gsl_matrix_set (kmeans->centers, i, j, 1);
171 gsl_matrix_set (kmeans->centers, i, j, 0);
176 If it is the first iteration, the variable kmeans->initial_centers is NULL and
177 it is created once for reporting issues. In SPSS, initial centers are shown in the reports
178 but in PSPP it is not shown now. I am leaving it here.
180 if (!kmeans->initial_centers)
182 kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
183 gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
189 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
197 for (i = 0; i < kmeans->ngroups; i++)
200 for (j = 0; j < kmeans->m; j++)
202 x = case_data (c, kmeans->variables[j])->f;
203 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
218 Re-calculates the cluster centers
221 kmeans_recalculate_centers (struct Kmeans *kmeans)
227 struct ccase *c_index;
228 struct casereader *cs;
229 struct casereader *cs_index;
234 cs = casereader_clone (kmeans->original_casereader);
235 cs_index = casereader_clone (kmeans->index_rdr);
237 gsl_matrix_set_all (kmeans->centers, 0.0);
238 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
240 c_index = casereader_read (cs_index);
241 index = case_data_idx (c_index, 0)->f;
242 for (v = 0; v < kmeans->m; ++v)
246 weight = case_data (c, kmeans->wv)->f;
252 x = case_data (c, kmeans->variables[v])->f * weight;
253 curval = gsl_matrix_get (kmeans->centers, index, v);
254 gsl_matrix_set (kmeans->centers, index, v, curval + x);
257 case_unref (c_index);
259 casereader_destroy (cs);
260 casereader_destroy (cs_index);
262 /* Getting number of cases */
266 //We got sum of each center but we need averages.
267 //We are dividing centers to numobs. This may be inefficient and
268 //we should check it again.
269 for (i = 0; i < kmeans->ngroups; i++)
271 casenumber numobs = kmeans->num_elements_groups->data[i];
272 for (j = 0; j < kmeans->m; j++)
276 double *x = gsl_matrix_ptr (kmeans->centers, i, j);
281 gsl_matrix_set (kmeans->centers, i, j, 0);
289 The variable index in struct Kmeans holds integer values that represents the current groups of cases.
290 index[n]=a shows the nth case is belong to ath cluster.
291 This function calculates these indexes and returns the number of different cases of the new and old
292 index variables. If last two index variables are equal, there is no any enhancement of clustering.
295 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
300 struct casereader *cs = casereader_clone (kmeans->original_casereader);
303 /* A casewriter into which we will write the indexes */
304 struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
306 gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
308 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
310 /* A case to hold the new index */
311 struct ccase *index_case_new = case_create (kmeans->proto);
312 int bestindex = kmeans_get_nearest_group (kmeans, c);
315 weight = (casenumber) case_data (c, kmeans->wv)->f;
321 kmeans->num_elements_groups->data[bestindex] += weight;
322 if (kmeans->index_rdr)
324 /* A case from which the old index will be read */
325 struct ccase *index_case_old = NULL;
327 /* Read the case from the index casereader */
328 index_case_old = casereader_read (kmeans->index_rdr);
330 /* Set totaldiff, using the old_index */
331 totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
333 /* We have no use for the old case anymore, so unref it */
334 case_unref (index_case_old);
338 /* If this is the first run, then assume index is zero */
339 totaldiff += bestindex;
342 /* Set the value of the new index */
343 case_data_rw_idx (index_case_new, 0)->f = bestindex;
345 /* and write the new index to the casewriter */
346 casewriter_write (index_wtr, index_case_new);
348 casereader_destroy (cs);
349 /* We have now read through the entire index_rdr, so it's
351 casereader_destroy (kmeans->index_rdr);
353 /* Convert the writer into a reader, ready for the next iteration to read */
354 kmeans->index_rdr = casewriter_make_reader (index_wtr);
361 kmeans_order_groups (struct Kmeans *kmeans)
363 gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
364 gsl_matrix_get_col (v, kmeans->centers, 0);
365 gsl_sort_vector_index (kmeans->group_order, v);
370 Does iterations, checks convergency
373 kmeans_cluster (struct Kmeans *kmeans)
380 show_warning1 = true;
383 kmeans_randomize_centers (kmeans);
384 for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
387 diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
388 kmeans_recalculate_centers (kmeans);
389 if (show_warning1 && kmeans->ngroups > kmeans->n)
393 ("Number of clusters may not be larger than the number of cases."));
394 show_warning1 = false;
400 for (i = 0; i < kmeans->ngroups; i++)
402 if (kmeans->num_elements_groups->data[i] == 0)
405 if (kmeans->trials >= 3)
418 Reports centers of clusters.
419 initial parameter is optional for future use.
420 if initial is true, initial cluster centers are reported. Otherwise, resulted centers are reported.
423 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
426 int nc, nr, heading_columns, currow;
428 nc = kmeans->ngroups + 1;
431 t = tab_create (nc, nr);
432 tab_headers (t, 0, nc - 1, 0, 1);
436 tab_title (t, _("Final Cluster Centers"));
440 tab_title (t, _("Initial Cluster Centers"));
442 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
443 tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
444 tab_hline (t, TAL_1, 1, nc - 1, 2);
447 for (i = 0; i < kmeans->ngroups; i++)
449 tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
452 tab_hline (t, TAL_1, 1, nc - 1, currow);
454 for (i = 0; i < kmeans->m; i++)
456 tab_text (t, 0, currow + i, TAB_LEFT,
457 var_to_string (kmeans->variables[i]));
460 for (i = 0; i < kmeans->ngroups; i++)
462 for (j = 0; j < kmeans->m; j++)
466 tab_double (t, i + 1, j + 4, TAB_CENTER,
467 gsl_matrix_get (kmeans->centers,
468 kmeans->group_order->data[i], j),
469 var_get_print_format (kmeans->variables[j]));
473 tab_double (t, i + 1, j + 4, TAB_CENTER,
474 gsl_matrix_get (kmeans->initial_centers,
475 kmeans->group_order->data[i], j),
476 var_get_print_format (kmeans->variables[j]));
485 Reports number of cases of each single cluster.
488 quick_cluster_show_number_cases (struct Kmeans *kmeans)
495 nr = kmeans->ngroups + 1;
496 t = tab_create (nc, nr);
497 tab_headers (t, 0, nc - 1, 0, 0);
498 tab_title (t, _("Number of Cases in each Cluster"));
499 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
500 tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
503 for (i = 0; i < kmeans->ngroups; i++)
505 tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
507 kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
508 tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
512 tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
513 tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
521 quick_cluster_show_results (struct Kmeans *kmeans)
523 kmeans_order_groups (kmeans);
524 //uncomment the line above for reporting initial centers
525 //quick_cluster_show_centers (kmeans, true);
526 quick_cluster_show_centers (kmeans, false);
527 quick_cluster_show_number_cases (kmeans);
532 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
534 struct Kmeans *kmeans;
536 const struct dictionary *dict = dataset_dict (ds);
537 const struct variable **variables;
538 struct casereader *cs;
545 if (!parse_variables_const (lexer, dict, &variables, &p,
546 PV_NO_DUPLICATE | PV_NUMERIC))
548 msg (ME, _("Variables cannot be parsed"));
549 return (CMD_FAILURE);
554 if (lex_match (lexer, T_SLASH))
556 if (lex_match_id (lexer, "CRITERIA"))
558 lex_match (lexer, T_EQUALS);
559 while (lex_token (lexer) != T_ENDCMD
560 && lex_token (lexer) != T_SLASH)
562 if (lex_match_id (lexer, "CLUSTERS"))
564 if (lex_force_match (lexer, T_LPAREN))
566 lex_force_int (lexer);
567 groups = lex_integer (lexer);
569 lex_force_match (lexer, T_RPAREN);
572 else if (lex_match_id (lexer, "MXITER"))
574 if (lex_force_match (lexer, T_LPAREN))
576 lex_force_int (lexer);
577 maxiter = lex_integer (lexer);
579 lex_force_match (lexer, T_RPAREN);
584 //further command set
585 return (CMD_FAILURE);
595 kmeans = kmeans_create (cs, variables, p, groups, maxiter);
597 kmeans->wv = dict_get_weight (dict);
598 kmeans_cluster (kmeans);
599 quick_cluster_show_results (kmeans);
600 ok = proc_commit (ds);
602 kmeans_destroy (kmeans);