1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2011 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
58 const struct variable **vars;
61 int ngroups; /* Number of group. (Given by the user) */
62 int maxiter; /* Maximum iterations (Given by the user) */
64 const struct variable *wv; /* Weighting variable. */
66 enum missing_type missing_type;
67 enum mv_class exclude;
70 /* Holds all of the information for the functions. int n, holds the number of
71 observation and its default value is -1. We set it in
72 kmeans_recalculate_centers in first invocation. */
75 gsl_matrix *centers; /* Centers for groups. */
76 gsl_vector_long *num_elements_groups;
78 casenumber n; /* Number of observations (default -1). */
80 int lastiter; /* Iteration where it found the solution. */
81 int trials; /* If not convergence, how many times has
83 gsl_matrix *initial_centers; /* Initial random centers. */
85 gsl_permutation *group_order; /* Group order for reporting. */
86 struct caseproto *proto;
87 struct casereader *index_rdr; /* Group ids for each case. */
90 static struct Kmeans *kmeans_create (const struct qc *qc);
92 static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc);
94 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
96 static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
99 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
101 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
103 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
105 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
107 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
109 static void quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *);
111 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
113 static void kmeans_destroy (struct Kmeans *kmeans);
115 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
116 variables 'variables', number of cases 'n', number of variables 'm', number
117 of clusters and amount of maximum iterations. */
118 static struct Kmeans *
119 kmeans_create (const struct qc *qc)
121 struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
122 kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
123 kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
125 kmeans->lastiter = 0;
127 kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
128 kmeans->initial_centers = NULL;
130 kmeans->proto = caseproto_create ();
131 kmeans->proto = caseproto_add_width (kmeans->proto, 0);
132 kmeans->index_rdr = NULL;
137 kmeans_destroy (struct Kmeans *kmeans)
139 gsl_matrix_free (kmeans->centers);
140 gsl_matrix_free (kmeans->initial_centers);
142 gsl_vector_long_free (kmeans->num_elements_groups);
144 gsl_permutation_free (kmeans->group_order);
146 caseproto_unref (kmeans->proto);
148 casereader_destroy (kmeans->index_rdr);
153 /* Creates random centers using randomly selected cases from the data. */
155 kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc)
158 for (i = 0; i < qc->ngroups; i++)
160 for (j = 0; j < qc->n_vars; j++)
164 gsl_matrix_set (kmeans->centers, i, j, 1);
168 gsl_matrix_set (kmeans->centers, i, j, 0);
172 /* If it is the first iteration, the variable kmeans->initial_centers is NULL
173 and it is created once for reporting issues. In SPSS, initial centers are
174 shown in the reports but in PSPP it is not shown now. I am leaving it
176 if (!kmeans->initial_centers)
178 kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
179 gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
184 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
188 double mindist = INFINITY;
189 for (i = 0; i < qc->ngroups; i++)
192 for (j = 0; j < qc->n_vars; j++)
194 const union value *val = case_data (c, qc->vars[j]);
195 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
198 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
209 /* Re-calculate the cluster centers. */
211 kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
217 struct casereader *cs = casereader_clone (reader);
218 struct casereader *cs_index = casereader_clone (kmeans->index_rdr);
220 gsl_matrix_set_all (kmeans->centers, 0.0);
221 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
223 double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
224 struct ccase *c_index = casereader_read (cs_index);
225 int index = case_data_idx (c_index, 0)->f;
226 for (v = 0; v < qc->n_vars; ++v)
228 const union value *val = case_data (c, qc->vars[v]);
229 double x = val->f * weight;
232 if ( var_is_value_missing (qc->vars[v], val, qc->exclude))
235 curval = gsl_matrix_get (kmeans->centers, index, v);
236 gsl_matrix_set (kmeans->centers, index, v, curval + x);
239 case_unref (c_index);
241 casereader_destroy (cs);
242 casereader_destroy (cs_index);
244 /* Getting number of cases */
248 /* We got sum of each center but we need averages.
249 We are dividing centers to numobs. This may be inefficient and
250 we should check it again. */
251 for (i = 0; i < qc->ngroups; i++)
253 casenumber numobs = kmeans->num_elements_groups->data[i];
254 for (j = 0; j < qc->n_vars; j++)
258 double *x = gsl_matrix_ptr (kmeans->centers, i, j);
263 gsl_matrix_set (kmeans->centers, i, j, 0);
269 /* The variable index in struct Kmeans holds integer values that represents the
270 current groups of cases. index[n]=a shows the nth case is belong to ath
271 cluster. This function calculates these indexes and returns the number of
272 different cases of the new and old index variables. If last two index
273 variables are equal, there is no any enhancement of clustering. */
275 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
279 struct casereader *cs = casereader_clone (reader);
281 /* A casewriter into which we will write the indexes. */
282 struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
284 gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
286 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
288 /* A case to hold the new index. */
289 struct ccase *index_case_new = case_create (kmeans->proto);
290 int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
291 double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
292 assert (bestindex < kmeans->num_elements_groups->size);
293 kmeans->num_elements_groups->data[bestindex] += weight;
294 if (kmeans->index_rdr)
296 /* A case from which the old index will be read. */
297 struct ccase *index_case_old = NULL;
299 /* Read the case from the index casereader. */
300 index_case_old = casereader_read (kmeans->index_rdr);
302 /* Set totaldiff, using the old_index. */
303 totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
305 /* We have no use for the old case anymore, so unref it. */
306 case_unref (index_case_old);
310 /* If this is the first run, then assume index is zero. */
311 totaldiff += bestindex;
314 /* Set the value of the new inde.x */
315 case_data_rw_idx (index_case_new, 0)->f = bestindex;
317 /* and write the new index to the casewriter */
318 casewriter_write (index_wtr, index_case_new);
320 casereader_destroy (cs);
321 /* We have now read through the entire index_rdr, so it's of no use
323 casereader_destroy (kmeans->index_rdr);
325 /* Convert the writer into a reader, ready for the next iteration to read */
326 kmeans->index_rdr = casewriter_make_reader (index_wtr);
332 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
334 gsl_vector *v = gsl_vector_alloc (qc->ngroups);
335 gsl_matrix_get_col (v, kmeans->centers, 0);
336 gsl_sort_vector_index (kmeans->group_order, v);
341 Does iterations, checks convergency. */
343 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
350 show_warning1 = true;
353 kmeans_randomize_centers (kmeans, qc);
354 for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
357 diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
358 kmeans_recalculate_centers (kmeans, reader, qc);
359 if (show_warning1 && qc->ngroups > kmeans->n)
361 msg (MW, _("Number of clusters may not be larger than the number "
363 show_warning1 = false;
369 for (i = 0; i < qc->ngroups; i++)
371 if (kmeans->num_elements_groups->data[i] == 0)
374 if (kmeans->trials >= 3)
385 /* Reports centers of clusters.
386 Initial parameter is optional for future use.
387 If initial is true, initial cluster centers are reported. Otherwise,
388 resulted centers are reported. */
390 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
393 int nc, nr, heading_columns, currow;
395 nc = qc->ngroups + 1;
398 t = tab_create (nc, nr);
399 tab_headers (t, 0, nc - 1, 0, 1);
403 tab_title (t, _("Final Cluster Centers"));
407 tab_title (t, _("Initial Cluster Centers"));
409 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
410 tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
411 tab_hline (t, TAL_1, 1, nc - 1, 2);
414 for (i = 0; i < qc->ngroups; i++)
416 tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
419 tab_hline (t, TAL_1, 1, nc - 1, currow);
421 for (i = 0; i < qc->n_vars; i++)
423 tab_text (t, 0, currow + i, TAB_LEFT,
424 var_to_string (qc->vars[i]));
427 for (i = 0; i < qc->ngroups; i++)
429 for (j = 0; j < qc->n_vars; j++)
433 tab_double (t, i + 1, j + 4, TAB_CENTER,
434 gsl_matrix_get (kmeans->centers,
435 kmeans->group_order->data[i], j),
436 var_get_print_format (qc->vars[j]));
440 tab_double (t, i + 1, j + 4, TAB_CENTER,
441 gsl_matrix_get (kmeans->initial_centers,
442 kmeans->group_order->data[i], j),
443 var_get_print_format (qc->vars[j]));
450 /* Reports number of cases of each single cluster. */
452 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
459 nr = qc->ngroups + 1;
460 t = tab_create (nc, nr);
461 tab_headers (t, 0, nc - 1, 0, 0);
462 tab_title (t, _("Number of Cases in each Cluster"));
463 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
464 tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
467 for (i = 0; i < qc->ngroups; i++)
469 tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
471 kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
472 tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
476 tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
477 tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
483 quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *qc)
485 kmeans_order_groups (kmeans, qc);
486 /* Uncomment the line below for reporting initial centers. */
487 /* quick_cluster_show_centers (kmeans, true); */
488 quick_cluster_show_centers (kmeans, false, qc);
489 quick_cluster_show_number_cases (kmeans, qc);
493 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
496 struct Kmeans *kmeans;
498 const struct dictionary *dict = dataset_dict (ds);
501 qc.missing_type = MISS_LISTWISE;
504 if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
505 PV_NO_DUPLICATE | PV_NUMERIC))
507 return (CMD_FAILURE);
510 while (lex_token (lexer) != T_ENDCMD)
512 lex_match (lexer, T_SLASH);
514 if (lex_match_id (lexer, "MISSING"))
516 lex_match (lexer, T_EQUALS);
517 while (lex_token (lexer) != T_ENDCMD
518 && lex_token (lexer) != T_SLASH)
520 if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT"))
522 qc.missing_type = MISS_LISTWISE;
524 else if (lex_match_id (lexer, "PAIRWISE"))
526 qc.missing_type = MISS_PAIRWISE;
528 else if (lex_match_id (lexer, "INCLUDE"))
530 qc.exclude = MV_SYSTEM;
532 else if (lex_match_id (lexer, "EXCLUDE"))
540 else if (lex_match_id (lexer, "CRITERIA"))
542 lex_match (lexer, T_EQUALS);
543 while (lex_token (lexer) != T_ENDCMD
544 && lex_token (lexer) != T_SLASH)
546 if (lex_match_id (lexer, "CLUSTERS"))
548 if (lex_force_match (lexer, T_LPAREN))
550 lex_force_int (lexer);
551 qc.ngroups = lex_integer (lexer);
553 lex_force_match (lexer, T_RPAREN);
556 else if (lex_match_id (lexer, "MXITER"))
558 if (lex_force_match (lexer, T_LPAREN))
560 lex_force_int (lexer);
561 qc.maxiter = lex_integer (lexer);
563 lex_force_match (lexer, T_RPAREN);
572 qc.wv = dict_get_weight (dict);
575 struct casereader *group;
576 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
578 while (casegrouper_get_next_group (grouper, &group))
580 if ( qc.missing_type == MISS_LISTWISE )
582 group = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
587 kmeans = kmeans_create (&qc);
588 kmeans_cluster (kmeans, group, &qc);
589 quick_cluster_show_results (kmeans, &qc);
590 kmeans_destroy (kmeans);
591 casereader_destroy (group);
593 ok = casegrouper_destroy (grouper);
595 ok = proc_commit (ds) && ok;