subtotal and hsubtotal work
authorBen Pfaff <blp@cs.stanford.edu>
Thu, 6 Jan 2022 04:55:45 +0000 (20:55 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Thu, 13 Jan 2022 05:52:27 +0000 (21:52 -0800)
src/language/stats/ctables.c

index 18fd30b683b3a438dd41036f2a596773a18b53cc..74afa0136da77b7dd10570f38c8be122f247d827 100644 (file)
@@ -177,9 +177,11 @@ struct ctables_cell
        axes (except the scalar variable, if any). */
     struct hmap_node node;
 
-    /* The domains that contains this cell. */
+    /* The domains that contain this cell. */
     struct ctables_domain *domains[N_CTDTS];
 
+    bool hide;
+
     struct
       {
         size_t vaa_idx;
@@ -373,6 +375,8 @@ struct ctables_category
       }
     type;
 
+    struct ctables_category *subtotal;
+
     union
       {
         double number;          /* CCT_NUMBER. */
@@ -1190,14 +1194,18 @@ ctables_table_parse_categories (struct lexer *lexer, struct dictionary *dict,
               return false;
             }
 
-          if ((cat->type == CCT_SUBTOTAL || cat->type == CCT_HSUBTOTAL)
-              && lex_match (lexer, T_EQUALS))
+          if (cat->type == CCT_SUBTOTAL || cat->type == CCT_HSUBTOTAL)
             {
-              if (!lex_force_string (lexer))
-                return false;
+              if (lex_match (lexer, T_EQUALS))
+                {
+                  if (!lex_force_string (lexer))
+                    return false;
 
-              cat->total_label = ss_xstrdup (lex_tokss (lexer));
-              lex_get (lexer);
+                  cat->total_label = ss_xstrdup (lex_tokss (lexer));
+                  lex_get (lexer);
+                }
+              else
+                cat->total_label = xstrdup (_("Subtotal"));
             }
 
           c->n_cats++;
@@ -1362,6 +1370,35 @@ ctables_table_parse_categories (struct lexer *lexer, struct dictionary *dict,
       };
     }
 
+  struct ctables_category *subtotal = NULL;
+  for (size_t i = totals_before ? 0 : c->n_cats;
+       totals_before ? i < c->n_cats : i-- > 0;
+       totals_before ? i++ : 0)
+    {
+      struct ctables_category *cat = &c->cats[i];
+      switch (cat->type)
+        {
+        case CCT_NUMBER:
+        case CCT_STRING:
+        case CCT_RANGE:
+        case CCT_MISSING:
+        case CCT_OTHERNM:
+          cat->subtotal = subtotal;
+          break;
+
+        case CCT_SUBTOTAL:
+        case CCT_HSUBTOTAL:
+          subtotal = cat;
+          break;
+
+        case CCT_TOTAL:
+        case CCT_VALUE:
+        case CCT_LABEL:
+        case CCT_FUNCTION:
+          break;
+        }
+    }
+
   return true;
 }
 
@@ -1785,7 +1822,7 @@ ctables_summary_add (union ctables_summary *s,
 }
 
 static double
