Added result_class parameter to tab_double and updated all callers. Removed tab_fixed
[pspp] / src / language / stats / quick-cluster.c
index 014406f098aa53adf20ef12136e30d800b751cdb..68b50123144e2b0b10c19d494d2e4b02b32ac430 100644 (file)
@@ -1,5 +1,5 @@
 /* PSPP - a program for statistical analysis.
-   Copyright (C) 2011 Free Software Foundation, Inc.
+   Copyright (C) 2011, 2012 Free Software Foundation, Inc.
 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
 #define _(msgid) gettext (msgid)
 #define N_(msgid) msgid
 
+enum missing_type
+  {
+    MISS_LISTWISE,
+    MISS_PAIRWISE,
+  };
+
+
+struct qc
+{
+  const struct variable **vars;
+  size_t n_vars;
+
+  int ngroups;                 /* Number of group. (Given by the user) */
+  int maxiter;                 /* Maximum iterations (Given by the user) */
+
+  const struct variable *wv;   /* Weighting variable. */
+
+  enum missing_type missing_type;
+  enum mv_class exclude;
+};
+
 /* Holds all of the information for the functions.  int n, holds the number of
    observation and its default value is -1.  We set it in
    kmeans_recalculate_centers in first invocation. */
@@ -53,44 +74,39 @@ struct Kmeans
 {
   gsl_matrix *centers;         /* Centers for groups. */
   gsl_vector_long *num_elements_groups;
-  int ngroups;                 /* Number of group. (Given by the user) */
+
   casenumber n;                        /* Number of observations (default -1). */
-  int m;                       /* Number of variables. (Given by the user) */
-  int maxiter;                 /* Maximum iterations (Given by the user) */
+
   int lastiter;                        /* Iteration where it found the solution. */
   int trials;                  /* If not convergence, how many times has
                                    clustering done. */
   gsl_matrix *initial_centers; /* Initial random centers. */
-  const struct variable **variables;
+
   gsl_permutation *group_order;        /* Group order for reporting. */
-  struct casereader *original_casereader;
   struct caseproto *proto;
   struct casereader *index_rdr;        /* Group ids for each case. */
-  const struct variable *wv;   /* Weighting variable. */
 };
 
-static struct Kmeans *kmeans_create (struct casereader *cs,
-                                    const struct variable **variables,
-                                    int m, int ngroups, int maxiter);
+static struct Kmeans *kmeans_create (const struct qc *qc);
 
-static void kmeans_randomize_centers (struct Kmeans *kmeans);
+static void kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc);
 
-static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c);
+static int kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *);
 
-static void kmeans_recalculate_centers (struct Kmeans *kmeans);
+static void kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
 
 static int
-kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans);
+kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *);
 
-static void kmeans_order_groups (struct Kmeans *kmeans);
+static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
 
-static void kmeans_cluster (struct Kmeans *kmeans);
+static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *);
 
-static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial);
+static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *);
 
-static void quick_cluster_show_number_cases (struct Kmeans *kmeans);
+static void quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *);
 
-static void quick_cluster_show_results (struct Kmeans *kmeans);
+static void quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *);
 
 int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
 
@@ -100,21 +116,15 @@ static void kmeans_destroy (struct Kmeans *kmeans);
    variables 'variables', number of cases 'n', number of variables 'm', number
    of clusters and amount of maximum iterations. */
 static struct Kmeans *
-kmeans_create (struct casereader *cs, const struct variable **variables,
-              int m, int ngroups, int maxiter)
+kmeans_create (const struct qc *qc)
 {
   struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
-  kmeans->centers = gsl_matrix_alloc (ngroups, m);
-  kmeans->num_elements_groups = gsl_vector_long_alloc (ngroups);
-  kmeans->ngroups = ngroups;
+  kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
+  kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
   kmeans->n = 0;
-  kmeans->m = m;
-  kmeans->maxiter = maxiter;
   kmeans->lastiter = 0;
   kmeans->trials = 0;
-  kmeans->variables = variables;
   kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
-  kmeans->original_casereader = cs;
   kmeans->initial_centers = NULL;
 
   kmeans->proto = caseproto_create ();
@@ -135,23 +145,19 @@ kmeans_destroy (struct Kmeans *kmeans)
 
   caseproto_unref (kmeans->proto);
 
-  /*
-     These reader and writer were already destroyed.
-     free (kmeans->original_casereader);
-     free (kmeans->index_rdr);
-   */
+  casereader_destroy (kmeans->index_rdr);
 
   free (kmeans);
 }
 
 /* Creates random centers using randomly selected cases from the data. */
 static void
