single totals work
authorBen Pfaff <blp@cs.stanford.edu>
Tue, 4 Jan 2022 06:52:57 +0000 (22:52 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 2 Apr 2022 01:48:55 +0000 (18:48 -0700)
src/language/stats/ctables.c

index b3ddee623f2bf4209ff3e2143723bb551e054042..ce635b15b8eaf67864c54c2befc24a357eafa66f 100644 (file)
@@ -394,10 +394,6 @@ struct ctables_category
       };
   };
 
-static const struct ctables_category *ctables_categories_match (
-  const struct ctables_categories *, const union value *,
-  const struct variable *);
-
 static void
 ctables_category_uninit (struct ctables_category *cat)
 {
@@ -2128,41 +2124,21 @@ ctables_categories_match (const struct ctables_categories *c,
   return var_is_value_missing (var, v) ? NULL : othernm;
 }
 
+static const struct ctables_category *
+ctables_categories_total (const struct ctables_categories *c)
+{
+  const struct ctables_category *total = &c->cats[c->n_cats - 1];
+  return total->type == CCT_TOTAL ? total : NULL;
+}
+
 static void
-ctables_cell_insert (struct ctables_table *t,
-                     const struct ccase *c,
-                     size_t ir, size_t ic, size_t il,
-                     double weight)
+ctables_cell_insert__ (struct ctables_table *t, const struct ccase *c,
+                       size_t ix[PIVOT_N_AXES],
+                       const struct ctables_category *cats[PIVOT_N_AXES][10],
+                       double weight)
 {
-  size_t ix[PIVOT_N_AXES] = {
-    [PIVOT_AXIS_ROW] = ir,
-    [PIVOT_AXIS_COLUMN] = ic,
-    [PIVOT_AXIS_LAYER] = il,
-  };
   const struct var_array *ss = &t->vaas[t->summary_axis].vas[ix[t->summary_axis]];
 
-  const struct ctables_category *cats[PIVOT_N_AXES][10];
-  for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
-    {
-      const struct var_array *va = &t->vaas[a].vas[ix[a]];
-      for (size_t i = 0; i < va->n; i++)
-        {
-          if (i == va->scale_idx)
-            continue;
-
-          const struct variable *var = va->vars[i];
-          const union value *value = case_data (c, var);
-
-          if (var_is_numeric (var) && value->f == SYSMIS)
-            return;
-
-          cats[a][i] = ctables_categories_match (
-            t->categories[var_get_dict_index (var)], value, var);
-          if (!cats[a][i])
-            return;
-        }
-    }
-
   size_t hash = 0;
   for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
     {
@@ -2170,8 +2146,12 @@ ctables_cell_insert (struct ctables_table *t,
       hash = hash_int (ix[a], hash);
       for (size_t i = 0; i < va->n; i++)
         if (i != va->scale_idx)
-          hash = value_hash (case_data (c, va->vars[i]),
-                             var_get_width (va->vars[i]), hash);
+          {
+            hash = hash_pointer (cats[a][i], hash);
+            if (cats[a][i]->type != CCT_TOTAL)
+              hash = value_hash (case_data (c, va->vars[i]),
+                                 var_get_width (va->vars[i]), hash);
+          }
     }
 
   struct ctables_cell *f;
@@ -2184,9 +2164,11 @@ ctables_cell_insert (struct ctables_table *t,
             goto not_equal;
           for (size_t i = 0; i < va->n; i++)
             if (i != va->scale_idx
-                && !value_equal (case_data (c, va->vars[i]),
-                                 &f->axes[a].cvs[i].value,
-                                 var_get_width (va->vars[i])))
+                && (cats[a][i] != f->axes[a].cvs[i].category
+                    || (cats[a][i]->type != CCT_TOTAL
+                        && !value_equal (case_data (c, va->vars[i]),
+                                         &f->axes[a].cvs[i].value,
+                                         var_get_width (va->vars[i])))))
                 goto not_equal;
         }
 
@@ -2225,6 +2207,65 @@ summarize:
     f->domains[dt]->valid += weight;
 }
 
+static void
+ctables_cell_insert (struct ctables_table *t,
+                     const struct ccase *c,
+                     size_t ir, size_t ic, size_t il,
+                     double weight)
+{
+  size_t ix[PIVOT_N_AXES] = {
+    [PIVOT_AXIS_ROW] = ir,
+    [PIVOT_AXIS_COLUMN] = ic,
+    [PIVOT_AXIS_LAYER] = il,
+  };
+
+  const struct ctables_category *cats[PIVOT_N_AXES][10];
+  for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
+    {
+      const struct var_array *va = &t->vaas[a].vas[ix[a]];
+      for (size_t i = 0; i < va->n; i++)
+        {
+          if (i == va->scale_idx)
+            continue;
+
+          const struct variable *var = va->vars[i];
+          const union value *value = case_data (c, var);
+
+          if (var_is_numeric (var) && value->f == SYSMIS)
+            return;
+
+          cats[a][i] = ctables_categories_match (
+            t->categories[var_get_dict_index (var)], value, var);
+          if (!cats[a][i])
+            return;
+        }
+    }
+
+  ctables_cell_insert__ (t, c, ix, cats, weight);
+
+  for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
+    {
+      const struct var_array *va = &t->vaas[a].vas[ix[a]];
+      for (size_t i = 0; i < va->n; i++)
+        {
+          if (i == va->scale_idx)
+            continue;
+
+          const struct variable *var = va->vars[i];
+
+          const struct ctables_category *total = ctables_categories_total (
+            t->categories[var_get_dict_index (var)]);
+          if (total)
+            {
+              const struct ctables_category *save = cats[a][i];
+              cats[a][i] = total;
+              ctables_cell_insert__ (t, c, ix, cats, weight);
+              cats[a][i] = save;
+            }
+        }
+    }
+}
+
 static bool
 ctables_execute (struct dataset *ds, struct ctables *ct)
 {
@@ -2409,9 +2450,11 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                     {
                       for (; n_common < va->n; n_common++)
                         if (n_common != va->scale_idx
-                            && !value_equal (&prev->axes[a].cvs[n_common].value,
-                                             &f->axes[a].cvs[n_common].value,
-                                             var_get_type (va->vars[n_common])))
+                            && (prev->axes[a].cvs[n_common].category
+                                != f->axes[a].cvs[n_common].category
+                                || !value_equal (&prev->axes[a].cvs[n_common].value,
+                                                 &f->axes[a].cvs[n_common].value,
+                                                 var_get_type (va->vars[n_common]))))
                           break;
                     }
                   else
@@ -2439,10 +2482,12 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                   struct pivot_category *parent = k > 0 ? groups[k - 1] : top;
 
                   struct pivot_value *label
-                    = (k != va->scale_idx
-                       ? pivot_value_new_var_value (va->vars[k],
-                                                    &f->axes[a].cvs[k].value)
-                       : NULL);
+                    = (k == va->scale_idx ? NULL
+                       : f->axes[a].cvs[k].category->type == CCT_TOTAL
+                       ? pivot_value_new_user_text (f->axes[a].cvs[k].category->total_label,
+                                                    SIZE_MAX)
+                       : pivot_value_new_var_value (va->vars[k],
+                                                    &f->axes[a].cvs[k].value));
                   if (k == va->n - 1)
                     {
                       if (a == t->summary_axis)