Allow totals to have different statistics
authorBen Pfaff <blp@cs.stanford.edu>
Sat, 15 Jan 2022 02:44:01 +0000 (18:44 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 2 Apr 2022 01:48:55 +0000 (18:48 -0700)
src/language/stats/ctables.c

index e5c1328f039dd3f9bcb3639c2cd497d967f3bc67..d5982a8c1e404d90f71821a189fbd25639d60434 100644 (file)
@@ -296,8 +296,8 @@ struct var_array
     size_t *domains[N_CTDTS];
     size_t n_domains[N_CTDTS];
 
-    struct ctables_summary_spec_set cell_summaries;
-    struct ctables_summary_spec_set total_summaries;
+    struct ctables_summary_spec_set cell_sss;
+    struct ctables_summary_spec_set total_sss;
   };
 
 struct var_array2
@@ -498,8 +498,8 @@ struct ctables_axis
           {
             struct ctables_var var;
             bool scale;
-            struct ctables_summary_spec_set cell_summaries;
-            struct ctables_summary_spec_set total_summaries;
+            struct ctables_summary_spec_set cell_sss;
+            struct ctables_summary_spec_set total_sss;
           };
 
         /* Nonterminals. */
@@ -641,8 +641,8 @@ ctables_axis_destroy (struct ctables_axis *axis)
   switch (axis->op)
     {
     case CTAO_VAR:
-      ctables_summary_spec_set_uninit (&axis->cell_summaries);
-      ctables_summary_spec_set_uninit (&axis->total_summaries);
+      ctables_summary_spec_set_uninit (&axis->cell_sss);
+      ctables_summary_spec_set_uninit (&axis->total_sss);
       break;
 
     case CTAO_STACK:
@@ -768,8 +768,8 @@ add_summary_spec (struct ctables_axis *axis,
           break;
         }
 
-      struct ctables_summary_spec_set *set = (totals ? &axis->total_summaries
-                                              : &axis->cell_summaries);
+      struct ctables_summary_spec_set *set = (totals ? &axis->total_sss
+                                              : &axis->cell_sss);
       if (set->n >= set->allocated)
         set->summaries = x2nrealloc (set->summaries, &set->allocated,
                                      sizeof *set->summaries);
@@ -929,6 +929,7 @@ ctables_axis_parse_postfix (struct ctables_axis_parse_ctx *ctx)
             {
               if (!lex_force_match (ctx->lexer, T_LBRACK))
                 goto error;
+              totals = true;
             }
         }
       else if (lex_force_match (ctx->lexer, T_RBRACK))
@@ -979,7 +980,7 @@ find_categorical_summary_spec (const struct ctables_axis *axis)
   if (!axis)
     return NULL;
   else if (axis->op == CTAO_VAR)
-    return !axis->scale && axis->cell_summaries.n ? axis : NULL;
+    return !axis->scale && axis->cell_sss.n ? axis : NULL;
   else
     {
       for (size_t i = 0; i < 2; i++)
@@ -1478,9 +1479,9 @@ nest_fts (struct var_array2 va0, struct var_array2 va1)
         assert (n == allocate);
 
         const struct var_array *summary_src;
-        if (!a->cell_summaries.var)
+        if (!a->cell_sss.var)
           summary_src = b;
-        else if (!b->cell_summaries.var)
+        else if (!b->cell_sss.var)
           summary_src = a;
         else
           NOT_REACHED ();
@@ -1490,8 +1491,8 @@ nest_fts (struct var_array2 va0, struct var_array2 va1)
                         : b->scale_idx != SIZE_MAX ? a->n + b->scale_idx
                         : SIZE_MAX),
           .n = n,
-          .cell_summaries = summary_src->cell_summaries,
-          .total_summaries = summary_src->total_summaries,
+          .cell_sss = summary_src->cell_sss,
+          .total_sss = summary_src->total_sss,
         };
       }
   var_array2_uninit (&va0);
