QUICK CLUSTER: Improve error messages and coding style.
authorBen Pfaff <blp@cs.stanford.edu>
Sun, 20 Nov 2022 01:00:57 +0000 (17:00 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sun, 20 Nov 2022 01:19:00 +0000 (17:19 -0800)
src/language/stats/quick-cluster.c
tests/language/stats/quick-cluster.at

index 161b9768b1fc431d8cd71ad9169062e13431e543..4f512b475661819ca1b08a6de551595fccf21bbf 100644 (file)
@@ -54,107 +54,105 @@ enum missing_type
 
 
 struct save_trans_data
-{
-  /* A writer which contains the values (if any) to be appended to
-     each case in the active dataset   */
-  struct casewriter *writer;
-
-  /* A reader created from the writer above. */
-  struct casereader *appending_reader;
+  {
+    /* A writer which contains the values (if any) to be appended to
+       each case in the active dataset   */
+    struct casewriter *writer;
 
-  /* The indices to be used to access values in the above,
-     reader/writer  */
-  int CASE_IDX_MEMBERSHIP;
-  int CASE_IDX_DISTANCE;
+    /* A reader created from the writer above. */
+    struct casereader *appending_reader;
 
-  /* The variables created to hold the values appended to the dataset  */
-  struct variable *membership;
-  struct variable *distance;
-};
+    /* The indices to be used to access values in the above,
+       reader/writer  */
+    int membership_case_idx;
+    int distance_case_idx;
 
+    /* The variables created to hold the values appended to the dataset  */
+    struct variable *membership;
+    struct variable *distance;
+  };
 
-#define SAVE_MEMBERSHIP 0x1
-#define SAVE_DISTANCE   0x2
 
 struct qc
-{
-  struct dataset *dataset;
-  struct dictionary *dict;
+  {
+    struct dataset *dataset;
+    struct dictionary *dict;
 
-  const struct variable **vars;
-  size_t n_vars;
+    const struct variable **vars;
+    size_t n_vars;
 
-  double epsilon;               /* The convergence criterion */
+    double epsilon;               /* The convergence criterion */
 
-  int ngroups;                 /* Number of group. (Given by the user) */
-  int maxiter;                 /* Maximum iterations (Given by the user) */
-  bool print_cluster_membership; /* true => print membership */
-  bool print_initial_clusters;   /* true => print initial cluster */
-  bool no_initial;              /* true => simplified initial cluster selection */
-  bool no_update;               /* true => do not iterate  */
+    int ngroups;                       /* Number of group. (Given by the user) */
+    int maxiter;                       /* Maximum iterations (Given by the user) */
+    bool print_cluster_membership; /* true => print membership */
+    bool print_initial_clusters;   /* true => print initial cluster */
+    bool initial;             /* false => simplified initial cluster selection */
+    bool update;               /* false => do not iterate  */
 
-  const struct variable *wv;   /* Weighting variable. */
+    const struct variable *wv; /* Weighting variable. */
 
-  enum missing_type missing_type;
-  enum mv_class exclude;
+    enum missing_type missing_type;
+    enum mv_class exclude;
 
-  /* Which values are to be saved?  */
-  int save_values;
+    /* Which values are to be saved?  */
+    bool save_membership;
+    bool save_distance;
 
-  /* The name of the new variable to contain the cluster of each case.  */
-  char *var_membership;
+    /* The name of the new variable to contain the cluster of each case.  */
+    char *var_membership;
 
-  /* The name of the new variable to contain the distance of each case
-     from its cluster centre.  */
-  char *var_distance;
+    /* The name of the new variable to contain the distance of each case
+       from its cluster centre.  */
+    char *var_distance;
 
-  struct save_trans_data *save_trans_data;
-};
+    struct save_trans_data *save_trans_data;
+  };
 
 /* Holds all of the information for the functions.  int n, holds the number of
    observation and its default value is -1.  We set it in
    kmeans_recalculate_centers in first invocation. */
 struct Kmeans
-{
-  gsl_matrix *centers;         /* Centers for groups. */
-  gsl_matrix *updated_centers;
-  casenumber n;
+  {
+    gsl_matrix *centers;               /* Centers for groups. */
+    gsl_matrix *updated_centers;
+    casenumber n;
 
-  gsl_vector_long *num_elements_groups;
+    gsl_vector_long *num_elements_groups;
 
-  gsl_matrix *initial_centers; /* Initial random centers. */
-  double convergence_criteria;
-  gsl_permutation *group_order;        /* Group order for reporting. */
-};
+    gsl_matrix *initial_centers;       /* Initial random centers. */
+    double convergence_criteria;
+    gsl_permutation *group_order;      /* Group order for reporting. */
+  };
 
-static struct Kmeans *kmeans_create (const struct qc *qc);
+static struct Kmeans *kmeans_create (const struct qc *);
 
-static void kmeans_get_nearest_group (const struct Kmeans *kmeans,
-                                     struct ccase *c, const struct qc *,
+static void kmeans_get_nearest_group (const struct Kmeans *,
+                                     struct ccase *, const struct qc *,
                                      int *, double *, int *, double *);
 
-static void kmeans_order_groups (struct Kmeans *kmeans, const struct qc *);
+static void kmeans_order_groups (struct Kmeans *, const struct qc *);
 
-static void kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader,
+static void kmeans_cluster (struct Kmeans *, struct casereader *,
                            const struct qc *);
 
-static void quick_cluster_show_centers (struct Kmeans *kmeans, bool initial,
+static void quick_cluster_show_centers (struct Kmeans *, bool initial,
                                        const struct qc *);
 
-static void quick_cluster_show_membership (struct Kmeans *kmeans,
-                                          const struct casereader *reader,
+static void quick_cluster_show_membership (struct Kmeans *,
+                                          const struct casereader *,
                                           struct qc *);
 
-static void quick_cluster_show_number_cases (struct Kmeans *kmeans,
+static void quick_cluster_show_number_cases (struct Kmeans *,
                                             const struct qc *);
 
-static void quick_cluster_show_results (struct Kmeans *kmeans,
-                                       const struct casereader *reader,
+static void quick_cluster_show_results (struct Kmeans *,
+                                       const struct casereader *,
                                        struct qc *);
 
-int cmd_quick_cluster (struct lexer *lexer, struct dataset *ds);
+int cmd_quick_cluster (struct lexer *, struct dataset *);
 
-static void kmeans_destroy (struct Kmeans *kmeans);
+static void kmeans_destroy (struct Kmeans *);
 
 /* Creates and returns a struct of Kmeans with given casereader 'cs', parsed
    variables 'variables', number of cases 'n', number of variables 'm', number
@@ -162,14 +160,14 @@ static void kmeans_destroy (struct Kmeans *kmeans);
 static struct Kmeans *
 kmeans_create (const struct qc *qc)
 {
-  struct Kmeans *kmeans = xmalloc (sizeof (struct Kmeans));
-  kmeans->centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
-  kmeans->updated_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars);
-  kmeans->num_elements_groups = gsl_vector_long_alloc (qc->ngroups);
-  kmeans->group_order = gsl_permutation_alloc (kmeans->centers->size1);
-  kmeans->initial_centers = NULL;
-
-  return (kmeans);
+  struct Kmeans *kmeans = xmalloc (sizeof *kmeans);
+  *kmeans = (struct Kmeans) {
+    .centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars),
+    .updated_centers = gsl_matrix_alloc (qc->ngroups, qc->n_vars),
+    .num_elements_groups = gsl_vector_long_alloc (qc->ngroups),
+    .group_order = gsl_permutation_alloc (qc->ngroups),
+  };
+  return kmeans;
 }
 
 static void
@@ -189,15 +187,12 @@ kmeans_destroy (struct Kmeans *kmeans)
 static double
 diff_matrix (const gsl_matrix *m1, const gsl_matrix *m2)
 {
-  int i,j;
   double max_diff = -INFINITY;
-  for (i = 0; i < m1->size1; ++i)
+  for (size_t i = 0; i < m1->size1; ++i)
     {
       double diff = 0;
-      for (j = 0; j < m1->size2; ++j)
-       {
-         diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j));
-       }
+      for (size_t j = 0; j < m1->size2; ++j)
+        diff += pow2 (gsl_matrix_get (m1,i,j) - gsl_matrix_get (m2,i,j));
       if (diff > max_diff)
        max_diff = diff;
     }
@@ -210,46 +205,35 @@ diff_matrix (const gsl_matrix *m1, const gsl_matrix *m2)
 static double
 matrix_mindist (const gsl_matrix *m, int *mn, int *mm)
 {
-  int i, j;
   double mindist = INFINITY;
-  for (i = 0; i < m->size1 - 1; ++i)
-    {
-      for (j = i + 1; j < m->size1; ++j)
-       {
-         int k;
-         double diff_sq = 0;
-         for (k = 0; k < m->size2; ++k)
-           {
-             diff_sq += pow2 (gsl_matrix_get (m, j, k) - gsl_matrix_get (m, i, k));
-           }
-         if (diff_sq < mindist)
-           {
-             mindist = diff_sq;
-             if (mn)
-               *mn = i;
-             if (mm)
-               *mm = j;
-           }
-       }
-    }
-
+  for (size_t i = 0; i + 1 < m->size1; ++i)
+    for (size_t j = i + 1; j < m->size1; ++j)
+      {
+        double diff_sq = 0;
+        for (size_t k = 0; k < m->size2; ++k)
+          diff_sq += pow2 (gsl_matrix_get (m, j, k) - gsl_matrix_get (m, i, k));
+        if (diff_sq < mindist)
+          {
+            mindist = diff_sq;
+            if (mn)
+              *mn = i;
+            if (mm)
+              *mm = j;
+          }
+      }
   return mindist;
 }
 
-
 /* Return the distance of C from the group whose index is WHICH */
 static double
 dist_from_case (const struct Kmeans *kmeans, const struct ccase *c,
                const struct qc *qc, int which)
 {
-  int j;
   double dist = 0;
-  for (j = 0; j < qc->n_vars; j++)
+  for (size_t j = 0; j < qc->n_vars; j++)
     {
       const union value *val = case_data (c, qc->vars[j]);
-      if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
-       NOT_REACHED ();
-
+      assert (!(var_is_value_missing (qc->vars[j], val) & qc->exclude));
       dist += pow2 (gsl_matrix_get (kmeans->centers, which, j) - val->f);
     }
 
@@ -260,46 +244,38 @@ dist_from_case (const struct Kmeans *kmeans, const struct ccase *c,
 static double
 min_dist_from (const struct Kmeans *kmeans, const struct qc *qc, int which)
 {
-  int j, i;
-
-  double mindist = INFINITY;
-  for (i = 0; i < qc->ngroups; i++)
+   double mindist = INFINITY;
+  for (size_t i = 0; i < qc->ngroups; i++)
     {
       if (i == which)
        continue;
 
       double dist = 0;
-      for (j = 0; j < qc->n_vars; j++)
-       {
-         dist += pow2 (gsl_matrix_get (kmeans->centers, i, j)
-                       - gsl_matrix_get (kmeans->centers, which, j));
-       }
+      for (size_t j = 0; j < qc->n_vars; j++)
+        dist += pow2 (gsl_matrix_get (kmeans->centers, i, j)
+                      - gsl_matrix_get (kmeans->centers, which, j));
 
       if (dist < mindist)
-       {
-         mindist = dist;
-       }
+        mindist = dist;
     }
 
   return mindist;
 }
 
-
-
 /* Calculate the initial cluster centers. */
 static void
 kmeans_initial_centers (struct Kmeans *kmeans,
                        const struct casereader *reader,
                        const struct qc *qc)
 {
-  struct ccase *c;
-  int nc = 0, j;
+  int nc = 0;
 
   struct casereader *cs = casereader_clone (reader);
+  struct ccase *c;
   for (; (c = casereader_read (cs)) != NULL; case_unref (c))
     {
       bool missing = false;
-      for (j = 0; j < qc->n_vars; ++j)
+      for (size_t j = 0; j < qc->n_vars; ++j)
        {
          const union value *val = case_data (c, qc->vars[j]);
          if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
@@ -311,21 +287,19 @@ kmeans_initial_centers (struct Kmeans *kmeans,
          if (nc < qc->ngroups)
            gsl_matrix_set (kmeans->centers, nc, j, val->f);
        }
-
       if (missing)
        continue;
 
       if (nc++ < qc->ngroups)
        continue;
 
-      if (!qc->no_initial)
+      if (qc->initial)
        {
-         int mq, mp;
-         double delta;
-
          int mn, mm;
          double m = matrix_mindist (kmeans->centers, &mn, &mm);
 
+         int mq, mp;
+         double delta;
          kmeans_get_nearest_group (kmeans, c, qc, &mq, &delta, &mp, NULL);
          if (delta > m)
            /* If the distance between C and the nearest group, is greater than the distance
@@ -336,7 +310,7 @@ kmeans_initial_centers (struct Kmeans *kmeans,
              int which = (dist_from_case (kmeans, c, qc, mn)
                           > dist_from_case (kmeans, c, qc, mm)) ? mm : mn;
 
-             for (j = 0; j < qc->n_vars; ++j)
+             for (size_t j = 0; j < qc->n_vars; ++j)
                {
                  const union value *val = case_data (c, qc->vars[j]);
                  gsl_matrix_set (kmeans->centers, which, j, val->f);
@@ -348,7 +322,7 @@ kmeans_initial_centers (struct Kmeans *kmeans,
               nearest group (MQ) and any other group, then replace
               MQ with C.  */
            {
-             for (j = 0; j < qc->n_vars; ++j)
+             for (size_t j = 0; j < qc->n_vars; ++j)
                {
                  const union value *val = case_data (c, qc->vars[j]);
                  gsl_matrix_set (kmeans->centers, mq, j, val->f);
@@ -367,7 +341,6 @@ kmeans_initial_centers (struct Kmeans *kmeans,
   gsl_matrix_memcpy (kmeans->initial_centers, kmeans->centers);
 }
 
-
 /* Return the index of the group which is nearest to the case C */
 static void
 kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c,
@@ -376,13 +349,12 @@ kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c,
 {
   int result0 = -1;
   int result1 = -1;
-  int i, j;
   double mindist0 = INFINITY;
   double mindist1 = INFINITY;
-  for (i = 0; i < qc->ngroups; i++)
+  for (size_t i = 0; i < qc->ngroups; i++)
     {
       double dist = 0;
-      for (j = 0; j < qc->n_vars; j++)
+      for (size_t j = 0; j < qc->n_vars; j++)
        {
          const union value *val = case_data (c, qc->vars[j]);
          if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
@@ -412,7 +384,6 @@ kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c,
   if (g_q)
     *g_q = result0;
 
-
   if (delta_p)
     *delta_p = mindist1;
 
@@ -420,8 +391,6 @@ kmeans_get_nearest_group (const struct Kmeans *kmeans, struct ccase *c,
     *g_p = result1;
 }
 
-
-
 static void
 kmeans_order_groups (struct Kmeans *kmeans, const struct qc *qc)
 {
@@ -437,40 +406,33 @@ static void
 kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader,
                const struct qc *qc)
 {
-  int j;
-
   kmeans_initial_centers (kmeans, reader, qc);
 
   gsl_matrix_memcpy (kmeans->updated_centers, kmeans->centers);
-
-
-  for (int xx = 0 ; xx < qc->maxiter ; ++xx)
+  for (int xx = 0; xx < qc->maxiter; ++xx)
     {
       gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
 
       kmeans->n = 0;
-      if (!qc->no_update)
+      if (qc->update)
        {
          struct casereader *r = casereader_clone (reader);
          struct ccase *c;
          for (; (c = casereader_read (r)) != NULL; case_unref (c))
            {
-             int group = -1;
-             int g;
              bool missing = false;
-
-             for (j = 0; j < qc->n_vars; j++)
+             for (size_t j = 0; j < qc->n_vars; j++)
                {
                  const union value *val = case_data (c, qc->vars[j]);
                  if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
                    missing = true;
                }
-
              if (missing)
                continue;
 
              double mindist = INFINITY;
-             for (g = 0; g < qc->ngroups; ++g)
+             int group = -1;
+             for (size_t g = 0; g < qc->ngroups; ++g)
                {
                  double d = dist_from_case (kmeans, c, qc, g);
 
@@ -485,7 +447,7 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader,
              *n += qc->wv ? case_num (c, qc->wv) : 1.0;
              kmeans->n++;
 
-             for (j = 0; j < qc->n_vars; ++j)
+             for (size_t j = 0; j < qc->n_vars; ++j)
                {
                  const union value *val = case_data (c, qc->vars[j]);
                  if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
@@ -498,68 +460,57 @@ kmeans_cluster (struct Kmeans *kmeans, struct casereader *reader,
          casereader_destroy (r);
        }
 
-      int g;
-
       /* Divide the cluster sums by the number of items in each cluster */
-      for (g = 0; g < qc->ngroups; ++g)
-       {
-         for (j = 0; j < qc->n_vars; ++j)
-           {
-             long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
-             double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
-             *x /= n + 1;  // Plus 1 for the initial centers
-           }
-       }
-
-
+      for (size_t g = 0; g < qc->ngroups; ++g)
+        for (size_t j = 0; j < qc->n_vars; ++j)
+          {
+            long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
+            double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
+            *x /= n + 1;  // Plus 1 for the initial centers
+          }
       gsl_matrix_memcpy (kmeans->centers, kmeans->updated_centers);
 
-      {
-       kmeans->n = 0;
-       /* Step 3 */
-       gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
-       gsl_matrix_set_all (kmeans->updated_centers, 0.0);
-       struct ccase *c;
-       struct casereader *cs = casereader_clone (reader);
-       for (; (c = casereader_read (cs)) != NULL; case_unref (c))
-         {
-           int group = -1;
-           kmeans_get_nearest_group (kmeans, c, qc, &group, NULL, NULL, NULL);
-
-           for (j = 0; j < qc->n_vars; ++j)
-             {
-               const union value *val = case_data (c, qc->vars[j]);
-               if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
-                 continue;
-
-               double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
-               *x += val->f;
-             }
-
-           long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
-           *n += qc->wv ? case_num (c, qc->wv) : 1.0;
-           kmeans->n++;
-         }
-       casereader_destroy (cs);
-
-
-       /* Divide the cluster sums by the number of items in each cluster */
-       for (g = 0; g < qc->ngroups; ++g)
-         {
-           for (j = 0; j < qc->n_vars; ++j)
-             {
-               long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
-               double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
-               *x /= n ;
-             }
-         }
-
-       double d = diff_matrix (kmeans->updated_centers, kmeans->centers);
-       if (d < kmeans->convergence_criteria)
-         break;
-      }
+      kmeans->n = 0;
+      /* Step 3 */
+      gsl_vector_long_set_all (kmeans->num_elements_groups, 0.0);
+      gsl_matrix_set_all (kmeans->updated_centers, 0.0);
+      struct ccase *c;
+      struct casereader *cs = casereader_clone (reader);
+      for (; (c = casereader_read (cs)) != NULL; case_unref (c))
+        {
+          int group = -1;
+          kmeans_get_nearest_group (kmeans, c, qc, &group, NULL, NULL, NULL);
+
+          for (size_t j = 0; j < qc->n_vars; ++j)
+            {
+              const union value *val = case_data (c, qc->vars[j]);
+              if (var_is_value_missing (qc->vars[j], val) & qc->exclude)
+                continue;
+
+              double *x = gsl_matrix_ptr (kmeans->updated_centers, group, j);
+              *x += val->f;
+            }
+
+          long *n = gsl_vector_long_ptr (kmeans->num_elements_groups, group);
+          *n += qc->wv ? case_num (c, qc->wv) : 1.0;
+          kmeans->n++;
+        }
+      casereader_destroy (cs);
 
-      if (qc->no_update)
+      /* Divide the cluster sums by the number of items in each cluster */
+      for (size_t g = 0; g < qc->ngroups; ++g)
+        for (size_t j = 0; j < qc->n_vars; ++j)
+          {
+            long n = gsl_vector_long_get (kmeans->num_elements_groups, g);
+            double *x = gsl_matrix_ptr (kmeans->updated_centers, g, j);
+            *x /= n;
+          }
+
+      double d = diff_matrix (kmeans->updated_centers, kmeans->centers);
+      if (d < kmeans->convergence_criteria)
+        break;
+
+      if (!qc->update)
        break;
     }
 }
@@ -620,18 +571,17 @@ save_trans_func (void *aux, struct ccase **c, casenumber x UNUSED)
 
   *c = case_unshare (*c);
 
-  if (std->CASE_IDX_MEMBERSHIP >= 0)
-    *case_num_rw (*c, std->membership) = case_num_idx (ca, std->CASE_IDX_MEMBERSHIP);
+  if (std->membership_case_idx >= 0)
+    *case_num_rw (*c, std->membership) = case_num_idx (ca, std->membership_case_idx);
 
-  if (std->CASE_IDX_DISTANCE >= 0)
-    *case_num_rw (*c, std->distance) = case_num_idx (ca, std->CASE_IDX_DISTANCE);
+  if (std->distance_case_idx >= 0)
+    *case_num_rw (*c, std->distance) = case_num_idx (ca, std->distance_case_idx);
 
   case_unref (ca);
 
   return TRNS_CONTINUE;
 }
 
-
 /* Free the resources of the transformation.  */
 static bool
 save_trans_destroy (void *aux)
@@ -642,10 +592,8 @@ save_trans_destroy (void *aux)
   return true;
 }
 
-
-/* Reports cluster membership for each case, and is requested
-saves the membership and the distance of the case from the cluster
-centre.  */
+/* Reports cluster membership for each case, and is requested saves the
+   membership and the distance of the case from the cluster centre.  */
 static void
 quick_cluster_show_membership (struct Kmeans *kmeans,
                               const struct casereader *reader,
@@ -670,28 +618,32 @@ quick_cluster_show_membership (struct Kmeans *kmeans,
   gsl_permutation_inverse (ip, kmeans->group_order);
 
   struct caseproto *proto = caseproto_create ();
-  if (qc->save_values)
+  if (qc->save_membership || qc->save_distance)
     {
       /* Prepare data which may potentially be used in a
         transformation appending new variables to the active
         dataset.  */
-      qc->save_trans_data = xzalloc (sizeof *qc->save_trans_data);
-      qc->save_trans_data->CASE_IDX_MEMBERSHIP = -1;
-      qc->save_trans_data->CASE_IDX_DISTANCE = -1;
-      qc->save_trans_data->writer = autopaging_writer_create (proto);
-
       int idx = 0;
-      if (qc->save_values & SAVE_MEMBERSHIP)
+      int membership_case_idx = -1;
+      if (qc->save_membership)
        {
          proto = caseproto_add_width (proto, 0);
-         qc->save_trans_data->CASE_IDX_MEMBERSHIP = idx++;
+         membership_case_idx = idx++;
        }
 
-      if (qc->save_values & SAVE_DISTANCE)
+      int distance_case_idx = -1;
+      if (qc->save_distance)
        {
          proto = caseproto_add_width (proto, 0);
-         qc->save_trans_data->CASE_IDX_DISTANCE = idx++;
+         distance_case_idx = idx++;
        }
+
+      qc->save_trans_data = xmalloc (sizeof *qc->save_trans_data);
+      *qc->save_trans_data = (struct save_trans_data) {
+        .membership_case_idx = membership_case_idx,
+        .distance_case_idx = distance_case_idx,
+        .writer = autopaging_writer_create (proto),
+      };
     }
 
   struct casereader *cs = casereader_clone (reader);
@@ -704,18 +656,18 @@ quick_cluster_show_membership (struct Kmeans *kmeans,
       int cluster = ip->data[clust];
 
       if (qc->save_trans_data)
-      {
-       /* Calculate the membership and distance values.  */
-       struct ccase *outc = case_create (proto);
-       if (qc->save_values & SAVE_MEMBERSHIP)
-         *case_num_rw_idx (outc, qc->save_trans_data->CASE_IDX_MEMBERSHIP) = cluster + 1;
+        {
+          /* Calculate the membership and distance values.  */
+          struct ccase *outc = case_create (proto);
+          if (qc->save_membership)
+            *case_num_rw_idx (outc, qc->save_trans_data->membership_case_idx) = cluster + 1;
 
-       if (qc->save_values & SAVE_DISTANCE)
-         *case_num_rw_idx (outc, qc->save_trans_data->CASE_IDX_DISTANCE)
-           = sqrt (dist_from_case (kmeans, c, qc, clust));
+          if (qc->save_distance)
+            *case_num_rw_idx (outc, qc->save_trans_data->distance_case_idx)
+              = sqrt (dist_from_case (kmeans, c, qc, clust));
 
-       casewriter_write (qc->save_trans_data->writer, outc);
-      }
+          casewriter_write (qc->save_trans_data->writer, outc);
+        }
 
       if (qc->print_cluster_membership)
        {
@@ -790,9 +742,7 @@ quick_cluster_parse (struct lexer *lexer, struct qc *qc)
 {
   if (!parse_variables_const (lexer, qc->dict, &qc->vars, &qc->n_vars,
                              PV_NO_DUPLICATE | PV_NUMERIC))
-    {
-      return (CMD_FAILURE);
-    }
+    return false;
 
   while (lex_token (lexer) != T_ENDCMD)
     {
@@ -806,24 +756,17 @@ quick_cluster_parse (struct lexer *lexer, struct qc *qc)
            {
              if (lex_match_id (lexer, "LISTWISE")
                  || lex_match_id (lexer, "DEFAULT"))
-               {
-                 qc->missing_type = MISS_LISTWISE;
-               }
+                qc->missing_type = MISS_LISTWISE;
              else if (lex_match_id (lexer, "PAIRWISE"))
-               {
-                 qc->missing_type = MISS_PAIRWISE;
-               }
+                qc->missing_type = MISS_PAIRWISE;
              else if (lex_match_id (lexer, "INCLUDE"))
-               {
-                 qc->exclude = MV_SYSTEM;
-               }
+                qc->exclude = MV_SYSTEM;
              else if (lex_match_id (lexer, "EXCLUDE"))
-               {
-                 qc->exclude = MV_ANY;
-               }
+                qc->exclude = MV_ANY;
              else
                {
-                 lex_error (lexer, NULL);
+                 lex_error_expecting (lexer, "LISTWISE", "DEFAULT",
+                                       "PAIRWISE", "INCLUDE", "EXCLUDE");
                  return false;
                }
            }
@@ -840,7 +783,7 @@ quick_cluster_parse (struct lexer *lexer, struct qc *qc)
                qc->print_initial_clusters = true;
              else
                {
-                 lex_error (lexer, NULL);
+                 lex_error_expecting (lexer, "CLUSTER", "INITIAL");
                  return false;
                }
            }
@@ -853,7 +796,7 @@ quick_cluster_parse (struct lexer *lexer, struct qc *qc)
            {
              if (lex_match_id (lexer, "CLUSTER"))
                {
-                 qc->save_values |= SAVE_MEMBERSHIP;
+                 qc->save_membership = true;
                  if (lex_match (lexer, T_LPAREN))
                    {
                      if (!lex_force_id (lexer))
@@ -879,7 +822,7 @@ quick_cluster_parse (struct lexer *lexer, struct qc *qc)
                }
              else if (lex_match_id (lexer, "DISTANCE"))
                {
-                 qc->save_values |= SAVE_DISTANCE;
+                 qc->save_distance = true;
                  if (lex_match (lexer, T_LPAREN))
                    {
                      if (!lex_force_id (lexer))
@@ -918,55 +861,50 @@ quick_cluster_parse (struct lexer *lexer, struct qc *qc)
            {
              if (lex_match_id (lexer, "CLUSTERS"))
                {
-                 if (lex_force_match (lexer, T_LPAREN) &&
-                     lex_force_int_range (lexer, "CLUSTERS", 1, INT_MAX))
-                   {
-                     qc->ngroups = lex_integer (lexer);
-                     lex_get (lexer);
-                     if (!lex_force_match (lexer, T_RPAREN))
-                       return false;
-                   }
+                 if (!lex_force_match (lexer, T_LPAREN)
+                     || !lex_force_int_range (lexer, "CLUSTERS", 1, INT_MAX))
+                    return false;
+                  qc->ngroups = lex_integer (lexer);
+                  lex_get (lexer);
+                  if (!lex_force_match (lexer, T_RPAREN))
+                    return false;
                }
              else if (lex_match_id (lexer, "CONVERGE"))
                {
-                 if (lex_force_match (lexer, T_LPAREN) &&
-                     lex_force_num_range_open (lexer, "CONVERGE", 0, DBL_MAX))
-                   {
-                     qc->epsilon = lex_number (lexer);
-                     lex_get (lexer);
-                     if (!lex_force_match (lexer, T_RPAREN))
-                       return false;
-                   }
+                 if (!lex_force_match (lexer, T_LPAREN)
+                     || !lex_force_num_range_open (lexer, "CONVERGE",
+                                                    0, DBL_MAX))
+                    return false;
+                  qc->epsilon = lex_number (lexer);
+                  lex_get (lexer);
+                  if (!lex_force_match (lexer, T_RPAREN))
+                    return false;
                }
              else if (lex_match_id (lexer, "MXITER"))
                {
-                 if (lex_force_match (lexer, T_LPAREN) &&
-                     lex_force_int_range (lexer, "MXITER", 1, INT_MAX))
-                   {
-                     qc->maxiter = lex_integer (lexer);
-                     lex_get (lexer);
-                     if (!lex_force_match (lexer, T_RPAREN))
-                       return false;
-                   }
+                 if (!lex_force_match (lexer, T_LPAREN)
+                     || !lex_force_int_range (lexer, "MXITER", 1, INT_MAX))
+                    return false;
+                  qc->maxiter = lex_integer (lexer);
+                  lex_get (lexer);
+                  if (!lex_force_match (lexer, T_RPAREN))
+                    return false;
                }
              else if (lex_match_id (lexer, "NOINITIAL"))
-               {
-                 qc->no_initial = true;
-               }
+                qc->initial = false;
              else if (lex_match_id (lexer, "NOUPDATE"))
-               {
-                 qc->no_update = true;
-               }
+                qc->update = false;
              else
                {
-                 lex_error (lexer, NULL);
+                 lex_error_expecting (lexer, "CLUSTERS", "CONVERGE", "MXITER",
+                                       "NOINITIAL", "NOUPDATE");
                  return false;
                }
            }
        }
       else
         {
-          lex_error (lexer, NULL);
+          lex_error_expecting (lexer, "MISSING", "PRINT", "SAVE", "CRITERIA");
           return false;
         }
     }
@@ -976,57 +914,49 @@ quick_cluster_parse (struct lexer *lexer, struct qc *qc)
 int
 cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
 {
-  struct qc qc;
-  struct Kmeans *kmeans;
-  bool ok;
-  memset (&qc, 0, sizeof qc);
-  qc.dataset = ds;
-  qc.dict =  dataset_dict (ds);
-  qc.ngroups = 2;
-  qc.maxiter = 10;
-  qc.epsilon = DBL_EPSILON;
-  qc.missing_type = MISS_LISTWISE;
-  qc.exclude = MV_ANY;
-
+  struct qc qc = {
+    .dataset = ds,
+    .dict = dataset_dict (ds),
+    .ngroups = 2,
+    .maxiter = 10,
+    .epsilon = DBL_EPSILON,
+    .missing_type = MISS_LISTWISE,
+    .exclude = MV_ANY,
+    .initial = true,
+    .update = true,
+  };
 
   if (!quick_cluster_parse (lexer, &qc))
     goto error;
 
   qc.wv = dict_get_weight (qc.dict);
 
-  {
-    struct casereader *group;
-    struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), qc.dict);
-
-    while (casegrouper_get_next_group (grouper, &group))
-      {
-       if (qc.missing_type == MISS_LISTWISE)
-         {
-           group = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
-                                                      qc.exclude,
-                                                      NULL,  NULL);
-         }
-
-       kmeans = kmeans_create (&qc);
-       kmeans_cluster (kmeans, group, &qc);
-       quick_cluster_show_results (kmeans, group, &qc);
-       kmeans_destroy (kmeans);
-       casereader_destroy (group);
-      }
-    ok = casegrouper_destroy (grouper);
-  }
+  struct casegrouper *grouper = casegrouper_create_splits (proc_open (ds), qc.dict);
+  struct casereader *group;
+  while (casegrouper_get_next_group (grouper, &group))
+    {
+      if (qc.missing_type == MISS_LISTWISE)
+        group = casereader_create_filter_missing (group, qc.vars, qc.n_vars,
+                                                  qc.exclude, NULL, NULL);
+
+      struct Kmeans *kmeans = kmeans_create (&qc);
+      kmeans_cluster (kmeans, group, &qc);
+      quick_cluster_show_results (kmeans, group, &qc);
+      kmeans_destroy (kmeans);
+      casereader_destroy (group);
+    }
+  bool ok = casegrouper_destroy (grouper);
   ok = proc_commit (ds) && ok;
 
-
   /* If requested, set a transformation to append the cluster and
      distance values to the current dataset.  */
   if (qc.save_trans_data)
     {
       struct save_trans_data *std = qc.save_trans_data;
+
       std->appending_reader = casewriter_make_reader (std->writer);
-      std->writer = NULL;
 
-      if (qc.save_values & SAVE_MEMBERSHIP)
+      if (qc.save_membership)
        {
          /* Invent a variable name if necessary.  */
          int idx = 0;
@@ -1048,7 +978,7 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
          std->membership = dict_create_var_assert (qc.dict, qc.var_membership, 0);
        }
 
-      if (qc.save_values & SAVE_DISTANCE)
+      if (qc.save_distance)
        {
          /* Invent a variable name if necessary.  */
          int idx = 0;
@@ -1081,7 +1011,7 @@ cmd_quick_cluster (struct lexer *lexer, struct dataset *ds)
   free (qc.var_distance);
   free (qc.var_membership);
   free (qc.vars);
-  return (ok);
+  return ok;
 
  error:
   free (qc.var_distance);
index 62d9facf32dc705a1001505d785f2265733b2ac4..12f4b2ac43b6dc0a6689e38580a4066f8156c5f2 100644 (file)
@@ -270,7 +270,7 @@ end data.
 QUICK CLUSTER x y /UNSUPPORTED.
 ])
 AT_CHECK([pspp -O format=csv quick-cluster.sps], [1], [dnl
-"quick-cluster.sps:7.20-7.30: error: QUICK CLUSTER: Syntax error.
+"quick-cluster.sps:7.20-7.30: error: QUICK CLUSTER: Syntax error expecting MISSING, PRINT, SAVE, or CRITERIA.
     7 | QUICK CLUSTER x y /UNSUPPORTED.
       |                    ^~~~~~~~~~~"
 ])
@@ -582,8 +582,105 @@ quick cluster x
   .
 
 list.
-])
+]) dnl "
 
 AT_CHECK([pspp -O format=csv badn.sps], [1], [ignore])
 
 AT_CLEANUP
+
+AT_SETUP([QUICK CLUSTER syntax errors])
+AT_DATA([quick-cluster.sps], [dnl
+DATA LIST LIST NOTABLE /x y.
+QUICK CLUSTER **.
+QUICK CLUSTER x/MISSING=**.
+QUICK CLUSTER x/PRINT=**.
+QUICK CLUSTER x/SAVE=CLUSTER(**).
+QUICK CLUSTER x/SAVE=CLUSTER(x).
+QUICK CLUSTER x/SAVE=CLUSTER(c **).
+QUICK CLUSTER x/SAVE=DISTANCE(**).
+QUICK CLUSTER x/SAVE=DISTANCE(x).
+QUICK CLUSTER x/SAVE=DISTANCE(d **).
+QUICK CLUSTER x/SAVE=**.
+QUICK CLUSTER x/CRITERIA=CLUSTERS **.
+QUICK CLUSTER x/CRITERIA=CLUSTERS(**).
+QUICK CLUSTER x/CRITERIA=CLUSTERS(5 **).
+QUICK CLUSTER x/CRITERIA=CONVERGE **.
+QUICK CLUSTER x/CRITERIA=CONVERGE(**).
+QUICK CLUSTER x/CRITERIA=CONVERGE(5 **).
+QUICK CLUSTER x/CRITERIA=**.
+QUICK CLUSTER x/ **.
+])
+AT_CHECK([pspp -O format=csv quick-cluster.sps], [1], [dnl
+"quick-cluster.sps:2.15-2.16: error: QUICK CLUSTER: Syntax error expecting variable name.
+    2 | QUICK CLUSTER **.
+      |               ^~"
+
+"quick-cluster.sps:3.25-3.26: error: QUICK CLUSTER: Syntax error expecting LISTWISE, DEFAULT, PAIRWISE, INCLUDE, or EXCLUDE.
+    3 | QUICK CLUSTER x/MISSING=**.
+      |                         ^~"
+
+"quick-cluster.sps:4.23-4.24: error: QUICK CLUSTER: Syntax error expecting CLUSTER or INITIAL.
+    4 | QUICK CLUSTER x/PRINT=**.
+      |                       ^~"
+
+"quick-cluster.sps:5.30-5.31: error: QUICK CLUSTER: Syntax error expecting identifier.
+    5 | QUICK CLUSTER x/SAVE=CLUSTER(**).
+      |                              ^~"
+
+"quick-cluster.sps:6.30: error: QUICK CLUSTER: A variable called `x' already exists.
+    6 | QUICK CLUSTER x/SAVE=CLUSTER(x).
+      |                              ^"
+
+"quick-cluster.sps:7.32-7.33: error: QUICK CLUSTER: Syntax error expecting `@:}@'.
+    7 | QUICK CLUSTER x/SAVE=CLUSTER(c **).
+      |                                ^~"
+
+"quick-cluster.sps:8.31-8.32: error: QUICK CLUSTER: Syntax error expecting identifier.
+    8 | QUICK CLUSTER x/SAVE=DISTANCE(**).
+      |                               ^~"
+
+"quick-cluster.sps:9.31: error: QUICK CLUSTER: A variable called `x' already exists.
+    9 | QUICK CLUSTER x/SAVE=DISTANCE(x).
+      |                               ^"
+
+"quick-cluster.sps:10.33-10.34: error: QUICK CLUSTER: Syntax error expecting `@:}@'.
+   10 | QUICK CLUSTER x/SAVE=DISTANCE(d **).
+      |                                 ^~"
+
+"quick-cluster.sps:11.22-11.23: error: QUICK CLUSTER: Syntax error expecting CLUSTER or DISTANCE.
+   11 | QUICK CLUSTER x/SAVE=**.
+      |                      ^~"
+
+"quick-cluster.sps:12.35-12.36: error: QUICK CLUSTER: Syntax error expecting `('.
+   12 | QUICK CLUSTER x/CRITERIA=CLUSTERS **.
+      |                                   ^~"
+
+"quick-cluster.sps:13.35-13.36: error: QUICK CLUSTER: Syntax error expecting positive integer for CLUSTERS.
+   13 | QUICK CLUSTER x/CRITERIA=CLUSTERS(**).
+      |                                   ^~"
+
+"quick-cluster.sps:14.37-14.38: error: QUICK CLUSTER: Syntax error expecting `)'.
+   14 | QUICK CLUSTER x/CRITERIA=CLUSTERS(5 **).
+      |                                     ^~"
+
+"quick-cluster.sps:15.35-15.36: error: QUICK CLUSTER: Syntax error expecting `('.
+   15 | QUICK CLUSTER x/CRITERIA=CONVERGE **.
+      |                                   ^~"
+
+"quick-cluster.sps:16.35-16.36: error: QUICK CLUSTER: Syntax error expecting positive number for CONVERGE.
+   16 | QUICK CLUSTER x/CRITERIA=CONVERGE(**).
+      |                                   ^~"
+
+"quick-cluster.sps:17.37-17.38: error: QUICK CLUSTER: Syntax error expecting `)'.
+   17 | QUICK CLUSTER x/CRITERIA=CONVERGE(5 **).
+      |                                     ^~"
+
+"quick-cluster.sps:18.26-18.27: error: QUICK CLUSTER: Syntax error expecting CLUSTERS, CONVERGE, MXITER, NOINITIAL, or NOUPDATE.
+   18 | QUICK CLUSTER x/CRITERIA=**.
+      |                          ^~"
+
+"quick-cluster.sps:19.18-19.19: error: QUICK CLUSTER: Syntax error expecting MISSING, PRINT, SAVE, or CRITERIA.
+   19 | QUICK CLUSTER x/ **.
+      |                  ^~"
+])
+AT_CLEANUP
\ No newline at end of file