CTABLES fixes for totals.
[pspp] / src / language / stats / ctables.c
index b64ea2b565de6412c5f84d41eee4c6d0839e83fe..d3a00312a93a670f8d0159363f567ca73e0b39d9 100644 (file)
@@ -201,17 +201,11 @@ struct ctables_cell
     struct hmap_node node;
 
     /* The domains that contain this cell. */
-    bool contributes_to_domains;
+    uint32_t omit_domains;
     struct ctables_domain *domains[N_CTDTS];
 
     bool hide;
 
-    /* Is at least one value missing, whether included or excluded? */
-    bool is_missing;
-
-    /* Is at least one value missing and excluded? */
-    bool excluded_missing;
-
     bool postcompute;
     enum ctables_summary_variant sv;
 
@@ -228,6 +222,8 @@ struct ctables_cell
     axes[PIVOT_N_AXES];
 
     union ctables_summary *summaries;
+
+    //char *name;
   };
 
 struct ctables
@@ -2253,10 +2249,10 @@ ctables_summary_uninit (union ctables_summary *s,
 }
 
 static void
-ctables_summary_add (const struct ctables_cell *cell, union ctables_summary *s,
+ctables_summary_add (union ctables_summary *s,
                      const struct ctables_summary_spec *ss,
                      const struct variable *var, const union value *value,
-                     bool is_scale, bool is_missing,
+                     bool is_scale, bool is_missing, bool excluded_missing,
                      double d_weight, double e_weight)
 {
   /* To determine whether a case is included in a given table for a particular
@@ -2280,15 +2276,36 @@ ctables_summary_add (const struct ctables_cell *cell, union ctables_summary *s,
   switch (ss->function)
     {
     case CSTF_TOTALN:
+    case CTSF_ROWPCT_TOTALN:
+    case CTSF_COLPCT_TOTALN:
+    case CTSF_TABLEPCT_TOTALN:
+    case CTSF_SUBTABLEPCT_TOTALN:
+    case CTSF_LAYERPCT_TOTALN:
+    case CTSF_LAYERROWPCT_TOTALN:
+    case CTSF_LAYERCOLPCT_TOTALN:
       s->count += d_weight;
       break;
 
     case CTSF_COUNT:
-      if (is_scale || !cell->excluded_missing)
+    case CTSF_ROWPCT_COUNT:
+    case CTSF_COLPCT_COUNT:
+    case CTSF_TABLEPCT_COUNT:
+    case CTSF_SUBTABLEPCT_COUNT:
+    case CTSF_LAYERPCT_COUNT:
+    case CTSF_LAYERROWPCT_COUNT:
+    case CTSF_LAYERCOLPCT_COUNT:
+      if (is_scale || !excluded_missing)
         s->count += d_weight;
       break;
 
     case CTSF_VALIDN:
+    case CTSF_ROWPCT_VALIDN:
+    case CTSF_COLPCT_VALIDN:
+    case CTSF_TABLEPCT_VALIDN:
+    case CTSF_SUBTABLEPCT_VALIDN:
+    case CTSF_LAYERPCT_VALIDN:
+    case CTSF_LAYERROWPCT_VALIDN:
+    case CTSF_LAYERCOLPCT_VALIDN:
       if (is_scale
           ? !var_is_value_missing (var, value)
           : !is_missing)
@@ -2301,28 +2318,8 @@ ctables_summary_add (const struct ctables_cell *cell, union ctables_summary *s,
       break;
 
     case CTSF_ECOUNT:
-    case CTSF_ROWPCT_COUNT:
-    case CTSF_COLPCT_COUNT:
-    case CTSF_TABLEPCT_COUNT:
-    case CTSF_SUBTABLEPCT_COUNT:
-    case CTSF_LAYERPCT_COUNT:
-    case CTSF_LAYERROWPCT_COUNT:
-    case CTSF_LAYERCOLPCT_COUNT:
-    case CTSF_ROWPCT_VALIDN:
-    case CTSF_COLPCT_VALIDN:
-    case CTSF_TABLEPCT_VALIDN:
-    case CTSF_SUBTABLEPCT_VALIDN:
-    case CTSF_LAYERPCT_VALIDN:
-    case CTSF_LAYERROWPCT_VALIDN:
-    case CTSF_LAYERCOLPCT_VALIDN:
-    case CTSF_ROWPCT_TOTALN:
-    case CTSF_COLPCT_TOTALN:
-    case CTSF_TABLEPCT_TOTALN:
-    case CTSF_SUBTABLEPCT_TOTALN:
-    case CTSF_LAYERPCT_TOTALN:
-    case CTSF_LAYERROWPCT_TOTALN:
-    case CTSF_LAYERCOLPCT_TOTALN:
-      s->count += d_weight;
+      if (is_scale || !excluded_missing)
+        s->count += e_weight;
       break;
 
     case CTSF_EVALIDN:
@@ -2472,8 +2469,8 @@ ctables_summary_value (const struct ctables_cell *cell,
     case CTSF_LAYERCOLPCT_COUNT:
       {
         enum ctables_domain_type d = ctables_function_domain (ss->function);
-        return (cell->domains[d]->e_valid
-                ? s->count / cell->domains[d]->e_valid * 100
+        return (cell->domains[d]->e_count
+                ? s->count / cell->domains[d]->e_count * 100
                 : SYSMIS);
       }
 
@@ -2484,6 +2481,13 @@ ctables_summary_value (const struct ctables_cell *cell,
     case CTSF_LAYERPCT_VALIDN:
     case CTSF_LAYERROWPCT_VALIDN:
     case CTSF_LAYERCOLPCT_VALIDN:
+      {
+        enum ctables_domain_type d = ctables_function_domain (ss->function);
+        return (cell->domains[d]->e_valid
+                ? s->count / cell->domains[d]->e_valid * 100
+                : SYSMIS);
+      }
+
     case CTSF_ROWPCT_TOTALN:
     case CTSF_COLPCT_TOTALN:
     case CTSF_TABLEPCT_TOTALN:
@@ -2491,7 +2495,12 @@ ctables_summary_value (const struct ctables_cell *cell,
     case CTSF_LAYERPCT_TOTALN:
     case CTSF_LAYERROWPCT_TOTALN:
     case CTSF_LAYERCOLPCT_TOTALN:
-      NOT_REACHED ();
+      {
+        enum ctables_domain_type d = ctables_function_domain (ss->function);
+        return (cell->domains[d]->e_total
+                ? s->count / cell->domains[d]->e_total * 100
+                : SYSMIS);
+      }
 
     case CTSF_MISSING:
       return s->count;
@@ -2698,8 +2707,13 @@ ctables_domain_insert (struct ctables_section *s, struct ctables_cell *cell,
       for (size_t i = 0; i < nest->n_domains[domain]; i++)
         {
           size_t v_idx = nest->domains[domain][i];
-          hash = value_hash (&cell->axes[a].cvs[v_idx].value,
-                             var_get_width (nest->vars[v_idx]), hash);
+          struct ctables_cell_value *cv = &cell->axes[a].cvs[v_idx];
+          hash = hash_pointer (cv->category, hash);
+          if (cv->category->type != CCT_TOTAL
+              && cv->category->type != CCT_SUBTOTAL
+              && cv->category->type != CCT_POSTCOMPUTE)
+            hash = value_hash (&cv->value,
+                               var_get_width (nest->vars[v_idx]), hash);
         }
     }
 
@@ -2713,9 +2727,14 @@ ctables_domain_insert (struct ctables_section *s, struct ctables_cell *cell,
           for (size_t i = 0; i < nest->n_domains[domain]; i++)
             {
               size_t v_idx = nest->domains[domain][i];
-              if (!value_equal (&df->axes[a].cvs[v_idx].value,
-                                &cell->axes[a].cvs[v_idx].value,
-                                var_get_width (nest->vars[v_idx])))
+              struct ctables_cell_value *cv1 = &df->axes[a].cvs[v_idx];
+              struct ctables_cell_value *cv2 = &cell->axes[a].cvs[v_idx];
+              if (cv1->category != cv2->category
+                  || (cv1->category->type != CCT_TOTAL
+                      && cv1->category->type != CCT_SUBTOTAL
+                      && cv1->category->type != CCT_POSTCOMPUTE
+                      && !value_equal (&cv1->value, &cv2->value,
+                                       var_get_width (nest->vars[v_idx]))))
                 goto not_equal;
             }
         }
@@ -2846,17 +2865,16 @@ ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
 
   cell = xmalloc (sizeof *cell);
   cell->hide = false;
-  cell->is_missing = false;
-  cell->excluded_missing = false;
   cell->sv = sv;
-  cell->contributes_to_domains = true;
+  cell->omit_domains = 0;
   cell->postcompute = false;
+  //struct string name = DS_EMPTY_INITIALIZER;
   for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
     {
       const struct ctables_nest *nest = s->nests[a];
       cell->axes[a].cvs = (nest->n
-                        ? xnmalloc (nest->n, sizeof *cell->axes[a].cvs)
-                        : NULL);
+                           ? xnmalloc (nest->n, sizeof *cell->axes[a].cvs)
+                           : NULL);
       for (size_t i = 0; i < nest->n; i++)
         {
           const struct ctables_category *cat = cats[a][i];
@@ -2871,19 +2889,59 @@ ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
               if (cat->type == CCT_TOTAL
                   || cat->type == CCT_SUBTOTAL
                   || cat->type == CCT_POSTCOMPUTE)
-                cell->contributes_to_domains = false;
-              else if (var_is_value_missing (var, value))
-                cell->is_missing = true;
-              if (cat->type == CCT_EXCLUDED_MISSING)
-                cell->excluded_missing = true;
+                {
+                  /* XXX these should be more encompassing I think.*/
+
+                  switch (a)
+                    {
+                    case PIVOT_AXIS_COLUMN:
+                      cell->omit_domains |= ((1u << CTDT_TABLE) |
+                                             (1u << CTDT_LAYER) |
+                                             (1u << CTDT_LAYERCOL) |
+                                             (1u << CTDT_SUBTABLE) |
+                                             (1u << CTDT_COL));
+                      break;
+                    case PIVOT_AXIS_ROW:
+                      cell->omit_domains |= ((1u << CTDT_TABLE) |
+                                             (1u << CTDT_LAYER) |
+                                             (1u << CTDT_LAYERROW) |
+                                             (1u << CTDT_SUBTABLE) |
+                                             (1u << CTDT_ROW));
+                      break;
+                    case PIVOT_AXIS_LAYER:
+                      cell->omit_domains |= ((1u << CTDT_TABLE) |
+                                             (1u << CTDT_LAYER));
+                      break;
+                    }
+                }
               if (cat->type == CCT_POSTCOMPUTE)
                 cell->postcompute = true;
             }
 
           cell->axes[a].cvs[i].category = cat;
           value_clone (&cell->axes[a].cvs[i].value, value, var_get_width (var));