@@ -1533,12 +1534,12 @@ enumerate_fts (enum pivot_axis_type axis_type, const struct ctables_axis *a)
         .n = 1,
         .scale_idx = a->scale ? 0 : SIZE_MAX,
       };
-      if (a->cell_summaries.n || a->scale)
+      if (a->cell_sss.n || a->scale)
         {
-          va->cell_summaries = a->cell_summaries;
-          va->total_summaries = a->total_summaries;
-          va->cell_summaries.var = a->var.var;
-          va->total_summaries.var = a->var.var;
+          va->cell_sss = a->cell_sss;
+          va->total_sss = a->total_sss;
+          va->cell_sss.var = a->var.var;
+          va->total_sss.var = a->var.var;
         }
       return (struct var_array2) { .vas = va, .n = 1 };
 
@@ -2282,7 +2283,8 @@ ctables_cell_insert__ (struct ctables_table *t, const struct ccase *c,
     }
 
   {
-    const struct ctables_summary_spec_set *sss = &ss->cell_summaries;
+    const struct ctables_summary_spec_set *sss
+      = (cell->total ? &ss->total_sss : &ss->cell_sss);
     cell->summaries = xmalloc (sss->n * sizeof *cell->summaries);
     for (size_t i = 0; i < sss->n; i++)
       ctables_summary_init (&cell->summaries[i], &sss->summaries[i]);
@@ -2292,7 +2294,8 @@ ctables_cell_insert__ (struct ctables_table *t, const struct ccase *c,
   hmap_insert (&t->cells, &cell->node, hash);
 
 summarize: ;
-  const struct ctables_summary_spec_set *sss = &ss->cell_summaries;
+  const struct ctables_summary_spec_set *sss
+    = (cell->total ? &ss->total_sss : &ss->cell_sss);
   for (size_t i = 0; i < sss->n; i++)
     ctables_summary_add (&cell->summaries[i], &sss->summaries[i], sss->var,
                          case_data (c, sss->var), weight);
@@ -2463,9 +2466,9 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
       for (size_t i = 0; i < t->vaas[t->summary_axis].n; i++)
         {
           struct var_array *va = &t->vaas[t->summary_axis].vas[i];
-          if (!va->cell_summaries.n)
+          if (!va->cell_sss.n)
             {
-              struct ctables_summary_spec_set *css = &va->cell_summaries;
+              struct ctables_summary_spec_set *css = &va->cell_sss;
               css->summaries = xmalloc (sizeof *css->summaries);
               css->n = 1;
 
@@ -2481,8 +2484,10 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
               if (!css->var)
                 css->var = va->vars[0];
 
-              va->total_summaries = va->cell_summaries;
+              va->total_sss = va->cell_sss;
             }
+          else if (!va->total_sss.n)
+            va->total_sss = va->cell_sss;
         }
     }
 
@@ -2623,10 +2628,12 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                         {
                           if (label)
                             parent = pivot_category_create_group__ (parent, label);
-                          for (size_t m = 0; m < va->cell_summaries.n; m++)
+                          const struct ctables_summary_spec_set *sss
+                            = cell->total ? &va->total_sss : &va->cell_sss;
+                          for (size_t m = 0; m < sss->n; m++)
                             {
                               int leaf = pivot_category_create_leaf (
-                                parent, pivot_value_new_text (va->cell_summaries.summaries[m].label));
+                                parent, pivot_value_new_text (sss->summaries[m].label));
                               if (m == 0)
                                 prev_leaf = leaf;
                             }
@@ -2663,7 +2670,8 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
           if (cell->hide)
             continue;
 
-          const struct ctables_summary_spec_set *sss = &t->vaas[t->summary_axis].vas[cell->axes[t->summary_axis].vaa_idx].cell_summaries;
+          const struct var_array *va = &t->vaas[t->summary_axis].vas[cell->axes[t->summary_axis].vaa_idx];
+          const struct ctables_summary_spec_set *sss = cell->total ? &va->total_sss : &va->cell_sss;
           for (size_t j = 0; j < sss->n; j++)
             {
               size_t dindexes[3];