postcomputes sort-of work!
authorBen Pfaff <blp@cs.stanford.edu>
Mon, 14 Feb 2022 00:04:25 +0000 (16:04 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 2 Apr 2022 01:48:55 +0000 (18:48 -0700)
src/language/stats/ctables.c

index 00275303d5952f93dcda6923cf4e0cd98281c802..bf3425b49e67c369dc0847fabae2e56ed706afef 100644 (file)
@@ -17,6 +17,7 @@
 #include <config.h>
 
 #include <math.h>
+#include <errno.h>
 
 #include "data/casereader.h"
 #include "data/casewriter.h"
@@ -198,6 +199,7 @@ struct ctables_cell
     struct ctables_domain *domains[N_CTDTS];
 
     bool hide;
+    bool postcompute;
     enum ctables_summary_variant sv;
 
     struct ctables_cell_axis
@@ -2727,8 +2729,6 @@ static struct ctables_cell *
 ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
                        const struct ctables_category *cats[PIVOT_N_AXES][10])
 {
-  const struct ctables_nest *ss = s->nests[s->table->summary_axis];
-
   size_t hash = 0;
   enum ctables_summary_variant sv = CSV_CELL;
   for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
@@ -2775,6 +2775,7 @@ ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
   cell->hide = false;
   cell->sv = sv;
   cell->contributes_to_domains = true;
+  cell->postcompute = false;
   for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
     {
       const struct ctables_nest *nest = s->nests[a];
@@ -2794,6 +2795,8 @@ ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
                   || cat->type == CCT_SUBTOTAL
                   || cat->type == CCT_POSTCOMPUTE)
                 cell->contributes_to_domains = false;
+              if (cat->type == CCT_POSTCOMPUTE)
+                cell->postcompute = true;
             }
 
           cell->axes[a].cvs[i].category = cat;
@@ -2802,6 +2805,7 @@ ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
         }
     }
 
+  const struct ctables_nest *ss = s->nests[s->table->summary_axis];
   const struct ctables_summary_spec_set *specs = &ss->specs[cell->sv];
   cell->summaries = xmalloc (specs->n * sizeof *cell->summaries);
   for (size_t i = 0; i < specs->n; i++)
@@ -3037,6 +3041,233 @@ ctables_table_add_section (struct ctables_table *t, enum pivot_axis_type a,
     }
 }
 
