QUICK-CLUSTER: Seperate const from non-const data and make it handle splits
authorJohn Darrington <john@darrington.wattle.id.au>
Sat, 2 Jul 2011 14:12:57 +0000 (16:12 +0200)
committerJohn Darrington <john@darrington.wattle.id.au>
Sat, 2 Jul 2011 14:12:57 +0000 (16:12 +0200)
src/language/stats/quick-cluster.c

index 014406f098aa53adf20ef12136e30d800b751cdb..9adcc64243a70c50539909304f758d09488a4020 100644 (file)
 #define _(msgid) gettext (msgid)
 #define N_(msgid) msgid
 
+struct qc
+{
+  const struct variable **vars;
+  size_t n_vars;
+
+  int ngroups;                 /* Number of group. (Given by the user) */
+  int maxiter;                 /* Maximum iterations (Given by the user) */
+
+  const struct variable *wv;   /* Weighting variable. */
+};
+
 /* Holds all of the information for the functions.  int n, holds the number of
    observation and its default value is -1.  We set it in
    kmeans_recalculate_centers in first invocation. */
@@ -53,44 +64,39 @@ struct Kmeans
 {
   gsl_matrix *centers;         /* Centers for groups. */
   gsl_vector_long *num_elements_groups;
-  int ngroups;                 /* Number of group. (Given by the user) */
+
   casenumber n;                        /* Number of observations (default -1). */
-  int m;                       /* Number of variables. (Given by the user) */
-  int maxiter;                 /* Maximum iterations (Given by the user) */
+
   int lastiter;                        /* Iteration where it found the solution. */
   int trials;                  /* If not convergence, how many times has
                                    clustering done. */
   gsl_matrix *initial_centers; /* Initial random centers. */
-  const struct variable **variables;
+
   gsl_permutation *group_order;        /* Group order for reporting. */
-  struct casereader *original_casereader;
   struct caseproto *proto;
   struct casereader *index_rdr;        /* Group ids for each case. */
-  const struct variable *wv;   /* Weighting variable. */
 };
 
-static struct Kmeans *kmeans_create (struct casereader *cs,
-                                    const struct variable **variables,
-                                    int m, int ngroups, int maxiter);
+static struct Kmeans *kmeans_create (const struct qc *qc);
 
-static void kmeans_randomize_centers (struct Kmeans *kmeans);
+static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc);
 
-static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c);
+static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
 
-static void kmeans_recalculate_centers (struct Kmeans *kmeans);
+static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
 
 static int
-kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans);
+kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
 
-static void kmeans_order_groups (struct Kmeans *kmeans);
+static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
 
-static void kmeans_cluster (struct Kmeans *kmeans);
+static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
 
-static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial);
+static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
 
-static void quick_cluster_show_number_cases (struct Kmeans *kmeans);
+static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
 
-static void quick_cluster_show_results (struct Kmeans *kmeans);
+static void quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *);
 
 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
 
@@ -100,21 +106,15 @@ static void kmeans_destroy (struct Kmeans *kmeans);
    variables 'variables', number of cases 'n', number of variables 'm', number
    of clusters and amount of maximum iterations. */
 static struct Kmeans *
-kmeans_create (struct casereader *cs, const struct variable **variables,
-              int m, int ngroups, int maxiter)
+kmeans_create (const struct qc *qc)
 {
   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
-  kmeans->centers = gsl_matrix_alloc (ngroups, m);
-  kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
-  kmeans->ngroups = ngroups;
+  kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
+  kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
   kmeans->n = 0;
-  kmeans->m = m;
-  kmeans->maxiter = maxiter;
   kmeans->lastiter = 0;
   kmeans->trials = 0;
-  kmeans->variables = variables;
   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
-  kmeans->original_casereader = cs;
   kmeans->initial_centers = NULL;
 
   kmeans->proto = caseproto_create ();
@@ -135,23 +135,19 @@ kmeans_destroy (struct Kmeans *kmeans)
 
   caseproto_unref (kmeans->proto);
 
-  /*
-     These reader and writer were already destroyed.
-     free (kmeans->original_casereader);
-     free (kmeans->index_rdr);
-   */
+  casereader_destroy (kmeans->index_rdr);
 
   free (kmeans);
 }
 
 /* Creates random centers using randomly selected cases from the data. */
 static void