-ctables_summary_value (const struct ctables_cell *f,
+ctables_summary_value (const struct ctables_cell *cell,
                        union ctables_summary *s,
                        const struct ctables_summary_spec *ss)
 {
@@ -1796,25 +1833,25 @@ ctables_summary_value (const struct ctables_cell *f,
       return s->valid;
 
     case CTSF_SUBTABLEPCT_COUNT:
-      return f->domains[CTDT_SUBTABLE]->valid ? s->valid / f->domains[CTDT_SUBTABLE]->valid * 100 : SYSMIS;
+      return cell->domains[CTDT_SUBTABLE]->valid ? s->valid / cell->domains[CTDT_SUBTABLE]->valid * 100 : SYSMIS;
 
     case CTSF_ROWPCT_COUNT:
-      return f->domains[CTDT_ROW]->valid ? s->valid / f->domains[CTDT_ROW]->valid * 100 : SYSMIS;
+      return cell->domains[CTDT_ROW]->valid ? s->valid / cell->domains[CTDT_ROW]->valid * 100 : SYSMIS;
 
     case CTSF_COLPCT_COUNT:
-      return f->domains[CTDT_COL]->valid ? s->valid / f->domains[CTDT_COL]->valid * 100 : SYSMIS;
+      return cell->domains[CTDT_COL]->valid ? s->valid / cell->domains[CTDT_COL]->valid * 100 : SYSMIS;
 
     case CTSF_TABLEPCT_COUNT:
-      return f->domains[CTDT_TABLE]->valid ? s->valid / f->domains[CTDT_TABLE]->valid * 100 : SYSMIS;
+      return cell->domains[CTDT_TABLE]->valid ? s->valid / cell->domains[CTDT_TABLE]->valid * 100 : SYSMIS;
 
     case CTSF_LAYERPCT_COUNT:
-      return f->domains[CTDT_LAYER]->valid ? s->valid / f->domains[CTDT_LAYER]->valid * 100 : SYSMIS;
+      return cell->domains[CTDT_LAYER]->valid ? s->valid / cell->domains[CTDT_LAYER]->valid * 100 : SYSMIS;
 
     case CTSF_LAYERROWPCT_COUNT:
-      return f->domains[CTDT_LAYERROW]->valid ? s->valid / f->domains[CTDT_LAYERROW]->valid * 100 : SYSMIS;
+      return cell->domains[CTDT_LAYERROW]->valid ? s->valid / cell->domains[CTDT_LAYERROW]->valid * 100 : SYSMIS;
 
     case CTSF_LAYERCOLPCT_COUNT:
-      return f->domains[CTDT_LAYERCOL]->valid ? s->valid / f->domains[CTDT_LAYERCOL]->valid * 100 : SYSMIS;
+      return cell->domains[CTDT_LAYERCOL]->valid ? s->valid / cell->domains[CTDT_LAYERCOL]->valid * 100 : SYSMIS;
 
     case CTSF_ROWPCT_VALIDN:
     case CTSF_COLPCT_VALIDN:
@@ -2026,19 +2063,19 @@ ctables_cell_compare_3way (const void *a_, const void *b_, const void *aux_)
  */
 
 static struct ctables_domain *
-ctables_domain_insert (struct ctables_table *t, struct ctables_cell *f,
+ctables_domain_insert (struct ctables_table *t, struct ctables_cell *cell,
                        enum ctables_domain_type domain)
 {
   size_t hash = 0;
   for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
     {
-      size_t idx = f->axes[a].vaa_idx;
+      size_t idx = cell->axes[a].vaa_idx;
       const struct var_array *va = &t->vaas[a].vas[idx];
       hash = hash_int (idx, hash);
       for (size_t i = 0; i < va->n_domains[domain]; i++)
         {
           size_t v_idx = va->domains[domain][i];
-          hash = value_hash (&f->axes[a].cvs[v_idx].value,
+          hash = value_hash (&cell->axes[a].cvs[v_idx].value,
                              var_get_width (va->vars[v_idx]), hash);
         }
     }
@@ -2049,7 +2086,7 @@ ctables_domain_insert (struct ctables_table *t, struct ctables_cell *f,
       const struct ctables_cell *df = d->example;
       for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
         {
-          size_t idx = f->axes[a].vaa_idx;
+          size_t idx = cell->axes[a].vaa_idx;
           if (idx != df->axes[a].vaa_idx)
             goto not_equal;
 
@@ -2058,7 +2095,7 @@ ctables_domain_insert (struct ctables_table *t, struct ctables_cell *f,
             {
               size_t v_idx = va->domains[domain][i];
               if (!value_equal (&df->axes[a].cvs[v_idx].value,
-                                &f->axes[a].cvs[v_idx].value,
+                                &cell->axes[a].cvs[v_idx].value,
                                 var_get_width (va->vars[v_idx])))
                 goto not_equal;
             }
@@ -2069,7 +2106,7 @@ ctables_domain_insert (struct ctables_table *t, struct ctables_cell *f,
     }
 
   d = xmalloc (sizeof *d);
-  *d = (struct ctables_domain) { .example = f };
+  *d = (struct ctables_domain) { .example = cell };
   hmap_insert (&t->domains[domain], &d->node, hash);
   return d;
 }
@@ -2151,26 +2188,30 @@ ctables_cell_insert__ (struct ctables_table *t, const struct ccase *c,
         if (i != va->scale_idx)
           {
             hash = hash_pointer (cats[a][i], hash);
-            if (cats[a][i]->type != CCT_TOTAL)
+            if (cats[a][i]->type != CCT_TOTAL
+                && cats[a][i]->type != CCT_SUBTOTAL
+                && cats[a][i]->type != CCT_HSUBTOTAL)
               hash = value_hash (case_data (c, va->vars[i]),
                                  var_get_width (va->vars[i]), hash);
           }
     }
 
-  struct ctables_cell *f;
-  HMAP_FOR_EACH_WITH_HASH (f, struct ctables_cell, node, hash, &t->cells)
+  struct ctables_cell *cell;
+  HMAP_FOR_EACH_WITH_HASH (cell, struct ctables_cell, node, hash, &t->cells)
     {
       for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
         {
           const struct var_array *va = &t->vaas[a].vas[ix[a]];
-          if (f->axes[a].vaa_idx != ix[a])
+          if (cell->axes[a].vaa_idx != ix[a])
             goto not_equal;
           for (size_t i = 0; i < va->n; i++)
             if (i != va->scale_idx
-                && (cats[a][i] != f->axes[a].cvs[i].category
+                && (cats[a][i] != cell->axes[a].cvs[i].category
                     || (cats[a][i]->type != CCT_TOTAL
+                        && cats[a][i]->type != CCT_SUBTOTAL
+                        && cats[a][i]->type != CCT_HSUBTOTAL
                         && !value_equal (case_data (c, va->vars[i]),
-                                         &f->axes[a].cvs[i].value,
+                                         &cell->axes[a].cvs[i].value,
                                          var_get_width (va->vars[i])))))
                 goto not_equal;
         }
@@ -2180,34 +2221,42 @@ ctables_cell_insert__ (struct ctables_table *t, const struct ccase *c,
     not_equal: ;
     }
 
-  f = xmalloc (sizeof *f);
+  cell = xmalloc (sizeof *cell);
+  cell->hide = false;
   for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
     {
       const struct var_array *va = &t->vaas[a].vas[ix[a]];
-      f->axes[a].vaa_idx = ix[a];
-      f->axes[a].cvs = (va->n
-                        ? xnmalloc (va->n, sizeof *f->axes[a].cvs)
+      cell->axes[a].vaa_idx = ix[a];
+      cell->axes[a].cvs = (va->n
+                        ? xnmalloc (va->n, sizeof *cell->axes[a].cvs)
                         : NULL);
       for (size_t i = 0; i < va->n; i++)
         {
-          f->axes[a].cvs[i].category = cats[a][i];
-          value_clone (&f->axes[a].cvs[i].value, case_data (c, va->vars[i]),
+          if (i != va->scale_idx)
+            {
+              const struct ctables_category *subtotal = cats[a][i]->subtotal;
+              if (subtotal && subtotal->type == CCT_HSUBTOTAL)
+                cell->hide = true;
+            }
+
+          cell->axes[a].cvs[i].category = cats[a][i];
+          value_clone (&cell->axes[a].cvs[i].value, case_data (c, va->vars[i]),
                        var_get_width (va->vars[i]));
         }
     }
-  f->summaries = xmalloc (ss->n_summaries * sizeof *f->summaries);
+  cell->summaries = xmalloc (ss->n_summaries * sizeof *cell->summaries);
   for (size_t i = 0; i < ss->n_summaries; i++)
-    ctables_summary_init (&f->summaries[i], &ss->summaries[i]);
+    ctables_summary_init (&cell->summaries[i], &ss->summaries[i]);
   for (enum ctables_domain_type dt = 0; dt < N_CTDTS; dt++)
-    f->domains[dt] = ctables_domain_insert (t, f, dt);
-  hmap_insert (&t->cells, &f->node, hash);
+    cell->domains[dt] = ctables_domain_insert (t, cell, dt);
+  hmap_insert (&t->cells, &cell->node, hash);
 
 summarize:
   for (size_t i = 0; i < ss->n_summaries; i++)
-    ctables_summary_add (&f->summaries[i], &ss->summaries[i], ss->summary_var,
+    ctables_summary_add (&cell->summaries[i], &ss->summaries[i], ss->summary_var,
                          case_data (c, ss->summary_var), weight);
   for (enum ctables_domain_type dt = 0; dt < N_CTDTS; dt++)
-    f->domains[dt]->valid += weight;
+    cell->domains[dt]->valid += weight;
 }
 
 static void
@@ -2279,6 +2328,24 @@ ctables_cell_insert (struct ctables_table *t,
   ctables_cell_insert__ (t, c, ix, cats, weight);
 
   recurse_totals (t, c, ix, cats, weight, 0, 0);
+
+  for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
+    {
+      const struct var_array *va = &t->vaas[a].vas[ix[a]];
+      for (size_t i = 0; i < va->n; i++)
+        {
+          if (i == va->scale_idx)
+            continue;
+
+          const struct ctables_category *save = cats[a][i];
+          if (save->subtotal)
+            {
+              cats[a][i] = save->subtotal;
+              ctables_cell_insert__ (t, c, ix, cats, weight);
+              cats[a][i] = save;
+            }
+        }
+    }
 }
 
 static bool
@@ -2434,11 +2501,12 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
 
           struct ctables_cell **sorted = xnmalloc (t->cells.count, sizeof *sorted);
 
-          struct ctables_cell *f;
+          struct ctables_cell *cell;
           size_t n = 0;
-          HMAP_FOR_EACH (f, struct ctables_cell, node, &t->cells)
-            sorted[n++] = f;
-          assert (n == t->cells.count);
+          HMAP_FOR_EACH (cell, struct ctables_cell, node, &t->cells)
+            if (!cell->hide)
+              sorted[n++] = cell;
+          assert (n <= t->cells.count);
 
           struct ctables_cell_sort_aux aux = { .t = t, .a = a };
           sort (sorted, n, sizeof *sorted, ctables_cell_compare_3way, &aux);
@@ -2453,22 +2521,22 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
           int prev_leaf = 0;
           for (size_t j = 0; j < n; j++)
             {
-              struct ctables_cell *f = sorted[j];
-              const struct var_array *va = &t->vaas[a].vas[f->axes[a].vaa_idx];
+              struct ctables_cell *cell = sorted[j];
+              const struct var_array *va = &t->vaas[a].vas[cell->axes[a].vaa_idx];
 
               size_t n_common = 0;
               bool new_subtable = false;
               if (j > 0)
                 {
                   struct ctables_cell *prev = sorted[j - 1];
-                  if (prev->axes[a].vaa_idx == f->axes[a].vaa_idx)
+                  if (prev->axes[a].vaa_idx == cell->axes[a].vaa_idx)
                     {
                       for (; n_common < va->n; n_common++)
                         if (n_common != va->scale_idx
                             && (prev->axes[a].cvs[n_common].category
-                                != f->axes[a].cvs[n_common].category
+                                != cell->axes[a].cvs[n_common].category
                                 || !value_equal (&prev->axes[a].cvs[n_common].value,
-                                                 &f->axes[a].cvs[n_common].value,
+                                                 &cell->axes[a].cvs[n_common].value,
                                                  var_get_type (va->vars[n_common]))))
                           break;
                     }
@@ -2488,7 +2556,7 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                 }
               if (n_common == va->n)
                 {
-                  f->axes[a].leaf = prev_leaf;
+                  cell->axes[a].leaf = prev_leaf;
                   continue;
                 }
 
@@ -2498,11 +2566,13 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
 
                   struct pivot_value *label
                     = (k == va->scale_idx ? NULL
-                       : f->axes[a].cvs[k].category->type == CCT_TOTAL
-                       ? pivot_value_new_user_text (f->axes[a].cvs[k].category->total_label,
+                       : (cell->axes[a].cvs[k].category->type == CCT_TOTAL
+                          || cell->axes[a].cvs[k].category->type == CCT_SUBTOTAL
+                          || cell->axes[a].cvs[k].category->type == CCT_HSUBTOTAL)
+                       ? pivot_value_new_user_text (cell->axes[a].cvs[k].category->total_label,
                                                     SIZE_MAX)
                        : pivot_value_new_var_value (va->vars[k],
-                                                    &f->axes[a].cvs[k].value));
+                                                    &cell->axes[a].cvs[k].value));
                   if (k == va->n - 1)
                     {
                       if (a == t->summary_axis)
@@ -2538,15 +2608,18 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                   groups[k] = parent;
                 }
 
-              f->axes[a].leaf = prev_leaf;
+              cell->axes[a].leaf = prev_leaf;
             }
           free (sorted);
           free (groups);
         }
-      struct ctables_cell *f;
-      HMAP_FOR_EACH (f, struct ctables_cell, node, &t->cells)
+      struct ctables_cell *cell;
+      HMAP_FOR_EACH (cell, struct ctables_cell, node, &t->cells)
         {
-          const struct var_array *ss = &t->vaas[t->summary_axis].vas[f->axes[t->summary_axis].vaa_idx];
+          if (cell->hide)
+            continue;
+
+          const struct var_array *ss = &t->vaas[t->summary_axis].vas[cell->axes[t->summary_axis].vaa_idx];
           for (size_t j = 0; j < ss->n_summaries; j++)
             {
               size_t dindexes[3];
@@ -2555,13 +2628,13 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
               for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
                 if (d[a])
                   {
-                    int leaf = f->axes[a].leaf;
+                    int leaf = cell->axes[a].leaf;
                     if (a == t->summary_axis)
                       leaf += j;
                     dindexes[n_dindexes++] = leaf;
                   }
 
-              double d = ctables_summary_value (f, &f->summaries[j], &ss->summaries[j]);
+              double d = ctables_summary_value (cell, &cell->summaries[j], &ss->summaries[j]);
               struct pivot_value *value = pivot_value_new_number (d);
               value->numeric.format = ss->summaries[j].format;
               pivot_table_put (pt, dindexes, n_dindexes, value);
@@ -3233,6 +3306,9 @@ cmd_ctables (struct lexer *lexer, struct dataset *ds)
                                    "SIGTEST", "COMPARETEST");
               goto error;
             }
+
+          if (!lex_match (lexer, T_SLASH))
+            break;
         }
 
       if (t->row_labels != CTLP_NORMAL && t->col_labels != CTLP_NORMAL)