give CTABLES its own freq structure
[pspp] / src / language / stats / ctables.c
index 563abdf7f1d99951db13b7522e1d65d9cb0f7b50..7a18ac5f134e4768f1cc6fdb939d5c16b4f95ee6 100644 (file)
@@ -24,7 +24,7 @@
 #include "language/lexer/format-parser.h"
 #include "language/lexer/lexer.h"
 #include "language/lexer/variable-parser.h"
-#include "language/stats/freq.h"
+#include "libpspp/array.h"
 #include "libpspp/assertion.h"
 #include "libpspp/hmap.h"
 #include "libpspp/message.h"
 
 enum ctables_vlabel
   {
-    CTVL_DEFAULT = SETTINGS_VALUE_SHOW_DEFAULT,
+    CTVL_NONE = SETTINGS_VALUE_SHOW_DEFAULT,
     CTVL_NAME = SETTINGS_VALUE_SHOW_VALUE,
     CTVL_LABEL = SETTINGS_VALUE_SHOW_LABEL,
     CTVL_BOTH = SETTINGS_VALUE_SHOW_BOTH,
-    CTVL_NONE,
   };
-static void UNUSED
-ctables_vlabel_unique (enum ctables_vlabel vlabel)
-{
-  /* This ensures that all of the values are unique. */
-  switch (vlabel)
-    {
-    case CTVL_DEFAULT:
-    case CTVL_NAME:
-    case CTVL_LABEL:
-    case CTVL_BOTH:
-    case CTVL_NONE:
-      abort ();
-    }
-}
 
 /* XXX:
    - unweighted summaries (U*)
@@ -255,6 +240,9 @@ struct ctables_table
 
     struct ctables_chisq *chisq;
     struct ctables_pairwise *pairwise;
+
+    struct ctables_freqtab **fts;
+    size_t n_fts;
   };
 
 struct ctables_var
@@ -1350,20 +1338,51 @@ enumerate_fts (const struct ctables_axis *a)
   NOT_REACHED ();
 }
 
+struct ctables_freq
+  {
+    struct hmap_node node;      /* Element in hash table. */
+    double count;
+    union value values[];      /* The value. */
+  };
+
+static struct ctables_freq *
+ctables_freq_allocate (size_t n_values)
+{
+  struct ctables_freq *f;
+  return xmalloc (sizeof *f + n_values * sizeof *f->values);
+}
+
 struct ctables_freqtab
   {
     struct var_array vars;
-    struct hmap data;           /* Contains "struct freq"s. */
+    struct hmap data;           /* Contains "struct ctables_freq"s. */
+    struct ctables_freq **sorted;
   };
 