+
+#if 0
+          if (i != nest->scale_idx)
+            {
+              if (!ds_is_empty (&name))
+                ds_put_cstr (&name, ", ");
+              char *value_s = data_out (value, var_get_encoding (var),
+                                        var_get_print_format (var),
+                                        settings_get_fmt_settings ());
+              if (cat->type == CCT_TOTAL
+                  || cat->type == CCT_SUBTOTAL
+                  || cat->type == CCT_POSTCOMPUTE)
+                ds_put_format (&name, "%s=total", var_get_name (var));
+              else
+                ds_put_format (&name, "%s=%s", var_get_name (var),
+                               value_s + strspn (value_s, " "));
+              free (value_s);
+            }
+#endif
         }
     }
+  //cell->name = ds_steal_cstr (&name);
 
   const struct ctables_nest *ss = s->nests[s->table->summary_axis];
   const struct ctables_summary_spec_set *specs = &ss->specs[cell->sv];
@@ -2899,41 +2957,41 @@ ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
 static void
 ctables_cell_add__ (struct ctables_section *s, const struct ccase *c,
                     const struct ctables_category *cats[PIVOT_N_AXES][10],
-                    bool is_missing, double d_weight, double e_weight)
+                    bool is_missing, bool excluded_missing,
+                    double d_weight, double e_weight)
 {
   struct ctables_cell *cell = ctables_cell_insert__ (s, c, cats);
   const struct ctables_nest *ss = s->nests[s->table->summary_axis];
 
   const struct ctables_summary_spec_set *specs = &ss->specs[cell->sv];
   for (size_t i = 0; i < specs->n; i++)
-    ctables_summary_add (cell, &cell->summaries[i], &specs->specs[i],
+    ctables_summary_add (&cell->summaries[i], &specs->specs[i],
                          specs->var, case_data (c, specs->var), specs->is_scale,
-                         is_missing, d_weight, e_weight);
-  if (cell->contributes_to_domains)
-    {
-      for (enum ctables_domain_type dt = 0; dt < N_CTDTS; dt++)
-        {
-          struct ctables_domain *d = cell->domains[dt];
-          d->d_total += d_weight;
-          d->e_total += e_weight;
-          if (!cell->excluded_missing)
-            {
-              d->d_count += d_weight;
-              d->e_count += e_weight;
-            }
-          if (!cell->is_missing)
-            {
-              d->d_valid += d_weight;
-              d->e_valid += e_weight;
-            }
-        }
-    }
+                         is_missing, excluded_missing, d_weight, e_weight);
+  for (enum ctables_domain_type dt = 0; dt < N_CTDTS; dt++)
+    if (!(cell->omit_domains && (1u << dt)))
+      {
+        struct ctables_domain *d = cell->domains[dt];
+        d->d_total += d_weight;
+        d->e_total += e_weight;
+        if (!excluded_missing)
+          {
+            d->d_count += d_weight;
+            d->e_count += e_weight;
+          }
+        if (!is_missing)
+          {
+            d->d_valid += d_weight;
+            d->e_valid += e_weight;
+          }
+      }
 }
 
 static void
 recurse_totals (struct ctables_section *s, const struct ccase *c,
                 const struct ctables_category *cats[PIVOT_N_AXES][10],
-                bool is_missing, double d_weight, double e_weight,
+                bool is_missing, bool excluded_missing,
+                double d_weight, double e_weight,
                 enum pivot_axis_type start_axis, size_t start_nest)
 {
   for (enum pivot_axis_type a = start_axis; a < PIVOT_N_AXES; a++)
@@ -2952,8 +3010,9 @@ recurse_totals (struct ctables_section *s, const struct ccase *c,
             {
               const struct ctables_category *save = cats[a][i];
               cats[a][i] = total;
-              ctables_cell_add__ (s, c, cats, is_missing, d_weight, e_weight);
-              recurse_totals (s, c, cats, is_missing,
+              ctables_cell_add__ (s, c, cats, is_missing, excluded_missing,
+                                  d_weight, e_weight);
+              recurse_totals (s, c, cats, is_missing, excluded_missing,
                               d_weight, e_weight, a, i + 1);
               cats[a][i] = save;
             }
@@ -2965,7 +3024,8 @@ recurse_totals (struct ctables_section *s, const struct ccase *c,
 static void
 recurse_subtotals (struct ctables_section *s, const struct ccase *c,
                    const struct ctables_category *cats[PIVOT_N_AXES][10],
-                   bool is_missing, double d_weight, double e_weight,
+                   bool is_missing, bool excluded_missing,
+                   double d_weight, double e_weight,
                    enum pivot_axis_type start_axis, size_t start_nest)
 {
   for (enum pivot_axis_type a = start_axis; a < PIVOT_N_AXES; a++)
@@ -2980,8 +3040,9 @@ recurse_subtotals (struct ctables_section *s, const struct ccase *c,
           if (save->subtotal)
             {
               cats[a][i] = save->subtotal;
-              ctables_cell_add__ (s, c, cats, is_missing, d_weight, e_weight);
-              recurse_subtotals (s, c, cats, is_missing,
+              ctables_cell_add__ (s, c, cats, is_missing, excluded_missing,
+                                  d_weight, e_weight);
+              recurse_subtotals (s, c, cats, is_missing, excluded_missing,
                                  d_weight, e_weight, a, i + 1);
               cats[a][i] = save;
             }
@@ -3015,8 +3076,15 @@ ctables_cell_insert (struct ctables_section *s,
                      double d_weight, double e_weight)
 {
   const struct ctables_category *cats[PIVOT_N_AXES][10]; /* XXX */
+
+  /* Does at least one categorical variable have a missing value in an included
+     or excluded category? */
   bool is_missing = false;
+
+  /* Does at least one categorical variable have a missing value in an excluded
+     category? */
   bool excluded_missing = false;
+
   for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
     {
       const struct ctables_nest *nest = s->nests[a];
@@ -3036,7 +3104,7 @@ ctables_cell_insert (struct ctables_section *s,
             s->table->categories[var_get_dict_index (var)], value, var);
           if (!cats[a][i])
             {
-              if (!is_missing)
+              if (!var_missing)
                 return;
 
               static const struct ctables_category cct_excluded_missing = {
@@ -3062,12 +3130,15 @@ ctables_cell_insert (struct ctables_section *s,
             }
       }
 
-  ctables_cell_add__ (s, c, cats, is_missing, d_weight, e_weight);
+  ctables_cell_add__ (s, c, cats, is_missing, excluded_missing,
+                      d_weight, e_weight);
 
-  if (!excluded_missing)
+  //if (!excluded_missing)
     {
-      recurse_totals (s, c, cats, is_missing, d_weight, e_weight, 0, 0);
-      recurse_subtotals (s, c, cats, is_missing, d_weight, e_weight, 0, 0);
+      recurse_totals (s, c, cats, is_missing, excluded_missing,
+                      d_weight, e_weight, 0, 0);
+      recurse_subtotals (s, c, cats, is_missing, excluded_missing,
+                         d_weight, e_weight, 0, 0);
     }
 }
 
@@ -3506,6 +3577,14 @@ ctables_table_output (struct ctables *ct, struct ctables_table *t)
           struct ctables_cell_sort_aux aux = { .nest = nest, .a = a };
           sort (sorted, n_sorted, sizeof *sorted, ctables_cell_compare_3way, &aux);
 
+#if 0
+          for (size_t j = 0; j < n_sorted; j++)
+            {
+              printf ("%s (%s): %f/%f = %.1f%%\n", sorted[j]->name, sorted[j]->contributes_to_domains ? "y" : "n", sorted[j]->summaries[0].count, sorted[j]->domains[CTDT_COL]->e_count, sorted[j]->summaries[0].count / sorted[j]->domains[CTDT_COL]->e_count * 100.0);
+            }
+          printf ("\n");
+#endif
+          
           struct ctables_level
             {
               enum ctables_level_type