1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2011 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
51 const struct variable **vars;
54 int ngroups; /* Number of group. (Given by the user) */
55 int maxiter; /* Maximum iterations (Given by the user) */
57 const struct variable *wv; /* Weighting variable. */
60 /* Holds all of the information for the functions. int n, holds the number of
61 observation and its default value is -1. We set it in
62 kmeans_recalculate_centers in first invocation. */
65 gsl_matrix *centers; /* Centers for groups. */
66 gsl_vector_long *num_elements_groups;
68 casenumber n; /* Number of observations (default -1). */
70 int lastiter; /* Iteration where it found the solution. */
71 int trials; /* If not convergence, how many times has
73 gsl_matrix *initial_centers; /* Initial random centers. */
75 gsl_permutation *group_order; /* Group order for reporting. */
76 struct caseproto *proto;
77 struct casereader *index_rdr; /* Group ids for each case. */
80 static struct Kmeans *kmeans_create (const struct qc *qc);
82 static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc);
84 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
86 static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
89 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
91 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
93 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
95 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
97 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
99 static void quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *);
101 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
103 static void kmeans_destroy (struct Kmeans *kmeans);
105 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
106 variables 'variables', number of cases 'n', number of variables 'm', number
107 of clusters and amount of maximum iterations. */
108 static struct Kmeans *
109 kmeans_create (const struct qc *qc)
111 struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
112 kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
113 kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
115 kmeans->lastiter = 0;
117 kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
118 kmeans->initial_centers = NULL;
120 kmeans->proto = caseproto_create ();
121 kmeans->proto = caseproto_add_width (kmeans->proto, 0);
122 kmeans->index_rdr = NULL;
127 kmeans_destroy (struct Kmeans *kmeans)
129 gsl_matrix_free (kmeans->centers);
130 gsl_matrix_free (kmeans->initial_centers);
132 gsl_vector_long_free (kmeans->num_elements_groups);
134 gsl_permutation_free (kmeans->group_order);
136 caseproto_unref (kmeans->proto);
138 casereader_destroy (kmeans->index_rdr);
143 /* Creates random centers using randomly selected cases from the data. */
145 kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc)
148 for (i = 0; i < qc->ngroups; i++)
150 for (j = 0; j < qc->n_vars; j++)
154 gsl_matrix_set (kmeans->centers, i, j, 1);
158 gsl_matrix_set (kmeans->centers, i, j, 0);
162 /* If it is the first iteration, the variable kmeans->initial_centers is NULL
163 and it is created once for reporting issues. In SPSS, initial centers are
164 shown in the reports but in PSPP it is not shown now. I am leaving it
166 if (!kmeans->initial_centers)
168 kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
169 gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
174 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
182 for (i = 0; i < qc->ngroups; i++)
185 for (j = 0; j < qc->n_vars; j++)
187 x = case_data (c, qc->vars[j])->f;
188 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
199 /* Re-calculate the cluster centers. */
201 kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
207 struct ccase *c_index;
208 struct casereader *cs;
209 struct casereader *cs_index;
213 cs = casereader_clone (reader);
214 cs_index = casereader_clone (kmeans->index_rdr);
216 gsl_matrix_set_all (kmeans->centers, 0.0);
217 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
219 double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
220 c_index = casereader_read (cs_index);
221 index = case_data_idx (c_index, 0)->f;
222 for (v = 0; v < qc->n_vars; ++v)
224 x = case_data (c, qc->vars[v])->f * weight;
225 curval = gsl_matrix_get (kmeans->centers, index, v);
226 gsl_matrix_set (kmeans->centers, index, v, curval + x);
229 case_unref (c_index);
231 casereader_destroy (cs);
232 casereader_destroy (cs_index);
234 /* Getting number of cases */
238 /* We got sum of each center but we need averages.
239 We are dividing centers to numobs. This may be inefficient and
240 we should check it again. */
241 for (i = 0; i < qc->ngroups; i++)
243 casenumber numobs = kmeans->num_elements_groups->data[i];
244 for (j = 0; j < qc->n_vars; j++)
248 double *x = gsl_matrix_ptr (kmeans->centers, i, j);
253 gsl_matrix_set (kmeans->centers, i, j, 0);
259 /* The variable index in struct Kmeans holds integer values that represents the
260 current groups of cases. index[n]=a shows the nth case is belong to ath
261 cluster. This function calculates these indexes and returns the number of
262 different cases of the new and old index variables. If last two index
263 variables are equal, there is no any enhancement of clustering. */
265 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
269 struct casereader *cs = casereader_clone (reader);
271 /* A casewriter into which we will write the indexes. */
272 struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
274 gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
276 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
278 /* A case to hold the new index. */
279 struct ccase *index_case_new = case_create (kmeans->proto);
280 int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
281 double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
282 kmeans->num_elements_groups->data[bestindex] += weight;
283 if (kmeans->index_rdr)
285 /* A case from which the old index will be read. */
286 struct ccase *index_case_old = NULL;
288 /* Read the case from the index casereader. */
289 index_case_old = casereader_read (kmeans->index_rdr);
291 /* Set totaldiff, using the old_index. */
292 totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
294 /* We have no use for the old case anymore, so unref it. */
295 case_unref (index_case_old);
299 /* If this is the first run, then assume index is zero. */
300 totaldiff += bestindex;
303 /* Set the value of the new inde.x */
304 case_data_rw_idx (index_case_new, 0)->f = bestindex;
306 /* and write the new index to the casewriter */
307 casewriter_write (index_wtr, index_case_new);
309 casereader_destroy (cs);
310 /* We have now read through the entire index_rdr, so it's of no use
312 casereader_destroy (kmeans->index_rdr);
314 /* Convert the writer into a reader, ready for the next iteration to read */
315 kmeans->index_rdr = casewriter_make_reader (index_wtr);
321 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
323 gsl_vector *v = gsl_vector_alloc (qc->ngroups);
324 gsl_matrix_get_col (v, kmeans->centers, 0);
325 gsl_sort_vector_index (kmeans->group_order, v);
330 Does iterations, checks convergency. */
332 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
339 show_warning1 = true;
342 kmeans_randomize_centers (kmeans, qc);
343 for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
346 diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
347 kmeans_recalculate_centers (kmeans, reader, qc);
348 if (show_warning1 && qc->ngroups > kmeans->n)
350 msg (MW, _("Number of clusters may not be larger than the number "
352 show_warning1 = false;
358 for (i = 0; i < qc->ngroups; i++)
360 if (kmeans->num_elements_groups->data[i] == 0)
363 if (kmeans->trials >= 3)
374 /* Reports centers of clusters.
375 Initial parameter is optional for future use.
376 If initial is true, initial cluster centers are reported. Otherwise,
377 resulted centers are reported. */
379 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
382 int nc, nr, heading_columns, currow;
384 nc = qc->ngroups + 1;
387 t = tab_create (nc, nr);
388 tab_headers (t, 0, nc - 1, 0, 1);
392 tab_title (t, _("Final Cluster Centers"));
396 tab_title (t, _("Initial Cluster Centers"));
398 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
399 tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
400 tab_hline (t, TAL_1, 1, nc - 1, 2);
403 for (i = 0; i < qc->ngroups; i++)
405 tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
408 tab_hline (t, TAL_1, 1, nc - 1, currow);
410 for (i = 0; i < qc->n_vars; i++)
412 tab_text (t, 0, currow + i, TAB_LEFT,
413 var_to_string (qc->vars[i]));
416 for (i = 0; i < qc->ngroups; i++)
418 for (j = 0; j < qc->n_vars; j++)
422 tab_double (t, i + 1, j + 4, TAB_CENTER,
423 gsl_matrix_get (kmeans->centers,
424 kmeans->group_order->data[i], j),
425 var_get_print_format (qc->vars[j]));
429 tab_double (t, i + 1, j + 4, TAB_CENTER,
430 gsl_matrix_get (kmeans->initial_centers,
431 kmeans->group_order->data[i], j),
432 var_get_print_format (qc->vars[j]));
439 /* Reports number of cases of each single cluster. */
441 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
448 nr = qc->ngroups + 1;
449 t = tab_create (nc, nr);
450 tab_headers (t, 0, nc - 1, 0, 0);
451 tab_title (t, _("Number of Cases in each Cluster"));
452 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
453 tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
456 for (i = 0; i < qc->ngroups; i++)
458 tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
460 kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
461 tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
465 tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
466 tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
472 quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *qc)
474 kmeans_order_groups (kmeans, qc);
475 /* Uncomment the line below for reporting initial centers. */
476 /* quick_cluster_show_centers (kmeans, true); */
477 quick_cluster_show_centers (kmeans, false, qc);
478 quick_cluster_show_number_cases (kmeans, qc);
482 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
485 struct Kmeans *kmeans;
487 const struct dictionary *dict = dataset_dict (ds);
491 if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
492 PV_NO_DUPLICATE | PV_NUMERIC))
494 return (CMD_FAILURE);
497 if (lex_match (lexer, T_SLASH))
499 if (lex_match_id (lexer, "CRITERIA"))
501 lex_match (lexer, T_EQUALS);
502 while (lex_token (lexer) != T_ENDCMD
503 && lex_token (lexer) != T_SLASH)
505 if (lex_match_id (lexer, "CLUSTERS"))
507 if (lex_force_match (lexer, T_LPAREN))
509 lex_force_int (lexer);
510 qc.ngroups = lex_integer (lexer);
512 lex_force_match (lexer, T_RPAREN);
515 else if (lex_match_id (lexer, "MXITER"))
517 if (lex_force_match (lexer, T_LPAREN))
519 lex_force_int (lexer);
520 qc.maxiter = lex_integer (lexer);
522 lex_force_match (lexer, T_RPAREN);
531 qc.wv = dict_get_weight (dict);
534 struct casereader *group;
535 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
537 while (casegrouper_get_next_group (grouper, &group))
539 kmeans = kmeans_create (&qc);
540 kmeans_cluster (kmeans, group, &qc);
541 quick_cluster_show_results (kmeans, &qc);
542 kmeans_destroy (kmeans);
543 casereader_destroy (group);
545 ok = casegrouper_destroy (grouper);
547 ok = proc_commit (ds) && ok;