+static double
+ctpo_add (double a, double b)
+{
+  return a + b;
+}
+
+static double
+ctpo_sub (double a, double b)
+{
+  return a - b;
+}
+
+static double
+ctpo_mul (double a, double b)
+{
+  return a * b;
+}
+
+static double
+ctpo_div (double a, double b)
+{
+  return b ? a / b : SYSMIS;
+}
+
+static double
+ctpo_pow (double a, double b)
+{
+  int save_errno = errno;
+  errno = 0;
+  double result = pow (a, b);
+  if (errno)
+    result = SYSMIS;
+  errno = save_errno;
+  return result;
+}
+
+static double
+ctpo_neg (double a, double b UNUSED)
+{
+  return -a;
+}
+
+struct ctables_pcexpr_evaluate_ctx
+  {
+    const struct ctables_cell *cell;
+    const struct ctables_section *section;
+    const struct ctables_categories *cats;
+    enum pivot_axis_type pc_a;
+    size_t pc_a_idx;
+  };
+
+static double ctables_pcexpr_evaluate (
+  const struct ctables_pcexpr_evaluate_ctx *, const struct ctables_pcexpr *);
+
+static double
+ctables_pcexpr_evaluate_nonterminal (
+  const struct ctables_pcexpr_evaluate_ctx *ctx,
+  const struct ctables_pcexpr *e, size_t n_args,
+  double evaluate (double, double))
+{
+  double args[2] = { 0, 0 };
+  for (size_t i = 0; i < n_args; i++)
+    {
+      args[i] = ctables_pcexpr_evaluate (ctx, e->subs[i]);
+      if (!isfinite (args[i]) || args[i] == SYSMIS)
+        return SYSMIS;
+    }
+  return evaluate (args[0], args[1]);
+}
+
+static double
+ctables_pcexpr_evaluate_category (const struct ctables_pcexpr_evaluate_ctx *ctx,
+                                  const struct ctables_category *cat)
+{
+  const struct ctables_section *s = ctx->section;
+
+  size_t hash = 0;
+  for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
+    {
+      const struct ctables_nest *nest = s->nests[a];
+      for (size_t i = 0; i < nest->n; i++)
+        if (a == ctx->pc_a && i == ctx->pc_a_idx)
+          {
+            /* XXX anything other than just a constant.... need a higher level
+               loop to go through occurrences */
+            hash = hash_pointer (cat, hash);
+            hash = hash_double (cat->number, hash);
+          }
+        else if (i != nest->scale_idx)
+          {
+            const struct ctables_cell_value *cv = &ctx->cell->axes[a].cvs[i];
+            hash = hash_pointer (cv->category, hash);
+            if (cv->category->type != CCT_TOTAL
+                && cv->category->type != CCT_SUBTOTAL
+                && cv->category->type != CCT_POSTCOMPUTE)
+              hash = value_hash (&cv->value,
+                                 var_get_width (nest->vars[i]), hash);
+          }
+    }
+
+  struct ctables_cell *tc;
+  HMAP_FOR_EACH_WITH_HASH (tc, struct ctables_cell, node, hash, &s->cells)
+    {
+      for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
+        {
+          const struct ctables_nest *nest = s->nests[a];
+          for (size_t i = 0; i < nest->n; i++)
+            {
+              const struct ctables_cell_value *p_cv = &ctx->cell->axes[a].cvs[i];
+              const struct ctables_cell_value *t_cv = &tc->axes[a].cvs[i];
+
+              if (i == nest->scale_idx)
+                {
+                  /* Nothing to do. */
+                }
+              else if (a == ctx->pc_a && i == ctx->pc_a_idx)
+                {
+                  /* XXX anything other than a constant.... */
+                  if (t_cv->category != cat || t_cv->value.f != cat->number)
+                    goto not_equal;
+                }
+              else if (p_cv->category != t_cv->category
+                       || (p_cv->category->type != CCT_TOTAL
+                           && p_cv->category->type != CCT_SUBTOTAL
+                           && p_cv->category->type != CCT_POSTCOMPUTE
+                           && !value_equal (&p_cv->value,
+                                            &t_cv->value,
+                                            var_get_width (nest->vars[i]))))
+                goto not_equal;
+            }
+        }
+
+      goto found;
+
+    not_equal: ;
+    }
+  return 0;
+
+found: ;
+  const struct ctables_table *t = s->table;
+  const struct ctables_nest *specs_nest = s->nests[t->summary_axis];
+  const struct ctables_summary_spec_set *specs = &specs_nest->specs[tc->sv];
+  size_t j = 0 /* XXX */;
+  return ctables_summary_value (tc, &tc->summaries[j], &specs->specs[j]);
+}
+
+static double
+ctables_pcexpr_evaluate (const struct ctables_pcexpr_evaluate_ctx *ctx,
+                         const struct ctables_pcexpr *e)
+{
+  switch (e->op)
+    {
+    case CTPO_CONSTANT:
+      return e->number;
+
+    case CTPO_CAT_NUMBER:
+    case CTPO_CAT_STRING:
+    case CTPO_CAT_RANGE:
+    case CTPO_CAT_MISSING:
+    case CTPO_CAT_OTHERNM:
+    case CTPO_CAT_SUBTOTAL:
+    case CTPO_CAT_TOTAL:
+      {
+        struct ctables_category *cat = ctables_find_category_for_postcompute (
+          ctx->cats, e);
+        assert (cat != NULL);
+
+        return ctables_pcexpr_evaluate_category (ctx, cat);
+      }
+
+    case CTPO_ADD:
+      return ctables_pcexpr_evaluate_nonterminal (ctx, e, 2, ctpo_add);
+
+    case CTPO_SUB:
+      return ctables_pcexpr_evaluate_nonterminal (ctx, e, 2, ctpo_sub);
+
+    case CTPO_MUL:
+      return ctables_pcexpr_evaluate_nonterminal (ctx, e, 2, ctpo_mul);
+
+    case CTPO_DIV:
+      return ctables_pcexpr_evaluate_nonterminal (ctx, e, 2, ctpo_div);
+
+    case CTPO_POW:
+      return ctables_pcexpr_evaluate_nonterminal (ctx, e, 2, ctpo_pow);
+
+    case CTPO_NEG:
+      return ctables_pcexpr_evaluate_nonterminal (ctx, e, 1, ctpo_neg);
+    }
+
+  NOT_REACHED ();
+}
+
+static double
+ctables_cell_calculate_postcompute (const struct ctables_section *s,
+                                    const struct ctables_cell *cell)
+{
+  enum pivot_axis_type pc_a;
+  size_t pc_a_idx;
+  const struct ctables_postcompute *pc;
+  for (pc_a = 0; ; pc_a++)
+    {
+      assert (pc_a < PIVOT_N_AXES);
+      for (pc_a_idx = 0; pc_a_idx < s->nests[pc_a]->n; pc_a_idx++)
+        {
+          const struct ctables_cell_value *cv = &cell->axes[pc_a].cvs[pc_a_idx];
+          if (cv->category->type == CCT_POSTCOMPUTE)
+            {
+              pc = cv->category->pc;
+              goto found;
+            }
+        }
+    }
+found: ;
+
+  const struct variable *var = s->nests[pc_a]->vars[pc_a_idx];
+  const struct ctables_categories *cats = s->table->categories[
+    var_get_dict_index (var)];
+  struct ctables_pcexpr_evaluate_ctx ctx = {
+    .cell = cell,
+    .section = s,
+    .cats = cats,
+    .pc_a = pc_a,
+    .pc_a_idx = pc_a_idx,
+  };
+  return ctables_pcexpr_evaluate (&ctx, pc->expr);
+}
+
 static void
 ctables_table_output (struct ctables *ct, struct ctables_table *t)
 {
@@ -3321,7 +3552,9 @@ ctables_table_output (struct ctables *ct, struct ctables_table *t)
                     dindexes[n_dindexes++] = leaf;
                   }
 
-              double d = ctables_summary_value (cell, &cell->summaries[j], &specs->specs[j]);
+              double d = (cell->postcompute
+                          ? ctables_cell_calculate_postcompute (s, cell)
+                          : ctables_summary_value (cell, &cell->summaries[j], &specs->specs[j]));
               struct pivot_value *value = pivot_value_new_number (d);
               value->numeric.format = specs->specs[j].format;
               pivot_table_put (pt, dindexes, n_dindexes, value);