+static int
+ctables_freq_compare_3way (const void *a_, const void *b_, const void *vars_)
+{
+  const struct var_array *vars = vars_;
+  struct ctables_freq *const *a = a_;
+  struct ctables_freq *const *b = b_;
+
+  for (size_t i = 0; i < vars->n; i++)
+    {
+      int cmp = value_compare_3way (&(*a)->values[i], &(*b)->values[i],
+                                    var_get_width (vars->vars[i]));
+      if (cmp)
+        return cmp;
+    }
+  return 0;
+}
+
 static bool
 ctables_execute (struct dataset *ds, struct ctables *ct)
 {
-  struct ctables_freqtab **fts = NULL;
-  size_t n_fts = 0;
-  size_t allocated_fts = 0;
   for (size_t i = 0; i < ct->n_tables; i++)
     {
+      size_t allocated_fts = 0;
+
       struct ctables_table *t = &ct->tables[i];
       struct var_array2 vaa = enumerate_fts (t->axes[PIVOT_AXIS_ROW]);
       vaa = nest_fts (vaa, enumerate_fts (t->axes[PIVOT_AXIS_COLUMN]));
@@ -1381,15 +1400,15 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
 
       for (size_t j = 0; j < vaa.n; j++)
         {
-          if (n_fts >= allocated_fts)
-            fts = x2nrealloc (fts, &allocated_fts, sizeof *fts);
-
           struct ctables_freqtab *ft = xmalloc (sizeof *ft);
           *ft = (struct ctables_freqtab) {
             .vars = vaa.vas[j],
             .data = HMAP_INITIALIZER (ft->data),
           };
-          fts[n_fts++] = ft;
+
+          if (t->n_fts >= allocated_fts)
+            t->fts = x2nrealloc (t->fts, &allocated_fts, sizeof *t->fts);
+          t->fts[t->n_fts++] = ft;
         }
 
       free (vaa.vas);
@@ -1404,69 +1423,161 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
     {
       double weight = dict_get_case_weight (dataset_dict (ds), c,
                                             &warn_on_invalid);
-      for (size_t i = 0; i < n_fts; i++)
-        {
-          struct ctables_freqtab *ft = fts[i];
 
-          size_t hash = 0;
+      for (size_t i = 0; i < ct->n_tables; i++)
+        {
+          struct ctables_table *t = &ct->tables[i];
 
-          for (size_t j = 0; j < ft->vars.n; j++)
+          for (size_t j = 0; j < t->n_fts; j++)
             {
-              const struct variable *var = ft->vars.vars[j];
-              hash = value_hash (case_data (c, var), var_get_width (var), hash);
-            }
+              struct ctables_freqtab *ft = t->fts[j];
 
-          struct freq *f;
-          HMAP_FOR_EACH_WITH_HASH (f, struct freq, node, hash, &ft->data)
-            {
-              for (size_t j = 0; j < ft->vars.n; j++)
+              for (size_t k = 0; k < ft->vars.n; k++)
+                {
+                  const struct variable *var = ft->vars.vars[k];
+                  switch (var_is_value_missing (var, case_data (c, var)))
+                    {
+                    case MV_SYSTEM:
+                      goto next_ft;
+
+                    case MV_USER:
+                      if (!t->categories[var_get_dict_index (var)]
+                          || !t->categories[var_get_dict_index (var)]->include_missing)
+                        goto next_ft;
+                      break;
+                    }
+                }
+              size_t hash = 0;
+              for (size_t k = 0; k < ft->vars.n; k++)
+                {
+                  const struct variable *var = ft->vars.vars[k];
+                  hash = value_hash (case_data (c, var), var_get_width (var), hash);
+                }
+
+              struct ctables_freq *f;
+              HMAP_FOR_EACH_WITH_HASH (f, struct ctables_freq, node, hash, &ft->data)
                 {
-                  const struct variable *var = ft->vars.vars[j];
-                  if (!value_equal (case_data (c, var), &f->values[j],
-                                    var_get_width (var)))
-                    goto next_hash_node;
+                  for (size_t k = 0; k < ft->vars.n; k++)
+                    {
+                      const struct variable *var = ft->vars.vars[k];
+                      if (!value_equal (case_data (c, var), &f->values[k],
+                                        var_get_width (var)))
+                        goto next_hash_node;
+                    }
+
+                  f->count += weight;
+                  goto next_ft;
+
+                next_hash_node: ;
                 }
 
-              f->count += weight;
-              goto next_ft;
+              f = ctables_freq_allocate (ft->vars.n);
+              f->count = weight;
+              for (size_t k = 0; k < ft->vars.n; k++)
+                {
+                  const struct variable *var = ft->vars.vars[k];
+                  value_clone (&f->values[k], case_data (c, var),
+                               var_get_width (var));
+                }
+              hmap_insert (&ft->data, &f->node, hash);
 
-            next_hash_node: ;
+            next_ft: ;
             }
+        }
+    }
+  casereader_destroy (input);
 
-        f = xmalloc (table_entry_size (ft->vars.n));
-        f->count = weight;
-        for (size_t j = 0; j < ft->vars.n; j++)
-          {
-            const struct variable *var = ft->vars.vars[j];
-            value_clone (&f->values[j], case_data (c, var),
-                         var_get_width (var));
-          }
-        hmap_insert (&ft->data, &f->node, hash);
+  for (size_t i = 0; i < ct->n_tables; i++)
+    {
+      struct ctables_table *t = &ct->tables[i];
+
+      struct pivot_table *pt = pivot_table_create (N_("Custom Tables"));
+      struct pivot_dimension *d = pivot_dimension_create (
+        pt, PIVOT_AXIS_ROW, N_("Rows"));
+      for (size_t j = 0; j < t->n_fts; j++)
+        {
+          struct ctables_freqtab *ft = t->fts[j];
+          ft->sorted = xnmalloc (ft->data.count, sizeof *ft->sorted);
+
+          struct ctables_freq *f;
+          size_t n = 0;
+          HMAP_FOR_EACH (f, struct ctables_freq, node, &ft->data)
+            ft->sorted[n++] = f;
+          assert (n == ft->data.count);
+          sort (ft->sorted, n, sizeof *ft->sorted,
+                ctables_freq_compare_3way, &ft->vars);
+
+          struct pivot_category **groups = xnmalloc (ft->vars.n,
+                                                     sizeof *groups);
+          for (size_t k = 0; k < n; k++)
+            {
+              struct ctables_freq *prev = k > 0 ? ft->sorted[k - 1] : NULL;
+              struct ctables_freq *f = ft->sorted[k];
+
+              size_t n_common = 0;
+              if (prev)
+                for (; n_common + 1 < ft->vars.n; n_common++)
+                  if (!value_equal (&prev->values[n_common],
+                                    &f->values[n_common],
+                                    var_get_type (ft->vars.vars[n_common])))
+                    break;
+
+              for (size_t m = n_common; m < ft->vars.n; m++)
+                {
+                  struct pivot_category *parent = m > 0 ? groups[m - 1] : d->root;
+                  const struct variable *var = ft->vars.vars[m];
+                  enum ctables_vlabel vlabel = ct->vlabels[var_get_dict_index (var)];
+
+                  if (vlabel != CTVL_NONE)
+                    parent = pivot_category_create_group__ (
+                      parent, pivot_value_new_variable (ft->vars.vars[m]));
 
-        next_ft: ;
+                  if (m + 1 < ft->vars.n)
+                    parent = pivot_category_create_group__ (
+                      parent,
+                      pivot_value_new_var_value (ft->vars.vars[m], &f->values[m]));
+                  groups[m] = parent;
+
+                  if (m == ft->vars.n - 1)
+                    {
+                      int leaf = pivot_category_create_leaf (
+                        parent,
+                        pivot_value_new_var_value (ft->vars.vars[ft->vars.n - 1],
+                                                   &f->values[ft->vars.n - 1]));
+                      pivot_table_put1 (pt, leaf, pivot_value_new_number (f->count));
+                    }
+                }
+            }
+          free (groups);
         }
+      pivot_table_submit (pt);
     }
-  casereader_destroy (input);
 
-  for (size_t i = 0; i < n_fts; i++)
+  for (size_t i = 0; i < ct->n_tables; i++)
     {
-      struct ctables_freqtab *ft = fts[i];
-      struct freq *f, *next;
-      HMAP_FOR_EACH_SAFE (f, next, struct freq, node, &ft->data)
+      struct ctables_table *t = &ct->tables[i];
+
+      for (size_t j = 0; j < t->n_fts; j++)
         {
-          hmap_delete (&ft->data, &f->node);
-          for (size_t j = 0; j < ft->vars.n; j++)
+          struct ctables_freqtab *ft = t->fts[j];
+          struct ctables_freq *f, *next;
+          HMAP_FOR_EACH_SAFE (f, next, struct ctables_freq, node, &ft->data)
             {
-              const struct variable *var = ft->vars.vars[j];
-              value_destroy (&f->values[j], var_get_width (var));
+              hmap_delete (&ft->data, &f->node);
+              for (size_t k = 0; k < ft->vars.n; k++)
+                {
+                  const struct variable *var = ft->vars.vars[k];
+                  value_destroy (&f->values[k], var_get_width (var));
+                }
+              free (f);
             }
-          free (f);
+          hmap_destroy (&ft->data);
+          free (ft->sorted);
+          var_array_uninit (&ft->vars);
+          free (ft);
         }
-      hmap_destroy (&ft->data);
-      var_array_uninit (&ft->vars);
-      free (ft);
+      free (t->fts);
     }
-  free (fts);
 
   return proc_commit (ds);
 }
@@ -1476,8 +1587,9 @@ cmd_ctables (struct lexer *lexer, struct dataset *ds)
 {
   size_t n_vars = dict_get_n_vars (dataset_dict (ds));
   enum ctables_vlabel *vlabels = xnmalloc (n_vars, sizeof *vlabels);
+  enum settings_value_show tvars = settings_get_show_variables ();
   for (size_t i = 0; i < n_vars; i++)
-    vlabels[i] = CTVL_DEFAULT;
+    vlabels[i] = (enum ctables_vlabel) tvars;
 
   struct ctables *ct = xmalloc (sizeof *ct);
   *ct = (struct ctables) {
@@ -1601,7 +1713,7 @@ cmd_ctables (struct lexer *lexer, struct dataset *ds)
 
           enum ctables_vlabel vlabel;
           if (lex_match_id (lexer, "DEFAULT"))
-            vlabel = CTVL_DEFAULT;
+            vlabel = (enum ctables_vlabel) settings_get_show_variables ();
           else if (lex_match_id (lexer, "NAME"))
             vlabel = CTVL_NAME;
           else if (lex_match_id (lexer, "LABEL"))