1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 2011, 2012, 2015 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include <gsl/gsl_matrix.h>
20 #include <gsl/gsl_permutation.h>
21 #include <gsl/gsl_sort_vector.h>
22 #include <gsl/gsl_statistics.h>
27 #include "data/case.h"
28 #include "data/casegrouper.h"
29 #include "data/casereader.h"
30 #include "data/casewriter.h"
31 #include "data/dataset.h"
32 #include "data/dictionary.h"
33 #include "data/format.h"
34 #include "data/missing-values.h"
35 #include "language/command.h"
36 #include "language/lexer/lexer.h"
37 #include "language/lexer/variable-parser.h"
38 #include "libpspp/message.h"
39 #include "libpspp/misc.h"
40 #include "libpspp/assertion.h"
41 #include "libpspp/str.h"
42 #include "math/random.h"
43 #include "output/tab.h"
44 #include "output/text-item.h"
47 #define _(msgid) gettext (msgid)
48 #define N_(msgid) msgid
59 const struct variable **vars;
62 double epsilon; /* The convergence criterium */
64 int ngroups; /* Number of group. (Given by the user) */
65 int maxiter; /* Maximum iterations (Given by the user) */
66 bool print_cluster_membership; /* true => print membership */
67 bool print_initial_clusters; /* true => print initial cluster */
68 bool no_initial; /* true => simplified initial cluster selection */
69 bool no_update; /* true => do not iterate */
71 const struct variable *wv; /* Weighting variable. */
73 enum missing_type missing_type;
74 enum mv_class exclude;
77 /* Holds all of the information for the functions. int n, holds the number of
78 observation and its default value is -1. We set it in
79 kmeans_recalculate_centers in first invocation. */
82 gsl_matrix *centers; /* Centers for groups. */
83 gsl_matrix *updated_centers;
86 gsl_vector_long *num_elements_groups;
88 gsl_matrix *initial_centers; /* Initial random centers. */
89 double convergence_criteria;
90 gsl_permutation *group_order; /* Group order for reporting. */
93 static struct Kmeans *kmeans_create (const struct qc *qc);
95 static void kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c, const struct qc *, int *, double *, int *, double *);
97 static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
99 static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
101 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
103 static void quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
105 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
107 static void quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
109 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
111 static void kmeans_destroy (struct Kmeans *kmeans);
113 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
114 variables 'variables', number of cases 'n', number of variables 'm', number
115 of clusters and amount of maximum iterations. */
116 static struct Kmeans *
117 kmeans_create (const struct qc *qc)
119 struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
120 kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
121 kmeans->updated_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
122 kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
123 kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
124 kmeans->initial_centers = NULL;
130 kmeans_destroy (struct Kmeans *kmeans)
132 gsl_matrix_free (kmeans->centers);
133 gsl_matrix_free (kmeans->updated_centers);
134 gsl_matrix_free (kmeans->initial_centers);
136 gsl_vector_long_free (kmeans->num_elements_groups);
138 gsl_permutation_free (kmeans->group_order);
144 diff_matrix (const gsl_matrix *m1, const gsl_matrix *m2)
147 double max_diff = -INFINITY;
148 for (i = 0; i < m1->size1; ++i)
151 for (j = 0; j < m1->size2; ++j)
153 diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j) );
165 matrix_mindist (const gsl_matrix *m, int *mn, int *mm)
168 double mindist = INFINITY;
169 for (i = 0; i < m->size1 - 1; ++i)
171 for (j = i + 1; j < m->size1; ++j)
175 for (k = 0; k < m->size2; ++k)
177 diff_sq += pow2 (gsl_matrix_get (m, j, k) - gsl_matrix_get (m, i, k));
179 if (diff_sq < mindist)
190 return sqrt (mindist);
195 dump_matrix (const gsl_matrix *m)
199 for (i = 0 ; i < m->size1; ++i)
201 for (j = 0 ; j < m->size2; ++j)
202 printf ("%02f ", gsl_matrix_get (m, i, j));
208 /* Return the distance of C from the group whose index is WHICH */
210 dist_from_case (const struct Kmeans *kmeans, const struct ccase *c, const struct qc *qc, int which)
214 for (j = 0; j < qc->n_vars; j++)
216 const union value *val = case_data (c, qc->vars[j]);
217 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
220 dist += pow2 (gsl_matrix_get (kmeans->centers, which, j) - val->f);
226 /* Return the minimum distance of the group WHICH and all other groups */
228 min_dist_from (const struct Kmeans *kmeans, const struct qc *qc, int which)
232 double mindist = INFINITY;
233 for (i = 0; i < qc->ngroups; i++)
239 for (j = 0; j < qc->n_vars; j++)
241 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - gsl_matrix_get (kmeans->centers, which, j));
250 return sqrt (mindist);
255 /* Calculate the intial cluster centers. */
257 kmeans_initial_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
262 struct casereader *cs = casereader_clone (reader);
263 for (; (c = casereader_read (cs)) != NULL; case_unref (c))
265 bool missing = false;
266 for (j = 0; j < qc->n_vars; ++j)
268 const union value *val = case_data (c, qc->vars[j]);
269 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
275 if (nc < qc->ngroups)
276 gsl_matrix_set (kmeans->centers, nc, j, val->f);
282 if (nc++ < qc->ngroups)
291 double m = matrix_mindist (kmeans->centers, &mn, &mm);
293 kmeans_get_nearest_group (kmeans, c, qc, &mq, &delta, &mp, NULL);
295 /* If the distance between C and the nearest group, is greater than the distance
296 between the two groups which are clostest to each other, then one group must be replaced */
298 /* Out of mn and mm, which is the clostest of the two groups to C ? */
299 int which = (dist_from_case (kmeans, c, qc, mn) > dist_from_case (kmeans, c, qc, mm)) ? mm : mn;
301 for (j = 0; j < qc->n_vars; ++j)
303 const union value *val = case_data (c, qc->vars[j]);
304 gsl_matrix_set (kmeans->centers, which, j, val->f);
307 else if (dist_from_case (kmeans, c, qc, mp) > min_dist_from (kmeans, qc, mq))
308 /* If the distance between C and the second nearest group (MP) is greater than the
309 smallest distance between the nearest group (MQ) and any other group, then replace
312 for (j = 0; j < qc->n_vars; ++j)
314 const union value *val = case_data (c, qc->vars[j]);
315 gsl_matrix_set (kmeans->centers, mq, j, val->f);
321 casereader_destroy (cs);
323 kmeans->convergence_criteria = qc->epsilon * matrix_mindist (kmeans->centers, NULL, NULL);
325 /* As it is the first iteration, the variable kmeans->initial_centers is NULL
326 and it is created once for reporting issues. */
327 kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
328 gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
332 /* Return the index of the group which is nearest to the case C */
334 kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c, const struct qc *qc, int *g_q, double *delta_q, int *g_p, double *delta_p)
339 double mindist0 = INFINITY;
340 double mindist1 = INFINITY;
341 for (i = 0; i < qc->ngroups; i++)
344 for (j = 0; j < qc->n_vars; j++)
346 const union value *val = case_data (c, qc->vars[j]);
347 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
350 dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
362 else if (dist < mindist1)
386 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
388 gsl_vector *v = gsl_vector_alloc (qc->ngroups);
389 gsl_matrix_get_col (v, kmeans->centers, 0);
390 gsl_sort_vector_index (kmeans->group_order, v);
395 Does iterations, checks convergency. */
397 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
401 kmeans_initial_centers (kmeans, reader, qc);
403 gsl_matrix_memcpy (kmeans->updated_centers, kmeans->centers);
406 for (int xx = 0 ; xx < qc->maxiter ; ++xx)
408 gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
413 struct casereader *r = casereader_clone (reader);
415 for (; (c = casereader_read (r)) != NULL; case_unref (c))
419 bool missing = false;
421 for (j = 0; j < qc->n_vars; j++)
423 const union value *val = case_data (c, qc->vars[j]);
424 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
431 double mindist = INFINITY;
432 for (g = 0; g < qc->ngroups; ++g)
434 double d = dist_from_case (kmeans, c, qc, g);
443 long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
444 *n += qc->wv ? case_data (c, qc->wv)->f : 1.0;
447 for (j = 0; j < qc->n_vars; ++j)
449 const union value *val = case_data (c, qc->vars[j]);
450 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
452 double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
453 *x += val->f * (qc->wv ? case_data (c, qc->wv)->f : 1.0);
457 casereader_destroy (r);
462 /* Divide the cluster sums by the number of items in each cluster */
463 for (g = 0; g < qc->ngroups; ++g)
465 for (j = 0; j < qc->n_vars; ++j)
467 long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
468 double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
469 *x /= n + 1; // Plus 1 for the initial centers
474 gsl_matrix_memcpy (kmeans->centers, kmeans->updated_centers);
480 gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
481 gsl_matrix_set_all (kmeans->updated_centers, 0.0);
483 struct casereader *cs = casereader_clone (reader);
484 for (; (c = casereader_read (cs)) != NULL; i++, case_unref (c))
487 kmeans_get_nearest_group (kmeans, c, qc, &group, NULL, NULL, NULL);
489 for (j = 0; j < qc->n_vars; ++j)
491 const union value *val = case_data (c, qc->vars[j]);
492 if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
495 double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
499 long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
500 *n += qc->wv ? case_data (c, qc->wv)->f : 1.0;
505 casereader_destroy (cs);
508 /* Divide the cluster sums by the number of items in each cluster */
509 for (g = 0; g < qc->ngroups; ++g)
511 for (j = 0; j < qc->n_vars; ++j)
513 long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
514 double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
519 double d = diff_matrix (kmeans->updated_centers, kmeans->centers);
520 if (d < kmeans->convergence_criteria)
529 /* Reports centers of clusters.
530 Initial parameter is optional for future use.
531 If initial is true, initial cluster centers are reported. Otherwise,
532 resulted centers are reported. */
534 quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
539 nc = qc->ngroups + 1;
541 t = tab_create (nc, nr);
542 tab_headers (t, 0, nc - 1, 0, 1);
546 tab_title (t, _("Final Cluster Centers"));
550 tab_title (t, _("Initial Cluster Centers"));
552 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
553 tab_joint_text (t, 1, 0, nc - 1, 0, TAB_CENTER, _("Cluster"));
554 tab_hline (t, TAL_1, 1, nc - 1, 2);
557 for (i = 0; i < qc->ngroups; i++)
559 tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
562 tab_hline (t, TAL_1, 1, nc - 1, currow);
564 for (i = 0; i < qc->n_vars; i++)
566 tab_text (t, 0, currow + i, TAB_LEFT,
567 var_to_string (qc->vars[i]));
570 for (i = 0; i < qc->ngroups; i++)
572 for (j = 0; j < qc->n_vars; j++)
576 tab_double (t, i + 1, j + 4, TAB_CENTER,
577 gsl_matrix_get (kmeans->centers,
578 kmeans->group_order->data[i], j),
579 var_get_print_format (qc->vars[j]), RC_OTHER);
583 tab_double (t, i + 1, j + 4, TAB_CENTER,
584 gsl_matrix_get (kmeans->initial_centers,
585 kmeans->group_order->data[i], j),
586 var_get_print_format (qc->vars[j]), RC_OTHER);
593 /* Reports cluster membership for each case. */
595 quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
601 struct casereader *cs = casereader_clone (reader);
604 t = tab_create (nc, nr);
605 tab_headers (t, 0, nc - 1, 0, 0);
606 tab_title (t, _("Cluster Membership"));
607 tab_text (t, 0, 0, TAB_CENTER, _("Case Number"));
608 tab_text (t, 1, 0, TAB_CENTER, _("Cluster"));
609 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
610 tab_hline (t, TAL_1, 0, nc - 1, 1);
612 gsl_permutation *ip = gsl_permutation_alloc (qc->ngroups);
613 gsl_permutation_inverse (ip, kmeans->group_order);
615 for (i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c))
618 assert (i < kmeans->n);
619 kmeans_get_nearest_group (kmeans, c, qc, &clust, NULL, NULL, NULL);
620 clust = ip->data[clust];
621 tab_text_format (t, 0, i+1, TAB_CENTER, "%d", (i + 1));
622 tab_text_format (t, 1, i+1, TAB_CENTER, "%d", (clust + 1));
624 gsl_permutation_free (ip);
625 assert (i == kmeans->n);
627 casereader_destroy (cs);
631 /* Reports number of cases of each single cluster. */
633 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
640 nr = qc->ngroups + 1;
641 t = tab_create (nc, nr);
642 tab_headers (t, 0, nc - 1, 0, 0);
643 tab_title (t, _("Number of Cases in each Cluster"));
644 tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
645 tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
648 for (i = 0; i < qc->ngroups; i++)
650 tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
652 kmeans->num_elements_groups->data[kmeans->group_order->data[i]];
653 tab_text_format (t, 2, i, TAB_CENTER, "%d", numelem);
657 tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
658 tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
664 quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
666 kmeans_order_groups (kmeans, qc); /* what does this do? */
668 if( qc->print_initial_clusters )
669 quick_cluster_show_centers (kmeans, true, qc);
670 quick_cluster_show_centers (kmeans, false, qc);
671 quick_cluster_show_number_cases (kmeans, qc);
672 if( qc->print_cluster_membership )
673 quick_cluster_show_membership(kmeans, reader, qc);
677 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
680 struct Kmeans *kmeans;
682 const struct dictionary *dict = dataset_dict (ds);
685 qc.epsilon = DBL_EPSILON;
686 qc.missing_type = MISS_LISTWISE;
688 qc.print_cluster_membership = false; /* default = do not output case cluster membership */
689 qc.print_initial_clusters = false; /* default = do not print initial clusters */
690 qc.no_initial = false; /* default = use well separated initial clusters */
691 qc.no_update = false; /* default = iterate until convergence or max iterations */
693 if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
694 PV_NO_DUPLICATE | PV_NUMERIC))
696 return (CMD_FAILURE);
699 while (lex_token (lexer) != T_ENDCMD)
701 lex_match (lexer, T_SLASH);
703 if (lex_match_id (lexer, "MISSING"))
705 lex_match (lexer, T_EQUALS);
706 while (lex_token (lexer) != T_ENDCMD
707 && lex_token (lexer) != T_SLASH)
709 if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT"))
711 qc.missing_type = MISS_LISTWISE;
713 else if (lex_match_id (lexer, "PAIRWISE"))
715 qc.missing_type = MISS_PAIRWISE;
717 else if (lex_match_id (lexer, "INCLUDE"))
719 qc.exclude = MV_SYSTEM;
721 else if (lex_match_id (lexer, "EXCLUDE"))
727 lex_error (lexer, NULL);
732 else if (lex_match_id (lexer, "PRINT"))
734 lex_match (lexer, T_EQUALS);
735 while (lex_token (lexer) != T_ENDCMD
736 && lex_token (lexer) != T_SLASH)
738 if (lex_match_id (lexer, "CLUSTER"))
739 qc.print_cluster_membership = true;
740 else if (lex_match_id (lexer, "INITIAL"))
741 qc.print_initial_clusters = true;
744 lex_error (lexer, NULL);
749 else if (lex_match_id (lexer, "CRITERIA"))
751 lex_match (lexer, T_EQUALS);
752 while (lex_token (lexer) != T_ENDCMD
753 && lex_token (lexer) != T_SLASH)
755 if (lex_match_id (lexer, "CLUSTERS"))
757 if (lex_force_match (lexer, T_LPAREN))
759 lex_force_int (lexer);
760 qc.ngroups = lex_integer (lexer);
763 lex_error (lexer, _("The number of clusters must be positive"));
767 lex_force_match (lexer, T_RPAREN);
770 else if (lex_match_id (lexer, "CONVERGE"))
772 if (lex_force_match (lexer, T_LPAREN))
774 lex_force_num (lexer);
775 qc.epsilon = lex_number (lexer);
778 lex_error (lexer, _("The convergence criterium must be positive"));
782 lex_force_match (lexer, T_RPAREN);
785 else if (lex_match_id (lexer, "MXITER"))
787 if (lex_force_match (lexer, T_LPAREN))
789 lex_force_int (lexer);
790 qc.maxiter = lex_integer (lexer);
793 lex_error (lexer, _("The number of iterations must be positive"));
797 lex_force_match (lexer, T_RPAREN);
800 else if (lex_match_id (lexer, "NOINITIAL"))
802 qc.no_initial = true;
804 else if (lex_match_id (lexer, "NOUPDATE"))
810 lex_error (lexer, NULL);
817 lex_error (lexer, NULL);
822 qc.wv = dict_get_weight (dict);
825 struct casereader *group;
826 struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
828 while (casegrouper_get_next_group (grouper, &group))
830 if ( qc.missing_type == MISS_LISTWISE )
832 group = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
837 kmeans = kmeans_create (&qc);
838 kmeans_cluster (kmeans, group, &qc);
839 quick_cluster_show_results (kmeans, group, &qc);
840 kmeans_destroy (kmeans);
841 casereader_destroy (group);
843 ok = casegrouper_destroy (grouper);
845 ok = proc_commit (ds) && ok;