-kmeans_randomize_centers (struct Kmeans *kmeans)
+kmeans_randomize_centers (struct Kmeans *kmeans, const struct qc *qc)
 {
   int i, j;
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
-      for (j = 0; j < kmeans->m; j++)
+      for (j = 0; j < qc->n_vars; j++)
        {
          if (i == j)
            {
@@ -169,27 +175,27 @@ kmeans_randomize_centers (struct Kmeans *kmeans)
      here. */
   if (!kmeans->initial_centers)
     {
-      kmeans->initial_centers = gsl_matrix_alloc (kmeans->ngroups, kmeans->m);
+      kmeans->initial_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
       gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
     }
 }
 
 static int
-kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
+kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c, const struct qc *qc)
 {
   int result = -1;
-  double x;
   int i, j;
-  double dist;
-  double mindist;
-  mindist = INFINITY;
-  for (i = 0; i < kmeans->ngroups; i++)
+  double mindist = INFINITY;
+  for (i = 0; i < qc->ngroups; i++)
     {
-      dist = 0;
-      for (j = 0; j < kmeans->m; j++)
+      double dist = 0;
+      for (j = 0; j < qc->n_vars; j++)
        {
-         x = case_data (c, kmeans->variables[j])->f;
-         dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - x);
+         const union value *val = case_data (c, qc->vars[j]);
+         if ( var_is_value_missing (qc->vars[j], val, qc->exclude))
+           continue;
+
+         dist += pow2 (gsl_matrix_get (kmeans->centers, i, j) - val->f);
        }
       if (dist < mindist)
        {
@@ -202,38 +208,30 @@ kmeans_get_nearest_group (struct Kmeans *kmeans, struct ccase *c)
 
 /* Re-calculate the cluster centers. */
 static void
-kmeans_recalculate_centers (struct Kmeans *kmeans)
+kmeans_recalculate_centers (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
 {
-  casenumber i;
+  casenumber i = 0;
   int v, j;
-  double x, curval;
   struct ccase *c;
-  struct ccase *c_index;
-  struct casereader *cs;
-  struct casereader *cs_index;
-  int index;
-  double weight;
 
-  i = 0;
-  cs = casereader_clone (kmeans->original_casereader);
-  cs_index = casereader_clone (kmeans->index_rdr);
+  struct casereader *cs = casereader_clone (reader);
+  struct casereader *cs_index = casereader_clone (kmeans->index_rdr);
 
   gsl_matrix_set_all (kmeans->centers, 0.0);
   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
     {
-      c_index = casereader_read (cs_index);
-      index = case_data_idx (c_index, 0)->f;
-      for (v = 0; v < kmeans->m; ++v)
+      double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
+      struct ccase *c_index = casereader_read (cs_index);
+      int index = case_data_idx (c_index, 0)->f;
+      for (v = 0; v < qc->n_vars; ++v)
        {
-         if (kmeans->wv)
-           {
-             weight = case_data (c, kmeans->wv)->f;
-           }
-         else
-           {
-             weight = 1.0;
-           }
-         x = case_data (c, kmeans->variables[v])->f * weight;
+         const union value *val = case_data (c, qc->vars[v]);
+         double x = val->f * weight;
+         double curval;
+
+         if ( var_is_value_missing (qc->vars[v], val, qc->exclude))
+           continue;
+
          curval = gsl_matrix_get (kmeans->centers, index, v);
          gsl_matrix_set (kmeans->centers, index, v, curval + x);
        }
@@ -250,10 +248,10 @@ kmeans_recalculate_centers (struct Kmeans *kmeans)
   /* We got sum of each center but we need averages.
      We are dividing centers to numobs. This may be inefficient and
      we should check it again. */
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
       casenumber numobs = kmeans->num_elements_groups->data[i];
-      for (j = 0; j < kmeans->m; j++)
+      for (j = 0; j < qc->n_vars; j++)
        {
          if (numobs > 0)
            {
@@ -274,12 +272,11 @@ kmeans_recalculate_centers (struct Kmeans *kmeans)
    different cases of the new and old index variables.  If last two index
    variables are equal, there is no any enhancement of clustering. */
 static int
-kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
+kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans, const struct casereader *reader, const struct qc *qc)
 {
   int totaldiff = 0;
-  double weight;
   struct ccase *c;
-  struct casereader *cs = casereader_clone (kmeans->original_casereader);
+  struct casereader *cs = casereader_clone (reader);
 
   /* A casewriter into which we will write the indexes. */
   struct casewriter *index_wtr = autopaging_writer_create (kmeans->proto);
@@ -290,15 +287,9 @@ kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
     {
       /* A case to hold the new index. */
       struct ccase *index_case_new = case_create (kmeans->proto);
-      int bestindex = kmeans_get_nearest_group (kmeans, c);
-      if (kmeans->wv)
-       {
-         weight = (casenumber) case_data (c, kmeans->wv)->f;
-       }
-      else
-       {
-         weight = 1.0;
-       }
+      int bestindex = kmeans_get_nearest_group (kmeans, c, qc);
+      double weight = qc->wv ? case_data (c, qc->wv)->f : 1.0;
+      assert (bestindex < kmeans->num_elements_groups->size);
       kmeans->num_elements_groups->data[bestindex] += weight;
       if (kmeans->index_rdr)
        {
@@ -338,17 +329,18 @@ kmeans_calculate_indexes_and_check_convergence (struct Kmeans *kmeans)
 }
 
 static void
-kmeans_order_groups (struct Kmeans *kmeans)
+kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
 {
-  gsl_vector *v = gsl_vector_alloc (kmeans->ngroups);
+  gsl_vector *v = gsl_vector_alloc (qc->ngroups);
   gsl_matrix_get_col (v, kmeans->centers, 0);
   gsl_sort_vector_index (kmeans->group_order, v);
+  gsl_vector_free (v);
 }
 
 /* Main algorithm.
    Does iterations, checks convergency. */
 static void
-kmeans_cluster (struct Kmeans *kmeans)
+kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader, const struct qc *qc)
 {
   int i;
   bool redo;
@@ -358,13 +350,13 @@ kmeans_cluster (struct Kmeans *kmeans)
   show_warning1 = true;
 cluster:
   redo = false;
-  kmeans_randomize_centers (kmeans);
-  for (kmeans->lastiter = 0; kmeans->lastiter < kmeans->maxiter;
+  kmeans_randomize_centers (kmeans, qc);
+  for (kmeans->lastiter = 0; kmeans->lastiter < qc->maxiter;
        kmeans->lastiter++)
     {
-      diffs = kmeans_calculate_indexes_and_check_convergence (kmeans);
-      kmeans_recalculate_centers (kmeans);
-      if (show_warning1 && kmeans->ngroups > kmeans->n)
+      diffs = kmeans_calculate_indexes_and_check_convergence (kmeans, reader, qc);
+      kmeans_recalculate_centers (kmeans, reader, qc);
+      if (show_warning1 && qc->ngroups > kmeans->n)
        {
          msg (MW, _("Number of clusters may not be larger than the number "
                      "of cases."));
@@ -374,7 +366,7 @@ cluster:
        break;
     }
 
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
       if (kmeans->num_elements_groups->data[i] == 0)
        {
@@ -395,14 +387,13 @@ cluster:
    If initial is true, initial cluster centers are reported.  Otherwise,
    resulted centers are reported. */
 static void
-quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
+quick_cluster_show_centers (struct Kmeans *kmeans, bool initial, const struct qc *qc)
 {
   struct tab_table *t;
-  int nc, nr, heading_columns, currow;
+  int nc, nr, currow;
   int i, j;
-  nc = kmeans->ngroups + 1;
-  nr = kmeans->m + 4;
-  heading_columns = 1;
+  nc = qc->ngroups + 1;
+  nr = qc->n_vars + 4;
   t = tab_create (nc, nr);
   tab_headers (t, 0, nc - 1, 0, 1);
   currow = 0;
@@ -419,36 +410,36 @@ quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
   tab_hline (t, TAL_1, 1, nc - 1, 2);
   currow += 2;
 
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
       tab_text_format (t, (i + 1), currow, TAB_CENTER, "%d", (i + 1));
     }
   currow++;
   tab_hline (t, TAL_1, 1, nc - 1, currow);
   currow++;
-  for (i = 0; i < kmeans->m; i++)
+  for (i = 0; i < qc->n_vars; i++)
     {
       tab_text (t, 0, currow + i, TAB_LEFT,
-               var_to_string (kmeans->variables[i]));
+               var_to_string (qc->vars[i]));
     }
 
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
-      for (j = 0; j < kmeans->m; j++)
+      for (j = 0; j < qc->n_vars; j++)
        {
          if (!initial)
            {
              tab_double (t, i + 1, j + 4, TAB_CENTER,
                          gsl_matrix_get (kmeans->centers,
                                          kmeans->group_order->data[i], j),
-                         var_get_print_format (kmeans->variables[j]));
+                         var_get_print_format (qc->vars[j]), RC_OTHER);
            }
          else
            {
              tab_double (t, i + 1, j + 4, TAB_CENTER,
                          gsl_matrix_get (kmeans->initial_centers,
                                          kmeans->group_order->data[i], j),
-                         var_get_print_format (kmeans->variables[j]));
+                         var_get_print_format (qc->vars[j]), RC_OTHER);
            }
        }
     }
@@ -457,14 +448,14 @@ quick_cluster_show_centers (struct Kmeans *kmeans, bool initial)
 
 /* Reports number of cases of each single cluster. */
 static void
-quick_cluster_show_number_cases (struct Kmeans *kmeans)
+quick_cluster_show_number_cases (struct Kmeans *kmeans, const struct qc *qc)
 {
   struct tab_table *t;
   int nc, nr;
   int i, numelem;
   long int total;
   nc = 3;
-  nr = kmeans->ngroups + 1;
+  nr = qc->ngroups + 1;
   t = tab_create (nc, nr);
   tab_headers (t, 0, nc - 1, 0, 0);
   tab_title (t, _("Number of Cases in each Cluster"));
@@ -472,7 +463,7 @@ quick_cluster_show_number_cases (struct Kmeans *kmeans)
   tab_text (t, 0, 0, TAB_LEFT, _("Cluster"));
 
   total = 0;
-  for (i = 0; i < kmeans->ngroups; i++)
+  for (i = 0; i < qc->ngroups; i++)
     {
       tab_text_format (t, 1, i, TAB_CENTER, "%d", (i + 1));
       numelem =
@@ -481,44 +472,71 @@ quick_cluster_show_number_cases (struct Kmeans *kmeans)
       total += numelem;
     }
 
-  tab_text (t, 0, kmeans->ngroups, TAB_LEFT, _("Valid"));
-  tab_text_format (t, 2, kmeans->ngroups, TAB_LEFT, "%ld", total);
+  tab_text (t, 0, qc->ngroups, TAB_LEFT, _("Valid"));
+  tab_text_format (t, 2, qc->ngroups, TAB_LEFT, "%ld", total);
   tab_submit (t);
 }
 
 /* Reports. */
 static void
-quick_cluster_show_results (struct Kmeans *kmeans)
+quick_cluster_show_results (struct Kmeans *kmeans, const struct qc *qc)
 {
-  kmeans_order_groups (kmeans);
+  kmeans_order_groups (kmeans, qc);
   /* Uncomment the line below for reporting initial centers. */
   /* quick_cluster_show_centers (kmeans, true); */
-  quick_cluster_show_centers (kmeans, false);
-  quick_cluster_show_number_cases (kmeans);
+  quick_cluster_show_centers (kmeans, false, qc);
+  quick_cluster_show_number_cases (kmeans, qc);
 }
 
 int
 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
 {
+  struct qc qc;
   struct Kmeans *kmeans;
   bool ok;
   const struct dictionary *dict = dataset_dict (ds);
-  const struct variable **variables;
-  struct casereader *cs;
-  int groups = 2;
-  int maxiter = 2;
-  size_t p;
+  qc.ngroups = 2;
+  qc.maxiter = 2;
+  qc.missing_type = MISS_LISTWISE;
+  qc.exclude = MV_ANY;
 
-  if (!parse_variables_const (lexer, dict, &variables, &p,
+  if (!parse_variables_const (lexer, dict, &qc.vars, &qc.n_vars,
                              PV_NO_DUPLICATE | PV_NUMERIC))
     {
-      msg (ME, _("Variables cannot be parsed"));
       return (CMD_FAILURE);
     }
 
-  if (lex_match (lexer, T_SLASH))
+  while (lex_token (lexer) != T_ENDCMD)
     {
-      if (lex_match_id (lexer, "CRITERIA"))
+      lex_match (lexer, T_SLASH);
+
+      if (lex_match_id (lexer, "MISSING"))
+       {
+         lex_match (lexer, T_EQUALS);
+         while (lex_token (lexer) != T_ENDCMD
+                && lex_token (lexer) != T_SLASH)
+           {
+             if (lex_match_id (lexer, "LISTWISE") || lex_match_id (lexer, "DEFAULT"))
+               {
+                 qc.missing_type = MISS_LISTWISE;
+               }
+             else if (lex_match_id (lexer, "PAIRWISE"))
+               {
+                 qc.missing_type = MISS_PAIRWISE;
+               }
+             else if (lex_match_id (lexer, "INCLUDE"))
+               {
+                 qc.exclude = MV_SYSTEM;
+               }
+             else if (lex_match_id (lexer, "EXCLUDE"))
+               {
+                 qc.exclude = MV_ANY;
+               }
+             else
+               goto error;
+           }     
+       }
+      else if (lex_match_id (lexer, "CRITERIA"))
        {
          lex_match (lexer, T_EQUALS);
          while (lex_token (lexer) != T_ENDCMD
@@ -529,7 +547,12 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
                  if (lex_force_match (lexer, T_LPAREN))
                    {
                      lex_force_int (lexer);
-                     groups = lex_integer (lexer);
+                     qc.ngroups = lex_integer (lexer);
+                     if (qc.ngroups <= 0)
+                       {
+                         lex_error (lexer, _("The number of clusters must be positive"));
+                         goto error;
+                       }
                      lex_get (lexer);
                      lex_force_match (lexer, T_RPAREN);
                    }
@@ -539,27 +562,52 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
                  if (lex_force_match (lexer, T_LPAREN))
                    {
                      lex_force_int (lexer);
-                     maxiter = lex_integer (lexer);
+                     qc.maxiter = lex_integer (lexer);
+                     if (qc.maxiter <= 0)
+                       {
+                         lex_error (lexer, _("The number of iterations must be positive"));
+                         goto error;
+                       }
                      lex_get (lexer);
                      lex_force_match (lexer, T_RPAREN);
                    }
                }
              else
-                return CMD_FAILURE;
+                goto error;
            }
        }
     }
 
-  cs = proc_open (ds);
-
-  kmeans = kmeans_create (cs, variables, p, groups, maxiter);
-
-  kmeans->wv = dict_get_weight (dict);
-  kmeans_cluster (kmeans);
-  quick_cluster_show_results (kmeans);
-  ok = proc_commit (ds);
-
-  kmeans_destroy (kmeans);
+  qc.wv = dict_get_weight (dict);
+
+  {
+    struct casereader *group;
+    struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), dict);
+
+    while (casegrouper_get_next_group (grouper, &group))
+      {
+       if ( qc.missing_type == MISS_LISTWISE )
+         {
+           group  = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
+                                                    qc.exclude,
+                                                    NULL,  NULL);
+         }
+
+       kmeans = kmeans_create (&qc);
+       kmeans_cluster (kmeans, group, &qc);
+       quick_cluster_show_results (kmeans, &qc);
+       kmeans_destroy (kmeans);
+       casereader_destroy (group);
+      }
+    ok = casegrouper_destroy (grouper);
+  }
+  ok = proc_commit (ds) && ok;
+
+  free (qc.vars);
 
   return (ok);
+
+ error:
+  free (qc.vars);
+  return CMD_FAILURE;
 }