freqtabs are per-table
[pspp] / src / language / stats / ctables.c
index 1b8ba198f83cca7c882e2b3140e729a698a7d270..dc761ade728e85235f89953dcd6d093303cdffec 100644 (file)
@@ -16,6 +16,7 @@
 
 #include <config.h>
 
+#include "data/casereader.h"
 #include "data/dataset.h"
 #include "data/dictionary.h"
 #include "data/mrset.h"
 #include "language/lexer/format-parser.h"
 #include "language/lexer/lexer.h"
 #include "language/lexer/variable-parser.h"
+#include "language/stats/freq.h"
+#include "libpspp/array.h"
 #include "libpspp/assertion.h"
 #include "libpspp/hmap.h"
 #include "libpspp/message.h"
+#include "libpspp/string-array.h"
 #include "output/pivot-table.h"
 
 #include "gl/minmax.h"
@@ -216,12 +220,7 @@ struct ctables_postcompute_expr
         /* CTPO_CAT_RANGE.
 
            XXX what about string ranges? */
-        struct
-          {
-            double low;         /* -DBL_MAX for LO. */
-            double high;        /* DBL_MAX for HIGH. */
-          }
-        range;
+        double range[2];
 
         /* CTPO_ADD, CTPO_SUB, CTPO_MUL, CTPO_DIV, CTPO_POW. */
         struct ctables_postcompute_expr *subs[2];
@@ -257,6 +256,9 @@ struct ctables_table
 
     struct ctables_chisq *chisq;
     struct ctables_pairwise *pairwise;
+
+    struct ctables_freqtab **fts;
+    size_t n_fts;
   };
 
 struct ctables_var
@@ -1252,6 +1254,275 @@ ctables_table_parse_categories (struct lexer *lexer, struct dictionary *dict,
   return true;
 }
 
