postcomputes show up in results (all zeros so far)
[pspp] / src / language / stats / ctables.c
index fdcf748391af4b1f9194d06ad59fd753e572675a..00275303d5952f93dcda6923cf4e0cd98281c802 100644 (file)
@@ -461,7 +461,6 @@ struct ctables_category
 
         /* Totals and subtotals. */
         CCT_SUBTOTAL,
-        CCT_HSUBTOTAL,
         CCT_TOTAL,
 
         /* Implicit category lists. */
@@ -473,12 +472,20 @@ struct ctables_category
 
     struct ctables_category *subtotal;
 
+    bool hide;
+
     union
       {
         double number;          /* CCT_NUMBER. */
         char *string;           /* CCT_STRING. */
         double range[2];        /* CCT_RANGE. */
-        char *total_label;      /* CCT_SUBTOTAL, CCT_HSUBTOTAL, CCT_TOTAL. */
+
+        struct
+          {
+            char *total_label;      /* CCT_SUBTOTAL, CCT_TOTAL. */
+            bool hide_subcategories; /* CCT_SUBTOTAL. */
+          };
+
         const struct ctables_postcompute *pc; /* CCT_POSTCOMPUTE. */
 
         /* CCT_VALUE, CCT_LABEL, CCT_FUNCTION. */
@@ -519,7 +526,6 @@ ctables_category_uninit (struct ctables_category *cat)
       break;
 
     case CCT_SUBTOTAL:
-    case CCT_HSUBTOTAL:
     case CCT_TOTAL:
       free (cat->total_label);
       break;
@@ -557,7 +563,6 @@ ctables_category_equal (const struct ctables_category *a,
       return a->pc == b->pc;
 
     case CCT_SUBTOTAL:
-    case CCT_HSUBTOTAL:
     case CCT_TOTAL:
       return !strcmp (a->total_label, b->total_label);
 
@@ -1313,8 +1318,7 @@ cct_range (double low, double high)
 }
 
 static bool
-ctables_table_parse_subtotal (struct lexer *lexer,
-                              enum ctables_category_type cct,
+ctables_table_parse_subtotal (struct lexer *lexer, bool hide_subcategories,
                               struct ctables_category *cat)
 {
   char *total_label;
@@ -1329,7 +1333,11 @@ ctables_table_parse_subtotal (struct lexer *lexer,
   else
     total_label = xstrdup (_("Subtotal"));
 
-  *cat = (struct ctables_category) { .type = cct, .total_label = total_label };
+  *cat = (struct ctables_category) {
+    .type = CCT_SUBTOTAL,
+    .hide_subcategories = hide_subcategories,
+    .total_label = total_label
+  };
   return true;
 }
 
@@ -1342,9 +1350,9 @@ ctables_table_parse_explicit_category (struct lexer *lexer, struct ctables *ct,
   else if (lex_match_id (lexer, "MISSING"))
     *cat = (struct ctables_category) { .type = CCT_MISSING };
   else if (lex_match_id (lexer, "SUBTOTAL"))
-    return ctables_table_parse_subtotal (lexer, CCT_SUBTOTAL, cat);
+    return ctables_table_parse_subtotal (lexer, false, cat);
   else if (lex_match_id (lexer, "HSUBTOTAL"))
-    return ctables_table_parse_subtotal (lexer, CCT_HSUBTOTAL, cat);
+    return ctables_table_parse_subtotal (lexer, true, cat);
   else if (lex_match_id (lexer, "LO"))
     {
       if (!lex_force_match_id (lexer, "THRU") || lex_force_num (lexer))
@@ -1409,15 +1417,15 @@ ctables_table_parse_explicit_category (struct lexer *lexer, struct ctables *ct,
   return true;
 }
 
-static const struct ctables_category *
+static struct ctables_category *
 ctables_find_category_for_postcompute (const struct ctables_categories *cats,
                                        const struct ctables_pcexpr *e)
 {
-  const struct ctables_category *best = NULL;
+  struct ctables_category *best = NULL;
   size_t n_subtotals = 0;
   for (size_t i = 0; i < cats->n_cats; i++)
     {
-      const struct ctables_category *cat = &cats->cats[i];
+      struct ctables_category *cat = &cats->cats[i];
       switch (e->op)
         {
         case CTPO_CAT_NUMBER:
@@ -1448,7 +1456,7 @@ ctables_find_category_for_postcompute (const struct ctables_categories *cats,
           break;
 
         case CTPO_CAT_SUBTOTAL:
-          if (cat->type == CCT_SUBTOTAL || cat->type == CCT_HSUBTOTAL)
+          if (cat->type == CCT_SUBTOTAL)
             {
               n_subtotals++;
               if (e->subtotal_index == n_subtotals)
@@ -1480,7 +1488,7 @@ ctables_find_category_for_postcompute (const struct ctables_categories *cats,
 
 static bool
 ctables_recursive_check_postcompute (const struct ctables_pcexpr *e,
-                                     const struct ctables_category *cat,
+                                     struct ctables_category *pc_cat,
                                      const struct ctables_categories *cats,
                                      const struct msg_location *cats_location)
 {
@@ -1493,44 +1501,49 @@ ctables_recursive_check_postcompute (const struct ctables_pcexpr *e,
     case CTPO_CAT_OTHERNM:
     case CTPO_CAT_SUBTOTAL:
     case CTPO_CAT_TOTAL:
-      if (!ctables_find_category_for_postcompute (cats, e))
-        {
-          if (e->op == CTPO_CAT_SUBTOTAL && e->subtotal_index == 0)
-            {
-              size_t n_subtotals = 0;
-              for (size_t i = 0; i < cats->n_cats; i++)
-                n_subtotals += (cats->cats[i].type == CCT_SUBTOTAL
-                                || cats->cats[i].type == CCT_HSUBTOTAL);
-              if (n_subtotals > 1)
-                {
-                  msg_at (SE, cats_location,
-                          ngettext ("These categories include %zu instance of "
-                                    "SUBTOTAL or HSUBTOTAL, so references from "
-                                    "computed categories must refer to "
-                                    "subtotals by position.",
-                                    "These categories include %zu instances of "
-                                    "SUBTOTAL or HSUBTOTAL, so references from "
-                                    "computed categories must refer to "
-                                    "subtotals by position.",
-                                    n_subtotals),
-                          n_subtotals);
-                  msg_at (SN, e->location,
-                          _("This is the reference that lacks a position."));
-                  return NULL;
-                }
-            }
+      {
+        struct ctables_category *cat = ctables_find_category_for_postcompute (
+          cats, e);
+        if (!cat)
+          {
+            if (e->op == CTPO_CAT_SUBTOTAL && e->subtotal_index == 0)
+              {
+                size_t n_subtotals = 0;
+                for (size_t i = 0; i < cats->n_cats; i++)
+                  n_subtotals += cats->cats[i].type == CCT_SUBTOTAL;
+                if (n_subtotals > 1)
+                  {
+                    msg_at (SE, cats_location,
+                            ngettext ("These categories include %zu instance "
+                                      "of SUBTOTAL or HSUBTOTAL, so references "
+                                      "from computed categories must refer to "
+                                      "subtotals by position.",
+                                      "These categories include %zu instances "
+                                      "of SUBTOTAL or HSUBTOTAL, so references "
+                                      "from computed categories must refer to "
+                                      "subtotals by position.",
+                                      n_subtotals),
+                            n_subtotals);
+                    msg_at (SN, e->location,
+                            _("This is the reference that lacks a position."));
+                    return NULL;
+                  }
+              }
 
-          msg_at (SE, cat->location,
-                  _("Computed category &%s references a category not included "
-                    "in the category list."),
-                  cat->pc->name);
-          msg_at (SN, e->location, _("This is the missing category."));
-          msg_at (SN, cats_location,
-                  _("To fix the problem, add the missing category to the "
-                    "list of categories here."));
-          return false;
-        }
-      return true;
+            msg_at (SE, pc_cat->location,
+                    _("Computed category &%s references a category not included "
+                      "in the category list."),
+                    cat->pc->name);
+            msg_at (SN, e->location, _("This is the missing category."));
+            msg_at (SN, cats_location,
+                    _("To fix the problem, add the missing category to the "
+                      "list of categories here."));
+            return false;
+          }
+        if (pc_cat->pc->hide_source_cats)
+          cat->hide = true;
+        return true;
+      }
 
     case CTPO_CONSTANT:
       return true;
@@ -1543,7 +1556,7 @@ ctables_recursive_check_postcompute (const struct ctables_pcexpr *e,
     case CTPO_NEG:
       for (size_t i = 0; i < 2; i++)
         if (e->subs[i] && !ctables_recursive_check_postcompute (
-              e->subs[i], cat, cats, cats_location))
+              e->subs[i], pc_cat, cats, cats_location))
           return false;
       return true;
 
@@ -1600,7 +1613,7 @@ ctables_table_parse_categories (struct lexer *lexer, struct dictionary *dict,
         = lex_ofs_location (lexer, cats_start_ofs, lex_ofs (lexer) - 1);
       for (size_t i = 0; i < c->n_cats; i++)
         {
-          const struct ctables_category *cat = &c->cats[i];
+          struct ctables_category *cat = &c->cats[i];
           if (cat->type == CCT_POSTCOMPUTE
               && !ctables_recursive_check_postcompute (cat->pc->expr, cat,
                                                        c, cats_location))
@@ -1783,7 +1796,6 @@ ctables_table_parse_categories (struct lexer *lexer, struct dictionary *dict,
           break;
 
         case CCT_SUBTOTAL:
-        case CCT_HSUBTOTAL:
           subtotal = cat;
           break;
 
@@ -2545,7 +2557,6 @@ ctables_cell_compare_3way (const void *a_, const void *b_, const void *aux_)
           case CCT_NUMBER:
           case CCT_STRING:
           case CCT_SUBTOTAL:
-          case CCT_HSUBTOTAL:
           case CCT_TOTAL:
           case CCT_POSTCOMPUTE:
             /* Must be equal. */
@@ -2688,7 +2699,6 @@ ctables_categories_match (const struct ctables_categories *c,
           break;
 
         case CCT_SUBTOTAL:
-        case CCT_HSUBTOTAL:
         case CCT_TOTAL:
           break;
 
@@ -2730,7 +2740,7 @@ ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
             hash = hash_pointer (cats[a][i], hash);
             if (cats[a][i]->type != CCT_TOTAL
                 && cats[a][i]->type != CCT_SUBTOTAL
-                && cats[a][i]->type != CCT_HSUBTOTAL)
+                && cats[a][i]->type != CCT_POSTCOMPUTE)
               hash = value_hash (case_data (c, nest->vars[i]),
                                  var_get_width (nest->vars[i]), hash);
             else
@@ -2749,7 +2759,7 @@ ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
                 && (cats[a][i] != cell->axes[a].cvs[i].category
                     || (cats[a][i]->type != CCT_TOTAL
                         && cats[a][i]->type != CCT_SUBTOTAL
-                        && cats[a][i]->type != CCT_HSUBTOTAL
+                        && cats[a][i]->type != CCT_POSTCOMPUTE
                         && !value_equal (case_data (c, nest->vars[i]),
                                          &cell->axes[a].cvs[i].value,
                                          var_get_width (nest->vars[i])))))
@@ -2774,14 +2784,15 @@ ctables_cell_insert__ (struct ctables_section *s, const struct ccase *c,
       for (size_t i = 0; i < nest->n; i++)
         {
           const struct ctables_category *cat = cats[a][i];
-
           if (i != nest->scale_idx)
             {
               const struct ctables_category *subtotal = cat->subtotal;
-              if (subtotal && subtotal->type == CCT_HSUBTOTAL)
+              if (cat->hide || (subtotal && subtotal->hide_subcategories))
                 cell->hide = true;
 
-              if (cat->type == CCT_TOTAL || cat->type == CCT_SUBTOTAL || cat->type == CCT_HSUBTOTAL)
+              if (cat->type == CCT_TOTAL
+                  || cat->type == CCT_SUBTOTAL
+                  || cat->type == CCT_POSTCOMPUTE)
                 cell->contributes_to_domains = false;
             }
 
@@ -2968,8 +2979,10 @@ ctables_category_create_label (const struct ctables_category *cat,
                                const struct variable *var,
                                const union value *value)
 {
-  return (cat->type == CCT_TOTAL || cat->type == CCT_SUBTOTAL || cat->type == CCT_HSUBTOTAL
+  return (cat->type == CCT_TOTAL || cat->type == CCT_SUBTOTAL
           ? pivot_value_new_user_text (cat->total_label, SIZE_MAX)
+          : cat->type == CCT_POSTCOMPUTE && cat->pc->label
+          ? pivot_value_new_user_text (cat->pc->label, SIZE_MAX)
           : pivot_value_new_var_value (var, value));
 }
 
@@ -3213,8 +3226,8 @@ ctables_table_output (struct ctables *ct, struct ctables_table *t)
                           if (prev->axes[a].cvs[var_idx].category != c)
                             break;
                           else if (c->type != CCT_SUBTOTAL
-                                   && c->type != CCT_HSUBTOTAL
                                    && c->type != CCT_TOTAL
+                                   && c->type != CCT_POSTCOMPUTE
                                    && !value_equal (&prev->axes[a].cvs[var_idx].value,
                                                     &cell->axes[a].cvs[var_idx].value,
                                                     var_get_type (nest->vars[var_idx])))
@@ -3674,7 +3687,6 @@ ctables_add_category_occurrences (const struct variable *var,
           break;
 
         case CCT_SUBTOTAL:
-        case CCT_HSUBTOTAL:
         case CCT_TOTAL:
           break;
 
@@ -3703,6 +3715,8 @@ ctables_section_recurse_add_empty_categories (
   else
     {
       const struct variable *var = s->nests[a]->vars[a_idx];
+      const struct ctables_categories *categories = s->table->categories[
+        var_get_dict_index (var)];
       int width = var_get_width (var);
       const struct hmap *occurrences = &s->occurrences[a][a_idx];
       const struct ctables_section_value *sv;
@@ -3711,11 +3725,21 @@ ctables_section_recurse_add_empty_categories (
           union value *value = case_data_rw (c, var);
           value_destroy (value, width);
           value_clone (value, &sv->value, width);
-          cats[a][a_idx] = ctables_categories_match (
-            s->table->categories[var_get_dict_index (var)], value, var);
+          cats[a][a_idx] = ctables_categories_match (categories, value, var);
           assert (cats[a][a_idx] != NULL);
           ctables_section_recurse_add_empty_categories (s, cats, c, a, a_idx + 1);
         }
+
+      for (size_t i = 0; i < categories->n_cats; i++)
+        {
+          const struct ctables_category *cat = &categories->cats[i];
+          if (cat->type == CCT_POSTCOMPUTE)
+            {
+              printf ("%s:%d\n", __FILE__, __LINE__);
+              cats[a][a_idx] = cat;
+              ctables_section_recurse_add_empty_categories (s, cats, c, a, a_idx + 1);
+            }
+        }
     }
 }
 
@@ -4141,7 +4165,7 @@ ctables_find_postcompute (struct ctables *ct, const char *name)
 static bool
 ctables_parse_pcompute (struct lexer *lexer, struct ctables *ct)
 {
-  int start_ofs = lex_ofs (lexer) - 1;
+  int pcompute_start = lex_ofs (lexer) - 1;
 
   if (!lex_force_match (lexer, T_AND) || !lex_force_id (lexer))
     return false;
@@ -4157,15 +4181,18 @@ ctables_parse_pcompute (struct lexer *lexer, struct ctables *ct)
       return false;
     }
 
+  int expr_start = lex_ofs (lexer);
   struct ctables_pcexpr *expr = parse_add (lexer);
+  int expr_end = lex_ofs (lexer) - 1;
   if (!expr || !lex_force_match (lexer, T_RPAREN))
     {
       free (name);
       return false;
     }
+  int pcompute_end = lex_ofs (lexer) - 1;
 
-  struct msg_location *location = lex_ofs_location (lexer, start_ofs,
-                                                    lex_ofs (lexer) - 1);
+  struct msg_location *location = lex_ofs_location (lexer, pcompute_start,
+                                                    pcompute_end);
 
   struct ctables_postcompute *pc = ctables_find_postcompute (ct, name);
   if (pc)
@@ -4188,6 +4215,8 @@ ctables_parse_pcompute (struct lexer *lexer, struct ctables *ct)
     }
   pc->expr = expr;
   pc->location = location;
+  if (!pc->label)
+    pc->label = lex_ofs_representation (lexer, expr_start, expr_end);
   return true;
 }