-kmeans_randomize_centers (struct Kmeans *kmeans)
+kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc)
 {
   int i, j;
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
-      for (j = 0; j < kmeans->m; j++)
+      for (j = 0; j < qc->n_vars; j++)
        {
          if (i == j)
            {
@@ -169,13 +165,13 @@ kmeans_randomize_centers (struct Kmeans *kmeans)
      here. */
   if (!kmeans->initial_centers)
     {
-      kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
+      kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
     }
 }
 
 static int
-kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
+kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
 {
   int result = -1;
   double x;
@@ -183,12 +179,12 @@ kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
   double dist;
   double mindist;
   mindist = INFINITY;
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
       dist = 0;
-      for (j = 0; j < kmeans->m; j++)
+      for (j = 0; j < qc->n_vars; j++)
        {
-         x = case_data (c, kmeans->variables[j])->f;
+         x = case_data (c, qc->vars[j])->f;
          dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
        }
       if (dist < mindist)
@@ -202,7 +198,7 @@ kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
 
 /* Re-calculate the cluster centers. */
 static void
-kmeans_recalculate_centers (struct Kmeans *kmeans)
+kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
 {
   casenumber i;
   int v, j;
@@ -212,28 +208,20 @@ kmeans_recalculate_centers (struct Kmeans *kmeans)
   struct casereader *cs;
   struct casereader *cs_index;
   int index;
-  double weight;
 
   i = 0;
-  cs = casereader_clone (kmeans->original_casereader);
+  cs = casereader_clone (reader);
   cs_index = casereader_clone (kmeans->index_rdr);
 
   gsl_matrix_set_all (kmeans->centers, 0.0);
   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
     {
+      double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
       c_index = casereader_read (cs_index);
       index = case_data_idx (c_index, 0)->f;
-      for (v = 0; v < kmeans->m; ++v)
+      for (v = 0; v < qc->n_vars; ++v)
        {
-         if (kmeans->wv)
-           {
-             weight = case_data (c, kmeans->wv)->f;
-           }
-         else
-           {
-             weight = 1.0;
-           }
-         x = case_data (c, kmeans->variables[v])->f * weight;
+         x = case_data (c, qc->vars[v])->f * weight;
          curval = gsl_matrix_get (kmeans->centers, index, v);
          gsl_matrix_set (kmeans->centers, index, v, curval + x);
        }
@@ -250,10 +238,10 @@ kmeans_recalculate_centers (struct Kmeans *kmeans)
   /* We got sum of each center but we need averages.
      We are dividing centers to numobs. This may be inefficient and
      we should check it again. */
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
       casenumber numobs = kmeans->num_elements_groups->data[i];
-      for (j = 0; j < kmeans->m; j++)
+      for (j = 0; j < qc->n_vars; j++)
        {
          if (numobs > 0)
            {
@@ -274,12 +262,11 @@ kmeans_recalculate_centers (struct Kmeans *kmeans)
    different cases of the new and old index variables.  If last two index
    variables are equal, there is no any enhancement of clustering. */
 static int
-kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
+kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
 {
   int totaldiff = 0;
-  double weight;
   struct ccase *c;
-  struct casereader *cs = casereader_clone (kmeans->original_casereader);
+  struct casereader *cs = casereader_clone (reader);
 
   /* A casewriter into which we will write the indexes. */
   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
@@ -290,15 +277,8 @@ kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
     {
       /* A case to hold the new index. */
       struct ccase *index_case_new = case_create (kmeans->proto);
-      int bestindex = kmeans_get_nearest_group (kmeans, c);
-      if (kmeans->wv)
-       {
-         weight = (casenumber) case_data (c, kmeans->wv)->f;
-       }
-      else
-       {
-         weight = 1.0;
-       }
+      int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
+      double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
       kmeans->num_elements_groups->data[bestindex] += weight;
       if (kmeans->index_rdr)
        {
@@ -338,17 +318,18 @@ kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
 }
 
 static void
-kmeans_order_groups (struct Kmeans *kmeans)
+kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
 {
-  gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
+  gsl_vector *v = gsl_vector_alloc (qc->ngroups);
   gsl_matrix_get_col (v, kmeans->centers, 0);
   gsl_sort_vector_index (kmeans->group_order, v);
+  gsl_vector_free (v);
 }
 
 /* Main algorithm.
    Does iterations, checks convergency. */
 static void
-kmeans_cluster (struct Kmeans *kmeans)
+kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
 {
   int i;
   bool redo;
@@ -358,13 +339,13 @@ kmeans_cluster (struct Kmeans *kmeans)
   show_warning1 = true;
 cluster:
   redo = false;
-  kmeans_randomize_centers (kmeans);
-  for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
+  kmeans_randomize_centers (kmeans, qc);
+  for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
        kmeans->lastiter++)
     {
-      diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
-      kmeans_recalculate_centers (kmeans);
-      if (show_warning1 && kmeans->ngroups > kmeans->n)
+      diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
+      kmeans_recalculate_centers (kmeans, reader, qc);
+      if (show_warning1 && qc->ngroups > kmeans->n)
        {
          msg (MW, _("Number of clusters may not be larger than the number "
                      "of cases."));
@@ -374,7 +355,7 @@ cluster:
        break;
     }
 
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
       if (kmeans->num_elements_groups->data[i] == 0)
        {
@@ -395,13 +376,13 @@ cluster:
    If initial is true, initial cluster centers are reported.  Otherwise,
    resulted centers are reported. */
 static void
-quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
+quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
 {
   struct tab_table *t;
   int nc, nr, heading_columns, currow;
   int i, j;
-  nc = kmeans->ngroups + 1;
-  nr = kmeans->m + 4;
+  nc = qc->ngroups + 1;
+  nr = qc->n_vars + 4;
   heading_columns = 1;
   t = tab_create (nc, nr);
   tab_headers (t, 0, nc - 1, 0, 1);
@@ -419,36 +400,36 @@ quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
   tab_hline (t, TAL_1, 1, nc - 1, 2);
   currow += 2;
 
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
     }
   currow++;
   tab_hline (t, TAL_1, 1, nc - 1, currow);
   currow++;
-  for (i = 0; i < kmeans->m; i++)
+  for (i = 0; i < qc->n_vars; i++)
     {
       tab_text (t, 0, currow + i, TAB_LEFT,
-               var_to_string (kmeans->variables[i]));
+               var_to_string (qc->vars[i]));
     }
 
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
-      for (j = 0; j < kmeans->m; j++)
+      for (j = 0; j < qc->n_vars; j++)
        {
          if (!initial)
            {
              tab_double (t, i + 1, j + 4, TAB_CENTER,
                          gsl_matrix_get (kmeans->centers,
                                          kmeans->group_order->data[i], j),
-                         var_get_print_format (kmeans->variables[j]));
+                         var_get_print_format (qc->vars[j]));
            }
          else
            {
              tab_double (t, i + 1, j + 4, TAB_CENTER,
                          gsl_matrix_get (kmeans->initial_centers,
                                          kmeans->group_order->data[i], j),
-                         var_get_print_format (kmeans->variables[j]));
+                         var_get_print_format (qc->vars[j]));
            }
        }
     }
@@ -457,14 +438,14 @@ quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
 
 /* Reports number of cases of each single cluster. */
 static void
-quick_cluster_show_number_cases (struct Kmeans *kmeans)
+quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
 {
   struct tab_table *t;
   int nc, nr;
   int i, numelem;
   long int total;
   nc = 3;
-  nr = kmeans->ngroups + 1;
+  nr = qc->ngroups + 1;
   t = tab_create (nc, nr);
   tab_headers (t, 0, nc - 1, 0, 0);
   tab_title (t, _("Number of Cases in each Cluster"));
@@ -472,7 +453,7 @@ quick_cluster_show_number_cases (struct Kmeans *kmeans)
   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
 
   total = 0;
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
       numelem =
@@ -481,38 +462,35 @@ quick_cluster_show_number_cases (struct Kmeans *kmeans)
       total += numelem;
     }
 
-  tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
-  tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
+  tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
+  tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
   tab_submit (t);
 }
 
 /* Reports. */
 static void
-quick_cluster_show_results (struct Kmeans *kmeans)
+quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *qc)
 {
-  kmeans_order_groups (kmeans);
+  kmeans_order_groups (kmeans, qc);
   /* Uncomment the line below for reporting initial centers. */
   /* quick_cluster_show_centers (kmeans, true); */
-  quick_cluster_show_centers (kmeans, false);
-  quick_cluster_show_number_cases (kmeans);
+  quick_cluster_show_centers (kmeans, false, qc);
+  quick_cluster_show_number_cases (kmeans, qc);
 }
 
 int
 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
 {
+  struct qc qc;
   struct Kmeans *kmeans;
   bool ok;
   const struct dictionary *dict = dataset_dict (ds);
-  const struct variable **variables;
-  struct casereader *cs;
-  int groups = 2;
-  int maxiter = 2;
-  size_t p;
+  qc.ngroups = 2;
+  qc.maxiter = 2;
 
-  if (!parse_variables_const (lexer, dict, &variables, &p,
+  if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
                              PV_NO_DUPLICATE | PV_NUMERIC))
     {
-      msg (ME, _("Variables cannot be parsed"));
       return (CMD_FAILURE);
     }
 
@@ -529,7 +507,7 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
                  if (lex_force_match (lexer, T_LPAREN))
                    {
                      lex_force_int (lexer);
-                     groups = lex_integer (lexer);
+                     qc.ngroups = lex_integer (lexer);
                      lex_get (lexer);
                      lex_force_match (lexer, T_RPAREN);
                    }
@@ -539,27 +517,40 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
                  if (lex_force_match (lexer, T_LPAREN))
                    {
                      lex_force_int (lexer);
-                     maxiter = lex_integer (lexer);
+                     qc.maxiter = lex_integer (lexer);
                      lex_get (lexer);
                      lex_force_match (lexer, T_RPAREN);
                    }
                }
              else
-                return CMD_FAILURE;
+                goto error;
            }
        }
     }
 
-  cs = proc_open (ds);
+  qc.wv = dict_get_weight (dict);
 
-  kmeans = kmeans_create (cs, variables, p, groups, maxiter);
+  {
+    struct casereader *group;
+    struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
 
-  kmeans->wv = dict_get_weight (dict);
-  kmeans_cluster (kmeans);
-  quick_cluster_show_results (kmeans);
-  ok = proc_commit (ds);
+    while (casegrouper_get_next_group (grouper, &group))
+      {
+       kmeans = kmeans_create (&qc);
+       kmeans_cluster (kmeans, group, &qc);
+       quick_cluster_show_results (kmeans, &qc);
+       kmeans_destroy (kmeans);
+       casereader_destroy (group);
+      }
+    ok = casegrouper_destroy (grouper);
+  }
+  ok = proc_commit (ds) && ok;
 
-  kmeans_destroy (kmeans);
+  free (qc.vars);
 
   return (ok);
+
+ error:
+  free (qc.vars);
+  return CMD_FAILURE;
 }