+struct var_array
+  {
+    struct variable **vars;
+    size_t n;
+  };
+
+static void
+var_array_uninit (struct var_array *va)
+{
+  if (va)
+    free (va->vars);
+}
+
+struct var_array2
+  {
+    struct var_array *vas;
+    size_t n;
+  };
+
+static void
+var_array2_uninit (struct var_array2 *vaa)
+{
+  if (vaa)
+    {
+      for (size_t i = 0; i < vaa->n; i++)
+        var_array_uninit (&vaa->vas[i]);
+      free (vaa->vas);
+    }
+}
+
+static struct var_array2
+nest_fts (struct var_array2 va0, struct var_array2 va1)
+{
+  if (!va0.n)
+    return va1;
+  else if (!va1.n)
+    return va0;
+
+  struct var_array2 vaa = { .vas = xnmalloc (va0.n, va1.n * sizeof *vaa.vas) };
+  for (size_t i = 0; i < va0.n; i++)
+    for (size_t j = 0; j < va1.n; j++)
+      {
+        size_t allocate = va0.vas[i].n + va1.vas[j].n;
+        struct variable **vars = xnmalloc (allocate, sizeof *vars);
+        size_t n = 0;
+        for (size_t k = 0; k < va0.vas[i].n; k++)
+          vars[n++] = va0.vas[i].vars[k];
+        for (size_t k = 0; k < va1.vas[j].n; k++)
+          vars[n++] = va1.vas[j].vars[k];
+        assert (n == allocate);
+
+        vaa.vas[vaa.n++] = (struct var_array) { .vars = vars, n = n };
+      }
+  var_array2_uninit (&va0);
+  var_array2_uninit (&va1);
+  return vaa;
+}
+
+static struct var_array2
+stack_fts (struct var_array2 va0, struct var_array2 va1)
+{
+  struct var_array2 vaa = { .vas = xnmalloc (va0.n + va1.n, sizeof *vaa.vas) };
+  for (size_t i = 0; i < va0.n; i++)
+    vaa.vas[vaa.n++] = va0.vas[i];
+  for (size_t i = 0; i < va1.n; i++)
+    vaa.vas[vaa.n++] = va1.vas[i];
+  assert (vaa.n == va0.n + va1.n);
+  free (va0.vas);
+  free (va1.vas);
+  return vaa;
+}
+
+static struct var_array2
+enumerate_fts (const struct ctables_axis *a)
+{
+  if (!a)
+    return (struct var_array2) { .n = 0 };
+
+  switch (a->op)
+    {
+    case CTAO_VAR:
+      assert (!a->var.is_mrset);
+      struct variable **v = xmalloc (sizeof *v);
+      *v = a->var.var;
+      struct var_array *va = xmalloc (sizeof *va);
+      *va = (struct var_array) { .vars = v, .n = 1 };
+      return (struct var_array2) { .vas = va, .n = 1 };
+
+    case CTAO_STACK:
+      return stack_fts (enumerate_fts (a->subs[0]),
+                        enumerate_fts (a->subs[1]));
+
+    case CTAO_NEST:
+      return nest_fts (enumerate_fts (a->subs[0]),
+                       enumerate_fts (a->subs[1]));
+    }
+
+  NOT_REACHED ();
+}
+
+struct ctables_freqtab
+  {
+    struct var_array vars;
+    struct hmap data;           /* Contains "struct freq"s. */
+    struct freq **sorted;
+  };
+
+static int
+compare_freq_3way (const void *a_, const void *b_, const void *vars_)
+{
+  const struct var_array *vars = vars_;
+  struct freq *const *a = a_;
+  struct freq *const *b = b_;
+
+  for (size_t i = 0; i < vars->n; i++)
+    {
+      int cmp = value_compare_3way (&(*a)->values[i], &(*b)->values[i],
+                                    var_get_width (vars->vars[i]));
+      if (cmp)
+        return cmp;
+    }
+  
+  return 0;
+}
+
+static bool
+ctables_execute (struct dataset *ds, struct ctables *ct)
+{
+  for (size_t i = 0; i < ct->n_tables; i++)
+    {
+      size_t allocated_fts = 0;
+
+      struct ctables_table *t = &ct->tables[i];
+      struct var_array2 vaa = enumerate_fts (t->axes[PIVOT_AXIS_ROW]);
+      vaa = nest_fts (vaa, enumerate_fts (t->axes[PIVOT_AXIS_COLUMN]));
+      vaa = nest_fts (vaa, enumerate_fts (t->axes[PIVOT_AXIS_LAYER]));
+      for (size_t i = 0; i < vaa.n; i++)
+        {
+          for (size_t j = 0; j < vaa.vas[i].n; j++)
+            {
+              if (j)
+                fputs (", ", stdout);
+              fputs (var_get_name (vaa.vas[i].vars[j]), stdout);
+            }
+          putchar ('\n');
+        }
+
+      for (size_t j = 0; j < vaa.n; j++)
+        {
+          struct ctables_freqtab *ft = xmalloc (sizeof *ft);
+          *ft = (struct ctables_freqtab) {
+            .vars = vaa.vas[j],
+            .data = HMAP_INITIALIZER (ft->data),
+          };
+
+          if (t->n_fts >= allocated_fts)
+            t->fts = x2nrealloc (t->fts, &allocated_fts, sizeof *t->fts);
+          t->fts[t->n_fts++] = ft;
+        }
+
+      free (vaa.vas);
+    }
+
+  struct casereader *input = casereader_create_filter_weight (proc_open (ds),
+                                                              dataset_dict (ds),
+                                                              NULL, NULL);
+  bool warn_on_invalid = true;
+  for (struct ccase *c = casereader_read (input); c;
+       case_unref (c), c = casereader_read (input))
+    {
+      double weight = dict_get_case_weight (dataset_dict (ds), c,
+                                            &warn_on_invalid);
+
+      for (size_t i = 0; i < ct->n_tables; i++)
+        {
+          struct ctables_table *t = &ct->tables[i];
+
+          for (size_t j = 0; j < t->n_fts; j++)
+            {
+              struct ctables_freqtab *ft = t->fts[j];
+
+              size_t hash = 0;
+
+              for (size_t k = 0; k < ft->vars.n; k++)
+                {
+                  const struct variable *var = ft->vars.vars[k];
+                  hash = value_hash (case_data (c, var), var_get_width (var), hash);
+                }
+
+              struct freq *f;
+              HMAP_FOR_EACH_WITH_HASH (f, struct freq, node, hash, &ft->data)
+                {
+                  for (size_t k = 0; k < ft->vars.n; k++)
+                    {
+                      const struct variable *var = ft->vars.vars[k];
+                      if (!value_equal (case_data (c, var), &f->values[k],
+                                        var_get_width (var)))
+                        goto next_hash_node;
+                    }
+
+                  f->count += weight;
+                  goto next_ft;
+
+                next_hash_node: ;
+                }
+
+              f = xmalloc (table_entry_size (ft->vars.n));
+              f->count = weight;
+              for (size_t k = 0; k < ft->vars.n; k++)
+                {
+                  const struct variable *var = ft->vars.vars[k];
+                  value_clone (&f->values[k], case_data (c, var),
+                               var_get_width (var));
+                }
+              hmap_insert (&ft->data, &f->node, hash);
+
+            next_ft: ;
+            }
+        }
+    }
+  casereader_destroy (input);
+
+  for (size_t i = 0; i < ct->n_tables; i++)
+    {
+      struct ctables_table *t = &ct->tables[i];
+
+      for (size_t j = 0; j < t->n_fts; j++)
+        {
+          struct ctables_freqtab *ft = t->fts[j];
+          ft->sorted = xnmalloc (ft->data.count, sizeof *ft->sorted);
+
+          struct freq *f;
+          size_t n = 0;
+          HMAP_FOR_EACH (f, struct freq, node, &ft->data)
+            ft->sorted[n++] = f;
+          sort (ft->sorted, ft->data.count, sizeof *ft->sorted,
+                compare_freq_3way, &ft->vars);
+        }
+    }
+
+  for (size_t i = 0; i < ct->n_tables; i++)
+    {
+      struct ctables_table *t = &ct->tables[i];
+
+      for (size_t j = 0; j < t->n_fts; j++)
+        {
+          struct ctables_freqtab *ft = t->fts[j];
+          struct freq *f, *next;
+          HMAP_FOR_EACH_SAFE (f, next, struct freq, node, &ft->data)
+            {
+              hmap_delete (&ft->data, &f->node);
+              for (size_t k = 0; k < ft->vars.n; k++)
+                {
+                  const struct variable *var = ft->vars.vars[k];
+                  value_destroy (&f->values[k], var_get_width (var));
+                }
+              free (f);
+            }
+          hmap_destroy (&ft->data);
+          free (ft->sorted);
+          var_array_uninit (&ft->vars);
+          free (ft);
+        }
+      free (t->fts);
+    }
+
+  return proc_commit (ds);
+}
+
 int
 cmd_ctables (struct lexer *lexer, struct dataset *ds)
 {
@@ -1526,7 +1797,6 @@ cmd_ctables (struct lexer *lexer, struct dataset *ds)
       if (!lex_force_match (lexer, T_SLASH))
         break;
 
-      /* XXX Validate axes. */
       while (!lex_match_id (lexer, "TABLE") && lex_token (lexer) != T_ENDCMD)
         {
           if (lex_match_id (lexer, "SLABELS"))
@@ -1857,10 +2127,19 @@ cmd_ctables (struct lexer *lexer, struct dataset *ds)
               goto error;
             }
         }
+
+      if (t->row_labels != CTLP_NORMAL && t->col_labels != CTLP_NORMAL)
+        {
+          msg (SE, _("ROWLABELS and COLLABELS may not both be specified."));
+          goto error;
+        }
+
     }
   while (lex_token (lexer) != T_ENDCMD);
+
+  bool ok = ctables_execute (ds, ct);
   ctables_destroy (ct);
-  return CMD_SUCCESS;
+  return ok ? CMD_SUCCESS : CMD_FAILURE;
 
 error:
   ctables_destroy (ct);