1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2011, 2012, 2015 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
26 #include "data/case.h"
27 #include "data/casegrouper.h"
28 #include "data/casereader.h"
29 #include "data/casewriter.h"
30 #include "data/dataset.h"
31 #include "data/dictionary.h"
32 #include "data/format.h"
33 #include "data/missing-values.h"
34 #include "language/command.h"
35 #include "language/lexer/lexer.h"
36 #include "language/lexer/variable-parser.h"
37 #include "libpspp/message.h"
38 #include "libpspp/misc.h"
39 #include "libpspp/assertion.h"
40 #include "libpspp/str.h"
41 #include "math/random.h"
42 #include "output/pivot-table.h"
43 #include "output/text-item.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
58 const struct variable **vars;
61 double epsilon; /* The convergence criterion */
63 int ngroups; /* Number of group. (Given by the user) */
64 int maxiter; /* Maximum iterations (Given by the user) */
65 bool print_cluster_membership; /* true => print membership */
66 bool print_initial_clusters; /* true => print initial cluster */
67 bool no_initial; /* true => simplified initial cluster selection */
68 bool no_update; /* true => do not iterate */
70 const struct variable *wv; /* Weighting variable. */
72 enum missing_type missing_type;
73 enum mv_class exclude;
76 /* Holds all of the information for the functions. int n, holds the number of
77 observation and its default value is -1. We set it in
78 kmeans_recalculate_centers in first invocation. */
81 gsl_matrix *centers; /* Centers for groups. */
82 gsl_matrix *updated_centers;
85 gsl_vector_long *num_elements_groups;
87 gsl_matrix *initial_centers; /* Initial random centers. */
88 double convergence_criteria;
89 gsl_permutation *group_order; /* Group order for reporting. */
92 static struct Kmeans *kmeans_create (const struct qc *qc);
94 static void kmeans_get_nearest_group (const struct Kmeans *kmeans,
95 struct ccase *c, const struct qc *,
96 int *, double *, int *, double *);
98 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
100 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader,
103 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial,
106 static void quick_cluster_show_membership (struct Kmeans *kmeans,
107 const struct casereader *reader,
110 static void quick_cluster_show_number_cases (struct Kmeans *kmeans,
113 static void quick_cluster_show_results (struct Kmeans *kmeans,
114 const struct casereader *reader,
117 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
119 static void kmeans_destroy (struct Kmeans *kmeans);
121 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
122 variables 'variables', number of cases 'n', number of variables 'm', number
123 of clusters and amount of maximum iterations. */
124 static struct Kmeans *
125 kmeans_create (const struct qc *qc)
127 struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
128 kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
129 kmeans->updated_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
130 kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
131 kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
132 kmeans->initial_centers = NULL;
138 kmeans_destroy (struct Kmeans *kmeans)
140 gsl_matrix_free (kmeans->centers);
141 gsl_matrix_free (kmeans->updated_centers);
142 gsl_matrix_free (kmeans->initial_centers);
144 gsl_vector_long_free (kmeans->num_elements_groups);
146 gsl_permutation_free (kmeans->group_order);
152 diff_matrix (const gsl_matrix *m1, const gsl_matrix *m2)
155 double max_diff = -INFINITY;
156 for (i = 0; i < m1->size1; ++i)
159 for (j = 0; j < m1->size2; ++j)
161 diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j) );
173 matrix_mindist (const gsl_matrix *m, int *mn, int *mm)
176 double mindist = INFINITY;
177 for (i = 0; i < m->size1 - 1; ++i)
179 for (j = i + 1; j < m->size1; ++j)
183 for (k = 0; k < m->size2; ++k)
185 diff_sq += pow2 (gsl_matrix_get (m, j, k) - gsl_matrix_get (m, i, k));
187 if (diff_sq < mindist)
202 /* Return the distance of C from the group whose index is WHICH */
204 dist_from_case (const struct Kmeans *kmeans, const struct ccase *c,
205 const struct qc *qc, int which)
209 for (j = 0; j < qc->n_vars; j++)
211 const union value *val = case_data (c, qc->vars[j]);
212 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
215 dist += pow2 (gsl_matrix_get (kmeans->centers, which, j) - val->f);
221 /* Return the minimum distance of the group WHICH and all other groups */
223 min_dist_from (const struct Kmeans *kmeans, const struct qc *qc, int which)
227 double mindist = INFINITY;
228 for (i = 0; i < qc->ngroups; i++)
234 for (j = 0; j < qc->n_vars; j++)
236 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j)
237 - gsl_matrix_get (kmeans->centers, which, j));
251 /* Calculate the initial cluster centers. */
253 kmeans_initial_centers (struct Kmeans *kmeans,
254 const struct casereader *reader,
260 struct casereader *cs = casereader_clone (reader);
261 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
263 bool missing = false;
264 for (j = 0; j < qc->n_vars; ++j)
266 const union value *val = case_data (c, qc->vars[j]);
267 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
273 if (nc < qc->ngroups)
274 gsl_matrix_set (kmeans->centers, nc, j, val->f);
280 if (nc++ < qc->ngroups)
289 double m = matrix_mindist (kmeans->centers, &mn, &mm);
291 kmeans_get_nearest_group (kmeans, c, qc, &mq, &delta, &mp, NULL);
293 /* If the distance between C and the nearest group, is greater than the distance
294 between the two groups which are clostest to each
295 other, then one group must be replaced. */
297 /* Out of mn and mm, which is the clostest of the two groups to C ? */
298 int which = (dist_from_case (kmeans, c, qc, mn)
299 > dist_from_case (kmeans, c, qc, mm)) ? mm : mn;
301 for (j = 0; j < qc->n_vars; ++j)
303 const union value *val = case_data (c, qc->vars[j]);
304 gsl_matrix_set (kmeans->centers, which, j, val->f);
307 else if (dist_from_case (kmeans, c, qc, mp) > min_dist_from (kmeans, qc, mq))
308 /* If the distance between C and the second nearest group
309 (MP) is greater than the smallest distance between the
310 nearest group (MQ) and any other group, then replace
313 for (j = 0; j < qc->n_vars; ++j)
315 const union value *val = case_data (c, qc->vars[j]);
316 gsl_matrix_set (kmeans->centers, mq, j, val->f);
322 casereader_destroy (cs);
324 kmeans->convergence_criteria = qc->epsilon * matrix_mindist (kmeans->centers, NULL, NULL);
326 /* As it is the first iteration, the variable kmeans->initial_centers is NULL
327 and it is created once for reporting issues. */
328 kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
329 gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
333 /* Return the index of the group which is nearest to the case C */
335 kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c,
336 const struct qc *qc, int *g_q, double *delta_q,
337 int *g_p, double *delta_p)
342 double mindist0 = INFINITY;
343 double mindist1 = INFINITY;
344 for (i = 0; i < qc->ngroups; i++)
347 for (j = 0; j < qc->n_vars; j++)
349 const union value *val = case_data (c, qc->vars[j]);
350 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
353 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
364 else if (dist < mindist1)
388 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
390 gsl_vector *v = gsl_vector_alloc (qc->ngroups);
391 gsl_matrix_get_col (v, kmeans->centers, 0);
392 gsl_sort_vector_index (kmeans->group_order, v);
397 Does iterations, checks convergency. */
399 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader,
404 kmeans_initial_centers (kmeans, reader, qc);
406 gsl_matrix_memcpy (kmeans->updated_centers, kmeans->centers);
409 for (int xx = 0 ; xx < qc->maxiter ; ++xx)
411 gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
416 struct casereader *r = casereader_clone (reader);
418 for (; (c = casereader_read (r)) != NULL; case_unref (c))
422 bool missing = false;
424 for (j = 0; j < qc->n_vars; j++)
426 const union value *val = case_data (c, qc->vars[j]);
427 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
434 double mindist = INFINITY;
435 for (g = 0; g < qc->ngroups; ++g)
437 double d = dist_from_case (kmeans, c, qc, g);
446 long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
447 *n += qc->wv ? case_data (c, qc->wv)->f : 1.0;
450 for (j = 0; j < qc->n_vars; ++j)
452 const union value *val = case_data (c, qc->vars[j]);
453 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
455 double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
456 *x += val->f * (qc->wv ? case_data (c, qc->wv)->f : 1.0);
460 casereader_destroy (r);
465 /* Divide the cluster sums by the number of items in each cluster */
466 for (g = 0; g < qc->ngroups; ++g)
468 for (j = 0; j < qc->n_vars; ++j)
470 long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
471 double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
472 *x /= n + 1; // Plus 1 for the initial centers
477 gsl_matrix_memcpy (kmeans->centers, kmeans->updated_centers);
482 gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
483 gsl_matrix_set_all (kmeans->updated_centers, 0.0);
485 struct casereader *cs = casereader_clone (reader);
486 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
489 kmeans_get_nearest_group (kmeans, c, qc, &group, NULL, NULL, NULL);
491 for (j = 0; j < qc->n_vars; ++j)
493 const union value *val = case_data (c, qc->vars[j]);
494 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
497 double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
501 long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
502 *n += qc->wv ? case_data (c, qc->wv)->f : 1.0;
505 casereader_destroy (cs);
508 /* Divide the cluster sums by the number of items in each cluster */
509 for (g = 0; g < qc->ngroups; ++g)
511 for (j = 0; j < qc->n_vars; ++j)
513 long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
514 double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
519 double d = diff_matrix (kmeans->updated_centers, kmeans->centers);
520 if (d < kmeans->convergence_criteria)
529 /* Reports centers of clusters.
530 Initial parameter is optional for future use.
531 If initial is true, initial cluster centers are reported. Otherwise,
532 resulted centers are reported. */
534 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
536 struct pivot_table *table
537 = pivot_table_create (initial
538 ? N_("Initial Cluster Centers")
539 : N_("Final Cluster Centers"));
541 struct pivot_dimension *clusters
542 = pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"));
544 clusters->root->show_label = true;
545 for (size_t i = 0; i < qc->ngroups; i++)
546 pivot_category_create_leaf (clusters->root,
547 pivot_value_new_integer (i + 1));
549 struct pivot_dimension *variables
550 = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Variable"));
552 for (size_t i = 0; i < qc->n_vars; i++)
553 pivot_category_create_leaf (variables->root,
554 pivot_value_new_variable (qc->vars[i]));
556 const gsl_matrix *matrix = (initial
557 ? kmeans->initial_centers
559 for (size_t i = 0; i < qc->ngroups; i++)
560 for (size_t j = 0; j < qc->n_vars; j++)
562 double x = gsl_matrix_get (matrix, kmeans->group_order->data[i], j);
563 union value v = { .f = x };
564 pivot_table_put2 (table, i, j,
565 pivot_value_new_var_value (qc->vars[j], &v));
568 pivot_table_submit (table);
571 /* Reports cluster membership for each case. */
573 quick_cluster_show_membership (struct Kmeans *kmeans,
574 const struct casereader *reader,
577 struct pivot_table *table = pivot_table_create (N_("Cluster Membership"));
579 pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Cluster"),
582 struct pivot_dimension *cases
583 = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Case Number"));
585 cases->root->show_label = true;
587 gsl_permutation *ip = gsl_permutation_alloc (qc->ngroups);
588 gsl_permutation_inverse (ip, kmeans->group_order);
590 struct casereader *cs = casereader_clone (reader);
592 for (int i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c))
594 assert (i < kmeans->n);
596 kmeans_get_nearest_group (kmeans, c, qc, &clust, NULL, NULL, NULL);
597 int cluster = ip->data[clust];
599 int case_idx = pivot_category_create_leaf (cases->root,
600 pivot_value_new_integer (i + 1));
601 pivot_table_put2 (table, 0, case_idx,
602 pivot_value_new_integer (cluster + 1));
605 gsl_permutation_free (ip);
606 pivot_table_submit (table);
607 casereader_destroy (cs);
611 /* Reports number of cases of each single cluster. */
613 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
615 struct pivot_table *table
616 = pivot_table_create (N_("Number of Cases in each Cluster"));
618 pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
621 struct pivot_dimension *clusters
622 = pivot_dimension_create (table, PIVOT_AXIS_ROW, N_("Clusters"));
624 struct pivot_category *group
625 = pivot_category_create_group (clusters->root, N_("Cluster"));
628 for (int i = 0; i < qc->ngroups; i++)
631 = pivot_category_create_leaf (group, pivot_value_new_integer (i + 1));
632 int count = kmeans->num_elements_groups->data [kmeans->group_order->data[i]];
633 pivot_table_put2 (table, 0, cluster_idx, pivot_value_new_integer (count));
637 int cluster_idx = pivot_category_create_leaf (clusters->root,
638 pivot_value_new_text (N_("Valid")));
639 pivot_table_put2 (table, 0, cluster_idx, pivot_value_new_integer (total));
640 pivot_table_submit (table);
645 quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader,
648 kmeans_order_groups (kmeans, qc); /* what does this do? */
650 if (qc->print_initial_clusters)
651 quick_cluster_show_centers (kmeans, true, qc);
652 quick_cluster_show_centers (kmeans, false, qc);
653 quick_cluster_show_number_cases (kmeans, qc);
654 if (qc->print_cluster_membership)
655 quick_cluster_show_membership (kmeans, reader, qc);
659 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
662 struct Kmeans *kmeans;
664 const struct dictionary *dict = dataset_dict (ds);
667 qc.epsilon = DBL_EPSILON;
668 qc.missing_type = MISS_LISTWISE;
670 qc.print_cluster_membership = false; /* default = do not output case cluster membership */
671 qc.print_initial_clusters = false; /* default = do not print initial clusters */
672 qc.no_initial = false; /* default = use well separated initial clusters */
673 qc.no_update = false; /* default = iterate until convergence or max iterations */
675 if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
676 PV_NO_DUPLICATE | PV_NUMERIC))
678 return (CMD_FAILURE);
681 while (lex_token (lexer) != T_ENDCMD)
683 lex_match (lexer, T_SLASH);
685 if (lex_match_id (lexer, "MISSING"))
687 lex_match (lexer, T_EQUALS);
688 while (lex_token (lexer) != T_ENDCMD
689 && lex_token (lexer) != T_SLASH)
691 if (lex_match_id (lexer, "LISTWISE")
692 || lex_match_id (lexer, "DEFAULT"))
694 qc.missing_type = MISS_LISTWISE;
696 else if (lex_match_id (lexer, "PAIRWISE"))
698 qc.missing_type = MISS_PAIRWISE;
700 else if (lex_match_id (lexer, "INCLUDE"))
702 qc.exclude = MV_SYSTEM;
704 else if (lex_match_id (lexer, "EXCLUDE"))
710 lex_error (lexer, NULL);
715 else if (lex_match_id (lexer, "PRINT"))
717 lex_match (lexer, T_EQUALS);
718 while (lex_token (lexer) != T_ENDCMD
719 && lex_token (lexer) != T_SLASH)
721 if (lex_match_id (lexer, "CLUSTER"))
722 qc.print_cluster_membership = true;
723 else if (lex_match_id (lexer, "INITIAL"))
724 qc.print_initial_clusters = true;
727 lex_error (lexer, NULL);
732 else if (lex_match_id (lexer, "CRITERIA"))
734 lex_match (lexer, T_EQUALS);
735 while (lex_token (lexer) != T_ENDCMD
736 && lex_token (lexer) != T_SLASH)
738 if (lex_match_id (lexer, "CLUSTERS"))
740 if (lex_force_match (lexer, T_LPAREN) &&
741 lex_force_int (lexer))
743 qc.ngroups = lex_integer (lexer);
746 lex_error (lexer, _("The number of clusters must be positive"));
750 if (!lex_force_match (lexer, T_RPAREN))
754 else if (lex_match_id (lexer, "CONVERGE"))
756 if (lex_force_match (lexer, T_LPAREN) &&
757 lex_force_num (lexer))
759 qc.epsilon = lex_number (lexer);
762 lex_error (lexer, _("The convergence criterion must be positive"));
766 if (!lex_force_match (lexer, T_RPAREN))
770 else if (lex_match_id (lexer, "MXITER"))
772 if (lex_force_match (lexer, T_LPAREN) &&
773 lex_force_int (lexer))
775 qc.maxiter = lex_integer (lexer);
778 lex_error (lexer, _("The number of iterations must be positive"));
782 if (!lex_force_match (lexer, T_RPAREN))
786 else if (lex_match_id (lexer, "NOINITIAL"))
788 qc.no_initial = true;
790 else if (lex_match_id (lexer, "NOUPDATE"))
796 lex_error (lexer, NULL);
803 lex_error (lexer, NULL);
808 qc.wv = dict_get_weight (dict);
811 struct casereader *group;
812 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
814 while (casegrouper_get_next_group (grouper, &group))
816 if ( qc.missing_type == MISS_LISTWISE )
818 group = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
823 kmeans = kmeans_create (&qc);
824 kmeans_cluster (kmeans, group, &qc);
825 quick_cluster_show_results (kmeans, group, &qc);
826 kmeans_destroy (kmeans);
827 casereader_destroy (group);
829 ok = casegrouper_destroy (grouper);
831 ok = proc_commit (ds) && ok;