1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2011 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
21 #include <libpspp/misc.h>
23 #include <libpspp/str.h>
24 #include <libpspp/message.h>
27 #include <data/dataset.h>
28 #include <data/missing-values.h>
29 #include <data/casereader.h>
30 #include <data/casewriter.h>
31 #include <data/casegrouper.h>
32 #include <data/dictionary.h>
33 #include <data/format.h>
34 #include <data/case.h>
36 #include <language/lexer/variable-parser.h>
37 #include <language/command.h>
38 #include <language/lexer/lexer.h>
40 #include <output/tab.h>
41 #include <output/text-item.h>
46 #include <gsl/gsl_matrix.h>
47 #include <gsl/gsl_statistics.h>
48 #include <gsl/gsl_permutation.h>
49 #include <gsl/gsl_sort_vector.h>
51 #include <math/random.h>
54 #define _(msgid) gettext (msgid)
55 #define N_(msgid) msgid
57 #include "quick-cluster.h"
61 Holds all of the information for the functions.
62 int n, holds the number of observation and its default value is -1.
63 We set it in kmeans_recalculate_centers in first invocation.
67 gsl_matrix *centers; //Centers for groups
68 gsl_vector_long *num_elements_groups;
69 int ngroups; //Number of group. (Given by the user)
70 casenumber n; //Number of observations. By default it is -1.
71 int m; //Number of variables. (Given by the user)
72 int maxiter; //Maximum number of iterations (Given by the user)
73 int lastiter; //Show at which iteration it found the solution.
74 int trials; //If not convergence, how many times has clustering done.
75 gsl_matrix *initial_centers; //Initial random centers
76 const struct variable **variables; //Variables
77 gsl_permutation *group_order; //Handles group order for reporting
78 struct casereader *original_casereader; //Casereader
79 struct caseproto *proto;
80 struct casereader *index_rdr; //We hold the group id's for each case in this structure
81 const struct variable *wv; //Weighting variable
86 Creates and returns a struct of Kmeans with given casereader 'cs', parsed variables 'variables',
87 number of cases 'n', number of variables 'm', number of clusters and amount of maximum iterations.
89 static struct Kmeans *
90 kmeans_create (struct casereader *cs, const struct variable **variables,
91 int m, int ngroups, int maxiter)
93 struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
94 kmeans->centers = gsl_matrix_alloc (ngroups, m);
95 kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
96 kmeans->ngroups = ngroups;
99 kmeans->maxiter = maxiter;
100 kmeans->lastiter = 0;
102 kmeans->variables = variables;
103 kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
104 kmeans->original_casereader = cs;
105 kmeans->initial_centers = NULL;
107 kmeans->proto = caseproto_create ();
108 kmeans->proto = caseproto_add_width (kmeans->proto, 0);
109 kmeans->index_rdr = NULL;
115 kmeans_destroy (struct Kmeans *kmeans)
117 gsl_matrix_free (kmeans->centers);
118 gsl_matrix_free (kmeans->initial_centers);
120 gsl_vector_long_free (kmeans->num_elements_groups);
122 gsl_permutation_free (kmeans->group_order);
124 caseproto_unref (kmeans->proto);
127 These reader and writer were already destroyed.
128 free (kmeans->original_casereader);
129 free (kmeans->index_rdr);
138 Creates random centers using randomly selected cases from the data.
141 kmeans_randomize_centers (struct Kmeans *kmeans)
144 for (i = 0; i < kmeans->ngroups; i++)
146 for (j = 0; j < kmeans->m; j++)
148 //gsl_matrix_set(kmeans->centers,i,j, gsl_rng_uniform (kmeans->rng));
151 gsl_matrix_set (kmeans->centers, i, j, 1);
155 gsl_matrix_set (kmeans->centers, i, j, 0);
160 If it is the first iteration, the variable kmeans->initial_centers is NULL and
161 it is created once for reporting issues. In SPSS, initial centers are shown in the reports
162 but in PSPP it is not shown now. I am leaving it here.
164 if (!kmeans->initial_centers)
166 kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
167 gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
173 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
181 for (i = 0; i < kmeans->ngroups; i++)
184 for (j = 0; j < kmeans->m; j++)
186 x = case_data (c, kmeans->variables[j])->f;
187 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
202 Re-calculates the cluster centers
205 kmeans_recalculate_centers (struct Kmeans *kmeans)
211 struct ccase *c_index;
212 struct casereader *cs;
213 struct casereader *cs_index;
218 cs = casereader_clone (kmeans->original_casereader);
219 cs_index = casereader_clone (kmeans->index_rdr);
221 gsl_matrix_set_all (kmeans->centers, 0.0);
222 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
224 c_index = casereader_read (cs_index);
225 index = case_data_idx (c_index, 0)->f;
226 for (v = 0; v < kmeans->m; ++v)
230 weight = case_data (c, kmeans->wv)->f;
236 x = case_data (c, kmeans->variables[v])->f * weight;
237 curval = gsl_matrix_get (kmeans->centers, index, v);
238 gsl_matrix_set (kmeans->centers, index, v, curval + x);
241 case_unref (c_index);
243 casereader_destroy (cs);
244 casereader_destroy (cs_index);
246 /* Getting number of cases */
250 //We got sum of each center but we need averages.
251 //We are dividing centers to numobs. This may be inefficient and
252 //we should check it again.
253 for (i = 0; i < kmeans->ngroups; i++)
255 casenumber numobs = kmeans->num_elements_groups->data[i];
256 for (j = 0; j < kmeans->m; j++)
260 double *x = gsl_matrix_ptr (kmeans->centers, i, j);
265 gsl_matrix_set (kmeans->centers, i, j, 0);
273 The variable index in struct Kmeans holds integer values that represents the current groups of cases.
274 index[n]=a shows the nth case is belong to ath cluster.
275 This function calculates these indexes and returns the number of different cases of the new and old
276 index variables. If last two index variables are equal, there is no any enhancement of clustering.
279 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
284 struct casereader *cs = casereader_clone (kmeans->original_casereader);
287 /* A casewriter into which we will write the indexes */
288 struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
290 gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
292 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
294 /* A case to hold the new index */
295 struct ccase *index_case_new = case_create (kmeans->proto);
296 int bestindex = kmeans_get_nearest_group (kmeans, c);
299 weight = (casenumber) case_data (c, kmeans->wv)->f;
305 kmeans->num_elements_groups->data[bestindex] += weight;
306 if (kmeans->index_rdr)
308 /* A case from which the old index will be read */
309 struct ccase *index_case_old = NULL;
311 /* Read the case from the index casereader */
312 index_case_old = casereader_read (kmeans->index_rdr);
314 /* Set totaldiff, using the old_index */
315 totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
317 /* We have no use for the old case anymore, so unref it */
318 case_unref (index_case_old);
322 /* If this is the first run, then assume index is zero */
323 totaldiff += bestindex;
326 /* Set the value of the new index */
327 case_data_rw_idx (index_case_new, 0)->f = bestindex;
329 /* and write the new index to the casewriter */
330 casewriter_write (index_wtr, index_case_new);
332 casereader_destroy (cs);
333 /* We have now read through the entire index_rdr, so it's
335 casereader_destroy (kmeans->index_rdr);
337 /* Convert the writer into a reader, ready for the next iteration to read */
338 kmeans->index_rdr = casewriter_make_reader (index_wtr);
345 kmeans_order_groups (struct Kmeans *kmeans)
347 gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
348 gsl_matrix_get_col (v, kmeans->centers, 0);
349 gsl_sort_vector_index (kmeans->group_order, v);
354 Does iterations, checks convergency
357 kmeans_cluster (struct Kmeans *kmeans)
364 show_warning1 = true;
367 kmeans_randomize_centers (kmeans);
368 for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
371 diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
372 kmeans_recalculate_centers (kmeans);
373 if (show_warning1 && kmeans->ngroups > kmeans->n)
377 ("Number of clusters may not be larger than the number of cases."));
378 show_warning1 = false;
384 for (i = 0; i < kmeans->ngroups; i++)
386 if (kmeans->num_elements_groups->data[i] == 0)
389 if (kmeans->trials >= 3)
402 Reports centers of clusters.
403 initial parameter is optional for future use.
404 if initial is true, initial cluster centers are reported. Otherwise, resulted centers are reported.
407 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
410 int nc, nr, heading_columns, currow;
412 nc = kmeans->ngroups + 1;
415 t = tab_create (nc, nr);
416 tab_headers (t, 0, nc - 1, 0, 1);
420 tab_title (t, _("Final Cluster Centers"));
424 tab_title (t, _("Initial Cluster Centers"));
426 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
427 tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
428 tab_hline (t, TAL_1, 1, nc - 1, 2);
431 for (i = 0; i < kmeans->ngroups; i++)
433 tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
436 tab_hline (t, TAL_1, 1, nc - 1, currow);
438 for (i = 0; i < kmeans->m; i++)
440 tab_text (t, 0, currow + i, TAB_LEFT,
441 var_to_string (kmeans->variables[i]));
444 for (i = 0; i < kmeans->ngroups; i++)
446 for (j = 0; j < kmeans->m; j++)
450 tab_double (t, i + 1, j + 4, TAB_CENTER,
451 gsl_matrix_get (kmeans->centers,
452 kmeans->group_order->data[i], j),
453 var_get_print_format (kmeans->variables[j]));
457 tab_double (t, i + 1, j + 4, TAB_CENTER,
458 gsl_matrix_get (kmeans->initial_centers,
459 kmeans->group_order->data[i], j),
460 var_get_print_format (kmeans->variables[j]));
469 Reports number of cases of each single cluster.
472 quick_cluster_show_number_cases (struct Kmeans *kmeans)
479 nr = kmeans->ngroups + 1;
480 t = tab_create (nc, nr);
481 tab_headers (t, 0, nc - 1, 0, 0);
482 tab_title (t, _("Number of Cases in each Cluster"));
483 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
484 tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
487 for (i = 0; i < kmeans->ngroups; i++)
489 tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
491 kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
492 tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
496 tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
497 tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
505 quick_cluster_show_results (struct Kmeans *kmeans)
507 kmeans_order_groups (kmeans);
508 //uncomment the line above for reporting initial centers
509 //quick_cluster_show_centers (kmeans, true);
510 quick_cluster_show_centers (kmeans, false);
511 quick_cluster_show_number_cases (kmeans);
516 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
518 struct Kmeans *kmeans;
520 const struct dictionary *dict = dataset_dict (ds);
521 const struct variable **variables;
522 struct casereader *cs;
529 if (!parse_variables_const (lexer, dict, &variables, &p,
530 PV_NO_DUPLICATE | PV_NUMERIC))
532 msg (ME, _("Variables cannot be parsed"));
533 return (CMD_FAILURE);
538 if (lex_match (lexer, T_SLASH))
540 if (lex_match_id (lexer, "CRITERIA"))
542 lex_match (lexer, T_EQUALS);
543 while (lex_token (lexer) != T_ENDCMD
544 && lex_token (lexer) != T_SLASH)
546 if (lex_match_id (lexer, "CLUSTERS"))
548 if (lex_force_match (lexer, T_LPAREN))
550 lex_force_int (lexer);
551 groups = lex_integer (lexer);
553 lex_force_match (lexer, T_RPAREN);
556 else if (lex_match_id (lexer, "MXITER"))
558 if (lex_force_match (lexer, T_LPAREN))
560 lex_force_int (lexer);
561 maxiter = lex_integer (lexer);
563 lex_force_match (lexer, T_RPAREN);
568 //further command set
569 return (CMD_FAILURE);
579 kmeans = kmeans_create (cs, variables, p, groups, maxiter);
581 kmeans->wv = dict_get_weight (dict);
582 kmeans_cluster (kmeans);
583 quick_cluster_show_results (kmeans);
584 ok = proc_commit (ds);
586 kmeans_destroy (kmeans);