1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2011, 2012, 2015 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/tab.h"
43 #include "output/text-item.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
58 const struct variable **vars;
61 int ngroups; /* Number of group. (Given by the user) */
62 int maxiter; /* Maximum iterations (Given by the user) */
63 int print_cluster_membership; /* true => print membership */
64 int print_initial_clusters; /* true => print initial cluster */
66 const struct variable *wv; /* Weighting variable. */
68 enum missing_type missing_type;
69 enum mv_class exclude;
72 /* Holds all of the information for the functions. int n, holds the number of
73 observation and its default value is -1. We set it in
74 kmeans_recalculate_centers in first invocation. */
77 gsl_matrix *centers; /* Centers for groups. */
78 gsl_vector_long *num_elements_groups;
80 casenumber n; /* Number of observations (default -1). */
82 int lastiter; /* Iteration where it found the solution. */
83 int trials; /* If not convergence, how many times has
85 gsl_matrix *initial_centers; /* Initial random centers. */
87 gsl_permutation *group_order; /* Group order for reporting. */
88 struct caseproto *proto;
89 struct casereader *index_rdr; /* Group ids for each case. */
92 static struct Kmeans *kmeans_create (const struct qc *qc);
94 static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc);
96 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
98 static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
101 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
103 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
105 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
107 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
109 static void quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
111 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
113 static void quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
115 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
117 static void kmeans_destroy (struct Kmeans *kmeans);
119 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
120 variables 'variables', number of cases 'n', number of variables 'm', number
121 of clusters and amount of maximum iterations. */
122 static struct Kmeans *
123 kmeans_create (const struct qc *qc)
125 struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
126 kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
127 kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
129 kmeans->lastiter = 0;
131 kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
132 kmeans->initial_centers = NULL;
134 kmeans->proto = caseproto_create ();
135 kmeans->proto = caseproto_add_width (kmeans->proto, 0);
136 kmeans->index_rdr = NULL;
141 kmeans_destroy (struct Kmeans *kmeans)
143 gsl_matrix_free (kmeans->centers);
144 gsl_matrix_free (kmeans->initial_centers);
146 gsl_vector_long_free (kmeans->num_elements_groups);
148 gsl_permutation_free (kmeans->group_order);
150 caseproto_unref (kmeans->proto);
152 casereader_destroy (kmeans->index_rdr);
157 /* Creates random centers using randomly selected cases from the data. */
159 kmeans_randomize_centers (struct Kmeans *kmeans, const struct casereader *reader UNUSED, const struct qc *qc)
162 for (i = 0; i < qc->ngroups; i++)
164 for (j = 0; j < qc->n_vars; j++)
168 gsl_matrix_set (kmeans->centers, i, j, 1);
172 gsl_matrix_set (kmeans->centers, i, j, 0);
176 /* If it is the first iteration, the variable kmeans->initial_centers is NULL
177 and it is created once for reporting issues. In SPSS, initial centers are
178 shown in the reports but in PSPP it is not shown now. I am leaving it
180 if (!kmeans->initial_centers)
182 kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
183 gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
188 kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
192 double mindist = INFINITY;
193 for (i = 0; i < qc->ngroups; i++)
196 for (j = 0; j < qc->n_vars; j++)
198 const union value *val = case_data (c, qc->vars[j]);
199 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
202 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
213 /* Re-calculate the cluster centers. */
215 kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
221 struct casereader *cs = casereader_clone (reader);
222 struct casereader *cs_index = casereader_clone (kmeans->index_rdr);
224 gsl_matrix_set_all (kmeans->centers, 0.0);
225 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
227 double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
228 struct ccase *c_index = casereader_read (cs_index);
229 int index = case_data_idx (c_index, 0)->f;
230 for (v = 0; v < qc->n_vars; ++v)
232 const union value *val = case_data (c, qc->vars[v]);
233 double x = val->f * weight;
236 if ( var_is_value_missing (qc->vars[v], val, qc->exclude))
239 curval = gsl_matrix_get (kmeans->centers, index, v);
240 gsl_matrix_set (kmeans->centers, index, v, curval + x);
243 case_unref (c_index);
245 casereader_destroy (cs);
246 casereader_destroy (cs_index);
248 /* Getting number of cases */
252 /* We got sum of each center but we need averages.
253 We are dividing centers to numobs. This may be inefficient and
254 we should check it again. */
255 for (i = 0; i < qc->ngroups; i++)
257 casenumber numobs = kmeans->num_elements_groups->data[i];
258 for (j = 0; j < qc->n_vars; j++)
262 double *x = gsl_matrix_ptr (kmeans->centers, i, j);
267 gsl_matrix_set (kmeans->centers, i, j, 0);
273 /* The variable index in struct Kmeans holds integer values that represents the
274 current groups of cases. index[n]=a shows the nth case is belong to ath
275 cluster. This function calculates these indexes and returns the number of
276 different cases of the new and old index variables. If last two index
277 variables are equal, there is no any enhancement of clustering. */
279 kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
283 struct casereader *cs = casereader_clone (reader);
285 /* A casewriter into which we will write the indexes. */
286 struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
288 gsl_vector_long_set_all (kmeans->num_elements_groups, 0);
290 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
292 /* A case to hold the new index. */
293 struct ccase *index_case_new = case_create (kmeans->proto);
294 int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
295 double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
296 assert (bestindex < kmeans->num_elements_groups->size);
297 kmeans->num_elements_groups->data[bestindex] += weight;
298 if (kmeans->index_rdr)
300 /* A case from which the old index will be read. */
301 struct ccase *index_case_old = NULL;
303 /* Read the case from the index casereader. */
304 index_case_old = casereader_read (kmeans->index_rdr);
306 /* Set totaldiff, using the old_index. */
307 totaldiff += abs (case_data_idx (index_case_old, 0)->f - bestindex);
309 /* We have no use for the old case anymore, so unref it. */
310 case_unref (index_case_old);
314 /* If this is the first run, then assume index is zero. */
315 totaldiff += bestindex;
318 /* Set the value of the new inde.x */
319 case_data_rw_idx (index_case_new, 0)->f = bestindex;
321 /* and write the new index to the casewriter */
322 casewriter_write (index_wtr, index_case_new);
324 casereader_destroy (cs);
325 /* We have now read through the entire index_rdr, so it's of no use
327 casereader_destroy (kmeans->index_rdr);
329 /* Convert the writer into a reader, ready for the next iteration to read */
330 kmeans->index_rdr = casewriter_make_reader (index_wtr);
336 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
338 gsl_vector *v = gsl_vector_alloc (qc->ngroups);
339 gsl_matrix_get_col (v, kmeans->centers, 0);
340 gsl_sort_vector_index (kmeans->group_order, v);
345 Does iterations, checks convergency. */
347 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
355 show_warning1 = true;
358 kmeans_randomize_centers (kmeans, reader, qc);
359 for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
362 diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
363 kmeans_recalculate_centers (kmeans, reader, qc);
364 if (show_warning1 && qc->ngroups > kmeans->n)
366 msg (MW, _("Number of clusters may not be larger than the number "
368 show_warning1 = false;
374 for (i = 0; i < qc->ngroups; i++)
376 if (kmeans->num_elements_groups->data[i] == 0)
379 if (kmeans->trials >= 3)
389 assert (redo_count < 10);
395 /* Reports centers of clusters.
396 Initial parameter is optional for future use.
397 If initial is true, initial cluster centers are reported. Otherwise,
398 resulted centers are reported. */
400 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
405 nc = qc->ngroups + 1;
407 t = tab_create (nc, nr);
408 tab_headers (t, 0, nc - 1, 0, 1);
412 tab_title (t, _("Final Cluster Centers"));
416 tab_title (t, _("Initial Cluster Centers"));
418 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
419 tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
420 tab_hline (t, TAL_1, 1, nc - 1, 2);
423 for (i = 0; i < qc->ngroups; i++)
425 tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
428 tab_hline (t, TAL_1, 1, nc - 1, currow);
430 for (i = 0; i < qc->n_vars; i++)
432 tab_text (t, 0, currow + i, TAB_LEFT,
433 var_to_string (qc->vars[i]));
436 for (i = 0; i < qc->ngroups; i++)
438 for (j = 0; j < qc->n_vars; j++)
442 tab_double (t, i + 1, j + 4, TAB_CENTER,
443 gsl_matrix_get (kmeans->centers,
444 kmeans->group_order->data[i], j),
445 var_get_print_format (qc->vars[j]), RC_OTHER);
449 tab_double (t, i + 1, j + 4, TAB_CENTER,
450 gsl_matrix_get (kmeans->initial_centers,
451 kmeans->group_order->data[i], j),
452 var_get_print_format (qc->vars[j]), RC_OTHER);
459 /* Reports cluster membership for each case. */
461 quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
467 struct casereader *cs = casereader_clone (reader);
470 t = tab_create (nc, nr);
471 tab_headers (t, 0, nc - 1, 0, 0);
472 tab_title (t, _("Cluster Membership"));
473 tab_text (t, 0, 0, TAB_CENTER, _("Case Number"));
474 tab_text (t, 1, 0, TAB_CENTER, _("Cluster"));
475 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
476 tab_hline (t, TAL_1, 0, nc - 1, 1);
479 for (i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c))
481 assert (i < kmeans->n);
482 clust = kmeans_get_nearest_group (kmeans, c, qc);
483 clust = kmeans->group_order->data[clust];
484 tab_text_format (t, 0, i+1, TAB_CENTER, "%d", (i + 1));
485 tab_text_format (t, 1, i+1, TAB_CENTER, "%d", (clust + 1));
487 assert (i == kmeans->n);
489 casereader_destroy (cs);
493 /* Reports number of cases of each single cluster. */
495 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
502 nr = qc->ngroups + 1;
503 t = tab_create (nc, nr);
504 tab_headers (t, 0, nc - 1, 0, 0);
505 tab_title (t, _("Number of Cases in each Cluster"));
506 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
507 tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
510 for (i = 0; i < qc->ngroups; i++)
512 tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
514 kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
515 tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
519 tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
520 tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
526 quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
528 kmeans_order_groups (kmeans, qc); /* what does this do? */
529 if( qc->print_initial_clusters )
530 quick_cluster_show_centers (kmeans, true, qc);
531 quick_cluster_show_centers (kmeans, false, qc);
532 quick_cluster_show_number_cases (kmeans, qc);
533 if( qc->print_cluster_membership )
534 quick_cluster_show_membership(kmeans, reader, qc);
538 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
541 struct Kmeans *kmeans;
543 const struct dictionary *dict = dataset_dict (ds);
546 qc.missing_type = MISS_LISTWISE;
548 qc.print_cluster_membership = false; /* default = do not output case cluster membership */
549 qc.print_initial_clusters = false; /* default = do not print initial clusters */
551 if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
552 PV_NO_DUPLICATE | PV_NUMERIC))
554 return (CMD_FAILURE);
557 while (lex_token (lexer) != T_ENDCMD)
559 lex_match (lexer, T_SLASH);
561 if (lex_match_id (lexer, "MISSING"))
563 lex_match (lexer, T_EQUALS);
564 while (lex_token (lexer) != T_ENDCMD
565 && lex_token (lexer) != T_SLASH)
567 if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT"))
569 qc.missing_type = MISS_LISTWISE;
571 else if (lex_match_id (lexer, "PAIRWISE"))
573 qc.missing_type = MISS_PAIRWISE;
575 else if (lex_match_id (lexer, "INCLUDE"))
577 qc.exclude = MV_SYSTEM;
579 else if (lex_match_id (lexer, "EXCLUDE"))
585 lex_error (lexer, NULL);
590 else if (lex_match_id (lexer, "PRINT"))
592 lex_match (lexer, T_EQUALS);
593 while (lex_token (lexer) != T_ENDCMD
594 && lex_token (lexer) != T_SLASH)
596 if (lex_match_id (lexer, "CLUSTER"))
597 qc.print_cluster_membership = true;
598 else if (lex_match_id (lexer, "INITIAL"))
599 qc.print_initial_clusters = true;
602 lex_error (lexer, NULL);
607 else if (lex_match_id (lexer, "CRITERIA"))
609 lex_match (lexer, T_EQUALS);
610 while (lex_token (lexer) != T_ENDCMD
611 && lex_token (lexer) != T_SLASH)
613 if (lex_match_id (lexer, "CLUSTERS"))
615 if (lex_force_match (lexer, T_LPAREN))
617 lex_force_int (lexer);
618 qc.ngroups = lex_integer (lexer);
621 lex_error (lexer, _("The number of clusters must be positive"));
625 lex_force_match (lexer, T_RPAREN);
628 else if (lex_match_id (lexer, "MXITER"))
630 if (lex_force_match (lexer, T_LPAREN))
632 lex_force_int (lexer);
633 qc.maxiter = lex_integer (lexer);
636 lex_error (lexer, _("The number of iterations must be positive"));
640 lex_force_match (lexer, T_RPAREN);
645 lex_error (lexer, NULL);
652 lex_error (lexer, NULL);
657 qc.wv = dict_get_weight (dict);
660 struct casereader *group;
661 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
663 while (casegrouper_get_next_group (grouper, &group))
665 if ( qc.missing_type == MISS_LISTWISE )
667 group = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
672 kmeans = kmeans_create (&qc);
673 kmeans_cluster (kmeans, group, &qc);
674 quick_cluster_show_results (kmeans, group, &qc);
675 kmeans_destroy (kmeans);
676 casereader_destroy (group);
678 ok = casegrouper_destroy (grouper);
680 ok = proc_commit (ds) && ok;