scale variables work more sanely
authorBen Pfaff <blp@cs.stanford.edu>
Sun, 2 Jan 2022 03:04:45 +0000 (19:04 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 2 Apr 2022 01:48:55 +0000 (18:48 -0700)
src/language/stats/ctables.c

index 26837f1b93c61fae51bd946c5f2c6b4b7238e242..ca02ebbae8d43e7b6a79905040bec85a59f05a99 100644 (file)
@@ -226,6 +226,7 @@ struct var_array
   {
     struct variable **vars;
     size_t n;
+    size_t scale_idx;
 
     struct ctables_summary_spec *summaries;
     size_t n_summaries;
@@ -1318,6 +1319,9 @@ nest_fts (struct var_array2 va0, struct var_array2 va1)
           NOT_REACHED ();
         vaa.vas[vaa.n++] = (struct var_array) {
           .vars = vars,
+          .scale_idx = (a->scale_idx != SIZE_MAX ? a->scale_idx
+                        : b->scale_idx != SIZE_MAX ? a->n + b->scale_idx
+                        : SIZE_MAX),
           .n = n,
           .summaries = summary_src->summaries,
           .n_summaries = summary_src->n_summaries,
@@ -1353,17 +1357,16 @@ enumerate_fts (enum pivot_axis_type axis_type, const struct ctables_axis *a)
     {
     case CTAO_VAR:
       assert (!a->var.is_mrset);
+
+      struct variable **vars = xmalloc (sizeof *vars);
+      *vars = a->var.var;
+
       struct var_array *va = xmalloc (sizeof *va);
-      if (a->scale)
-        *va = (struct var_array) { .n = 0 };
-      else
-        {
-          struct variable **vars = xmalloc (sizeof *vars);
-          *vars = a->var.var;
-          enum pivot_axis_type *axes = xmalloc (sizeof *axes);
-          *axes = axis_type;
-          *va = (struct var_array) { .vars = vars, .n = 1 };
-        }
+      *va = (struct var_array) {
+        .vars = vars,
+        .n = 1,
+        .scale_idx = a->scale ? 0 : SIZE_MAX,
+      };
       if (a->n_summaries || a->scale)
         {
           va->summaries = a->summaries;
@@ -1870,13 +1873,14 @@ ctables_freq_compare_3way (const void *a_, const void *b_, const void *aux_)
 
   const struct var_array *va = &aux->t->vaas[aux->a].vas[a_idx];
   for (size_t i = 0; i < va->n; i++)
-    {
-      int cmp = value_compare_3way (&a->axes[aux->a].values[i],
-                                    &b->axes[aux->a].values[i],
-                                    var_get_width (va->vars[i]));
-      if (cmp)
-        return cmp;
-    }
+    if (i != va->scale_idx)
+      {
+        int cmp = value_compare_3way (&a->axes[aux->a].values[i],
+                                      &b->axes[aux->a].values[i],
+                                      var_get_width (va->vars[i]));
+        if (cmp)
+          return cmp;
+      }
   return 0;
 }
 
@@ -1917,8 +1921,9 @@ ctables_freqtab_insert (struct ctables_table *t,
       const struct var_array *va = &t->vaas[a].vas[ix[a]];
       hash = hash_int (ix[a], hash);
       for (size_t i = 0; i < va->n; i++)
-        hash = value_hash (case_data (c, va->vars[i]),
-                           var_get_width (va->vars[i]), hash);
+        if (i != va->scale_idx)
+          hash = value_hash (case_data (c, va->vars[i]),
+                             var_get_width (va->vars[i]), hash);
     }
 
   struct ctables_freq *f;
@@ -1930,10 +1935,11 @@ ctables_freqtab_insert (struct ctables_table *t,
           if (f->axes[a].vaa_idx != ix[a])
             goto not_equal;
           for (size_t i = 0; i < va->n; i++)
-            if (!value_equal (case_data (c, va->vars[i]),
-                              &f->axes[a].values[i],
-                              var_get_width (va->vars[i])))
-              goto not_equal;
+            if (i != va->scale_idx
+                && !value_equal (case_data (c, va->vars[i]),
+                                 &f->axes[a].values[i],
+                                 var_get_width (va->vars[i])))
+                goto not_equal;
         }
 
       goto summarize;
@@ -2007,11 +2013,13 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                                                               dataset_dict (ds),
                                                               NULL, NULL);
   bool warn_on_invalid = true;
+  double total_weight = 0;
   for (struct ccase *c = casereader_read (input); c;
        case_unref (c), c = casereader_read (input))
     {
       double weight = dict_get_case_weight (dataset_dict (ds), c,
                                             &warn_on_invalid);
+      total_weight += weight;
 
       for (size_t i = 0; i < ct->n_tables; i++)
         {
@@ -2079,9 +2087,10 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                   if (prev->axes[a].vaa_idx == f->axes[a].vaa_idx)
                     {
                       for (; n_common < va->n; n_common++)
-                        if (!value_equal (&prev->axes[a].values[n_common],
-                                          &f->axes[a].values[n_common],
-                                          var_get_type (va->vars[n_common])))
+                        if (n_common != va->scale_idx
+                            && !value_equal (&prev->axes[a].values[n_common],
+                                             &f->axes[a].values[n_common],
+                                             var_get_type (va->vars[n_common])))
                           break;
                     }
                   else
@@ -2108,14 +2117,17 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                 {
                   struct pivot_category *parent = k > 0 ? groups[k - 1] : top;
 
-                  struct pivot_value *label = pivot_value_new_var_value (
-                    va->vars[k], &f->axes[a].values[k]);
-
+                  struct pivot_value *label
+                    = (k != va->scale_idx
+                       ? pivot_value_new_var_value (va->vars[k],
+                                                    &f->axes[a].values[k])
+                       : NULL);
                   if (k == va->n - 1)
                     {
                       if (a == t->summary_axis)
                         {
-                          parent = pivot_category_create_group__ (parent, label);
+                          if (label)
+                            parent = pivot_category_create_group__ (parent, label);
                           for (size_t m = 0; m < va->n_summaries; m++)
                             {
                               int leaf = pivot_category_create_leaf (
@@ -2125,11 +2137,18 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                             }
                         }
                       else
-                        prev_leaf = pivot_category_create_leaf (parent, label);
+                        {
+                          /* This assertion is true as long as the summary axis
+                             is the axis where the summaries are displayed. */
+                          assert (label);
+
+                          prev_leaf = pivot_category_create_leaf (parent, label);
+                        }
                       break;
                     }
 
-                  parent = pivot_category_create_group__ (parent, label);
+                  if (label)
+                    parent = pivot_category_create_group__ (parent, label);
 
                   enum ctables_vlabel vlabel = ct->vlabels[var_get_dict_index (va->vars[k + 1])];
                   if (vlabel != CTVL_NONE)
@@ -2161,9 +2180,10 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                     dindexes[n_dindexes++] = leaf;
                   }
 
-              double value = ctables_summary_value (&f->summaries[j], &ss->summaries[j]);
-              pivot_table_put (pt, dindexes, n_dindexes,
-                               pivot_value_new_number (value));
+              double d = ctables_summary_value (&f->summaries[j], &ss->summaries[j]);
+              struct pivot_value *value = pivot_value_new_number (d);
+              value->numeric.format = ss->summaries[j].format;
+              pivot_table_put (pt, dindexes, n_dindexes, value);
             }
         }