Quick Cluster: Print an error instead of failing silently
[pspp] / src / language / stats / quick-cluster.c
index 946181b01a6b22e3061a4b26c961ed6f88236bd7..28ceea3e733329173c03bcbbd54cdeadead2cb35 100644 (file)
@@ -1,5 +1,5 @@
 /* PSPP - a program for statistical analysis.
-   Copyright (C) 2011, 2012 Free Software Foundation, Inc.
+   Copyright (C) 2011, 2012, 2015 Free Software Foundation, Inc.
 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
@@ -60,6 +60,8 @@ struct qc
 
   int ngroups;                 /* Number of group. (Given by the user) */
   int maxiter;                 /* Maximum iterations (Given by the user) */
+  int print_cluster_membership; /* true => print membership */
+  int print_initial_clusters;   /* true => print initial cluster */
 
   const struct variable *wv;   /* Weighting variable. */
 
@@ -89,7 +91,7 @@ struct Kmeans
 
 static struct Kmeans *kmeans_create (const struct qc *qc);
 
-static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc);
+static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc);
 
 static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
 
@@ -104,9 +106,11 @@ static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, co
 
 static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
 
+static void quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
+
 static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
 
-static void quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *);
+static void quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
 
 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
 
@@ -152,7 +156,7 @@ kmeans_destroy (struct Kmeans *kmeans)
 
 /* Creates random centers using randomly selected cases from the data. */
 static void
-kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc)
+kmeans_randomize_centers (struct Kmeans *kmeans, const struct casereader *reader UNUSED, const struct qc *qc)
 {
   int i, j;
   for (i = 0; i < qc->ngroups; i++)
@@ -346,11 +350,12 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct q
   bool redo;
   int diffs;
   bool show_warning1;
+  int redo_count = 0;
 
   show_warning1 = true;
-cluster:
+ cluster:
   redo = false;
-  kmeans_randomize_centers (kmeans, qc);
+  kmeans_randomize_centers (kmeans, reader, qc);
   for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
        kmeans->lastiter++)
     {
@@ -377,8 +382,13 @@ cluster:
          break;
        }
     }
+
   if (redo)
-    goto cluster;
+    {
+      redo_count++;
+      assert (redo_count < 10);
+      goto cluster;
+    }
 
 }
 
@@ -432,20 +442,54 @@ quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc
              tab_double (t, i + 1, j + 4, TAB_CENTER,
                          gsl_matrix_get (kmeans->centers,
                                          kmeans->group_order->data[i], j),
-                         var_get_print_format (qc->vars[j]));
+                         var_get_print_format (qc->vars[j]), RC_OTHER);
            }
          else
            {
              tab_double (t, i + 1, j + 4, TAB_CENTER,
                          gsl_matrix_get (kmeans->initial_centers,
                                          kmeans->group_order->data[i], j),
-                         var_get_print_format (qc->vars[j]));
+                         var_get_print_format (qc->vars[j]), RC_OTHER);
            }
        }
     }
   tab_submit (t);
 }
 
+/* Reports cluster membership for each case. */
+static void
+quick_cluster_show_membership (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
+{
+  struct tab_table *t;
+  int nc, nr;
+  int i, clust; 
+  struct ccase *c;
+  struct casereader *cs = casereader_clone (reader);
+  nc = 2;
+  nr = kmeans->n + 1;
+  t = tab_create (nc, nr);
+  tab_headers (t, 0, nc - 1, 0, 0);
+  tab_title (t, _("Cluster Membership"));
+  tab_text (t, 0, 0, TAB_CENTER, _("Case Number"));
+  tab_text (t, 1, 0, TAB_CENTER, _("Cluster"));
+  tab_box (t, TAL_2, TAL_2, TAL_0, TAL_1, 0, 0, nc - 1, nr - 1);
+  tab_hline (t, TAL_1, 0, nc - 1, 1);
+
+
+  for (i = 0; (c = casereader_read (cs)) != NULL; i++, case_unref (c))
+    {
+      assert (i < kmeans->n);
+      clust = kmeans_get_nearest_group (kmeans, c, qc);
+      clust = kmeans->group_order->data[clust];
+      tab_text_format (t, 0, i+1, TAB_CENTER, "%d", (i + 1));
+      tab_text_format (t, 1, i+1, TAB_CENTER, "%d", (clust + 1));
+    }
+  assert (i == kmeans->n);
+  tab_submit (t);
+  casereader_destroy (cs);
+}
+
+
 /* Reports number of cases of each single cluster. */
 static void
 quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
@@ -479,13 +523,15 @@ quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
 
 /* Reports. */
 static void
-quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *qc)
+quick_cluster_show_results (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
 {
-  kmeans_order_groups (kmeans, qc);
-  /* Uncomment the line below for reporting initial centers. */
-  /* quick_cluster_show_centers (kmeans, true); */
+  kmeans_order_groups (kmeans, qc); /* what does this do? */
+  if( qc->print_initial_clusters )
+    quick_cluster_show_centers (kmeans, true, qc);
   quick_cluster_show_centers (kmeans, false, qc);
   quick_cluster_show_number_cases (kmeans, qc);
+  if( qc->print_cluster_membership )
+    quick_cluster_show_membership(kmeans, reader, qc);
 }
 
 int
@@ -499,6 +545,8 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
   qc.maxiter = 2;
   qc.missing_type = MISS_LISTWISE;
   qc.exclude = MV_ANY;
+  qc.print_cluster_membership = false; /* default = do not output case cluster membership */
+  qc.print_initial_clusters = false;   /* default = do not print initial clusters */
 
   if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
                              PV_NO_DUPLICATE | PV_NUMERIC))
@@ -533,9 +581,29 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
                  qc.exclude = MV_ANY;
                }
              else
-               goto error;
+               {
+                 lex_error (lexer, NULL);
+                 goto error;
+               }
            }     
        }
+      else if (lex_match_id (lexer, "PRINT"))
+       {
+         lex_match (lexer, T_EQUALS);
+         while (lex_token (lexer) != T_ENDCMD
+                && lex_token (lexer) != T_SLASH)
+           {
+             if (lex_match_id (lexer, "CLUSTER"))
+                qc.print_cluster_membership = true;
+             else if (lex_match_id (lexer, "INITIAL"))
+               qc.print_initial_clusters = true;
+             else
+               {
+                 lex_error (lexer, NULL);
+                 goto error;
+               }
+           }
+       }
       else if (lex_match_id (lexer, "CRITERIA"))
        {
          lex_match (lexer, T_EQUALS);
@@ -573,9 +641,17 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
                    }
                }
              else
-                goto error;
+               {
+                 lex_error (lexer, NULL);
+                 goto error;
+               }
            }
        }
+      else
+        {
+          lex_error (lexer, NULL);
+          goto error;
+        }
     }
 
   qc.wv = dict_get_weight (dict);
@@ -589,13 +665,13 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
        if ( qc.missing_type == MISS_LISTWISE )
          {
            group  = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
-                                                    qc.exclude,
-                                                    NULL,  NULL);
+                                                      qc.exclude,
+                                                      NULL,  NULL);
          }
 
        kmeans = kmeans_create (&qc);
        kmeans_cluster (kmeans, group, &qc);
-       quick_cluster_show_results (kmeans, &qc);
+       quick_cluster_show_results (kmeans, group, &qc);
        kmeans_destroy (kmeans);
        casereader_destroy (group);
       }