Match postcompute expressions against categories.
authorBen Pfaff <blp@cs.stanford.edu>
Sat, 12 Feb 2022 23:30:32 +0000 (15:30 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 12 Feb 2022 23:30:32 +0000 (15:30 -0800)
src/language/stats/ctables.c
src/libpspp/message.c
src/libpspp/message.h

index 7f324349bfd233d46ce5f65826baadfd948b1e13..fdcf748391af4b1f9194d06ad59fd753e572675a 100644 (file)
@@ -309,7 +309,6 @@ struct ctables_pcexpr
       };
 
     /* Source location. */
-    int ofs[2];
     struct msg_location *location;
   };
 
@@ -494,6 +493,10 @@ struct ctables_category
             double percentile;
           };
       };
+
+    /* Source location.  This is null for CCT_TOTAL, CCT_VALUE, CCT_LABEL,
+       CCT_FUNCTION. */
+    struct msg_location *location;
   };
 
 static void
@@ -1309,6 +1312,246 @@ cct_range (double low, double high)
   };
 }
 
+static bool
+ctables_table_parse_subtotal (struct lexer *lexer,
+                              enum ctables_category_type cct,
+                              struct ctables_category *cat)
+{
+  char *total_label;
+  if (lex_match (lexer, T_EQUALS))
+    {
+      if (!lex_force_string (lexer))
+        return false;
+
+      total_label = ss_xstrdup (lex_tokss (lexer));
+      lex_get (lexer);
+    }
+  else
+    total_label = xstrdup (_("Subtotal"));
+
+  *cat = (struct ctables_category) { .type = cct, .total_label = total_label };
+  return true;
+}
+
+static bool
+ctables_table_parse_explicit_category (struct lexer *lexer, struct ctables *ct,
+                                       struct ctables_category *cat)
+{
+  if (lex_match_id (lexer, "OTHERNM"))
+    *cat = (struct ctables_category) { .type = CCT_OTHERNM };
+  else if (lex_match_id (lexer, "MISSING"))
+    *cat = (struct ctables_category) { .type = CCT_MISSING };
+  else if (lex_match_id (lexer, "SUBTOTAL"))
+    return ctables_table_parse_subtotal (lexer, CCT_SUBTOTAL, cat);
+  else if (lex_match_id (lexer, "HSUBTOTAL"))
+    return ctables_table_parse_subtotal (lexer, CCT_HSUBTOTAL, cat);
+  else if (lex_match_id (lexer, "LO"))
+    {
+      if (!lex_force_match_id (lexer, "THRU") || lex_force_num (lexer))
+        return false;
+      *cat = cct_range (-DBL_MAX, lex_number (lexer));
+      lex_get (lexer);
+    }
+  else if (lex_is_number (lexer))
+    {
+      double number = lex_number (lexer);
+      lex_get (lexer);
+      if (lex_match_id (lexer, "THRU"))
+        {
+          if (lex_match_id (lexer, "HI"))
+            *cat = cct_range (number, DBL_MAX);
+          else
+            {
+              if (!lex_force_num (lexer))
+                return false;
+              *cat = cct_range (number, lex_number (lexer));
+              lex_get (lexer);
+            }
+        }
+      else
+        *cat = (struct ctables_category) {
+          .type = CCT_NUMBER,
+          .number = number
+        };
+    }
+  else if (lex_is_string (lexer))
+    {
+      *cat = (struct ctables_category) {
+        .type = CCT_STRING,
+        .string = ss_xstrdup (lex_tokss (lexer)),
+      };
+      lex_get (lexer);
+    }
+  else if (lex_match (lexer, T_AND))
+    {
+      if (!lex_force_id (lexer))
+        return false;
+      struct ctables_postcompute *pc = ctables_find_postcompute (
+        ct, lex_tokcstr (lexer));
+      if (!pc)
+        {
+          struct msg_location *loc = lex_get_location (lexer, -1, 0);
+          msg_at (SE, loc, _("Unknown postcompute &%s."),
+                  lex_tokcstr (lexer));
+          msg_location_destroy (loc);
+          return false;
+        }
+      lex_get (lexer);
+
+      *cat = (struct ctables_category) { .type = CCT_POSTCOMPUTE, .pc = pc };
+    }
+  else
+    {
+      lex_error (lexer, NULL);
+      return false;
+    }
+
+  return true;
+}
+
+static const struct ctables_category *
+ctables_find_category_for_postcompute (const struct ctables_categories *cats,
+                                       const struct ctables_pcexpr *e)
+{
+  const struct ctables_category *best = NULL;
+  size_t n_subtotals = 0;
+  for (size_t i = 0; i < cats->n_cats; i++)
+    {
+      const struct ctables_category *cat = &cats->cats[i];
+      switch (e->op)
+        {
+        case CTPO_CAT_NUMBER:
+          if (cat->type == CCT_NUMBER && cat->number == e->number)
+            best = cat;
+          break;
+
+        case CTPO_CAT_STRING:
+          if (cat->type == CCT_STRING && !strcmp (cat->string, e->string))
+            best = cat;
+          break;
+
+        case CTPO_CAT_RANGE:
+          if (cat->type == CCT_RANGE
+              && cat->range[0] == e->range[0]
+              && cat->range[1] == e->range[1])
+            best = cat;
+          break;
+
+        case CTPO_CAT_MISSING:
+          if (cat->type == CCT_MISSING)
+            best = cat;
+          break;
+
+        case CTPO_CAT_OTHERNM:
+          if (cat->type == CCT_OTHERNM)
+            best = cat;
+          break;
+
+        case CTPO_CAT_SUBTOTAL:
+          if (cat->type == CCT_SUBTOTAL || cat->type == CCT_HSUBTOTAL)
+            {
+              n_subtotals++;
+              if (e->subtotal_index == n_subtotals)
+                return cat;
+              else if (e->subtotal_index == 0)
+                best = cat;
+            }
+          break;
+
+        case CTPO_CAT_TOTAL:
+          if (cat->type == CCT_TOTAL)
+            return cat;
+          break;
+
+        case CTPO_CONSTANT:
+        case CTPO_ADD:
+        case CTPO_SUB:
+        case CTPO_MUL:
+        case CTPO_DIV:
+        case CTPO_POW:
+        case CTPO_NEG:
+          NOT_REACHED ();
+        }
+    }
+  if (e->op == CTPO_CAT_SUBTOTAL && e->subtotal_index == 0 && n_subtotals > 1)
+    return NULL;
+  return best;
+}
+
+static bool
+ctables_recursive_check_postcompute (const struct ctables_pcexpr *e,
+                                     const struct ctables_category *cat,
+                                     const struct ctables_categories *cats,
+                                     const struct msg_location *cats_location)
+{
+  switch (e->op)
+    {
+    case CTPO_CAT_NUMBER:
+    case CTPO_CAT_STRING:
+    case CTPO_CAT_RANGE:
+    case CTPO_CAT_MISSING:
+    case CTPO_CAT_OTHERNM:
+    case CTPO_CAT_SUBTOTAL:
+    case CTPO_CAT_TOTAL:
+      if (!ctables_find_category_for_postcompute (cats, e))
+        {
+          if (e->op == CTPO_CAT_SUBTOTAL && e->subtotal_index == 0)
+            {
+              size_t n_subtotals = 0;
+              for (size_t i = 0; i < cats->n_cats; i++)
+                n_subtotals += (cats->cats[i].type == CCT_SUBTOTAL
+                                || cats->cats[i].type == CCT_HSUBTOTAL);
+              if (n_subtotals > 1)
+                {
+                  msg_at (SE, cats_location,
+                          ngettext ("These categories include %zu instance of "
+                                    "SUBTOTAL or HSUBTOTAL, so references from "
+                                    "computed categories must refer to "
+                                    "subtotals by position.",
+                                    "These categories include %zu instances of "
+                                    "SUBTOTAL or HSUBTOTAL, so references from "
+                                    "computed categories must refer to "
+                                    "subtotals by position.",
+                                    n_subtotals),
+                          n_subtotals);
+                  msg_at (SN, e->location,
+                          _("This is the reference that lacks a position."));
+                  return NULL;
+                }
+            }
+
+          msg_at (SE, cat->location,
+                  _("Computed category &%s references a category not included "
+                    "in the category list."),
+                  cat->pc->name);
+          msg_at (SN, e->location, _("This is the missing category."));
+          msg_at (SN, cats_location,
+                  _("To fix the problem, add the missing category to the "
+                    "list of categories here."));
+          return false;
+        }
+      return true;
+
+    case CTPO_CONSTANT:
+      return true;
+
+    case CTPO_ADD:
+    case CTPO_SUB:
+    case CTPO_MUL:
+    case CTPO_DIV:
+    case CTPO_POW:
+    case CTPO_NEG:
+      for (size_t i = 0; i < 2; i++)
+        if (e->subs[i] && !ctables_recursive_check_postcompute (
+              e->subs[i], cat, cats, cats_location))
+          return false;
+      return true;
+
+    default:
+      NOT_REACHED ();
+    }
+}
+
 static bool
 ctables_table_parse_categories (struct lexer *lexer, struct dictionary *dict,
                                 struct ctables *ct, struct ctables_table *t)
@@ -1336,104 +1579,33 @@ ctables_table_parse_categories (struct lexer *lexer, struct dictionary *dict,
   size_t allocated_cats = 0;
   if (lex_match (lexer, T_LBRACK))
     {
+      int cats_start_ofs = lex_ofs (lexer);
       do
         {
           if (c->n_cats >= allocated_cats)
             c->cats = x2nrealloc (c->cats, &allocated_cats, sizeof *c->cats);
 
+          int start_ofs = lex_ofs (lexer);
           struct ctables_category *cat = &c->cats[c->n_cats];
-          if (lex_match_id (lexer, "OTHERNM"))
-            cat->type = CCT_OTHERNM;
-          else if (lex_match_id (lexer, "MISSING"))
-            cat->type = CCT_MISSING;
-          else if (lex_match_id (lexer, "SUBTOTAL"))
-            *cat = (struct ctables_category)
-              { .type = CCT_SUBTOTAL, .total_label = NULL };
-          else if (lex_match_id (lexer, "HSUBTOTAL"))
-            *cat = (struct ctables_category)
-              { .type = CCT_HSUBTOTAL, .total_label = NULL };
-          else if (lex_match_id (lexer, "LO"))
-            {
-              if (!lex_force_match_id (lexer, "THRU") || lex_force_num (lexer))
-                return false;
-              *cat = cct_range (-DBL_MAX, lex_number (lexer));
-              lex_get (lexer);
-            }
-          else if (lex_is_number (lexer))
-            {
-              double number = lex_number (lexer);
-              lex_get (lexer);
-              if (lex_match_id (lexer, "THRU"))
-                {
-                  if (lex_match_id (lexer, "HI"))
-                    *cat = cct_range (number, DBL_MAX);
-                  else
-                    {
-                      if (!lex_force_num (lexer))
-                        return false;
-                      *cat = cct_range (number, lex_number (lexer));
-                      lex_get (lexer);
-                    }
-                }
-              else
-                *cat = (struct ctables_category) {
-                  .type = CCT_NUMBER,
-                  .number = number
-                };
-            }
-          else if (lex_is_string (lexer))
-            {
-              *cat = (struct ctables_category) {
-                .type = CCT_STRING,
-                .string = ss_xstrdup (lex_tokss (lexer)),
-              };
-              lex_get (lexer);
-            }
-          else if (lex_match (lexer, T_AND))
-            {
-              if (!lex_force_id (lexer))
-                return false;
-              struct ctables_postcompute *pc = ctables_find_postcompute (
-                ct, lex_tokcstr (lexer));
-              if (!pc)
-                {
-                  struct msg_location *loc = lex_get_location (lexer, -1, 0);
-                  msg_at (SE, loc, _("Unknown postcompute &%s."),
-                          lex_tokcstr (lexer));
-                  msg_location_destroy (loc);
-                  return false;
-                }
-              lex_get (lexer);
-
-              *cat = (struct ctables_category) {
-                .type = CCT_POSTCOMPUTE,
-                .pc = pc,
-              };
-            }
-          else
-            {
-              lex_error (lexer, NULL);
-              return false;
-            }
-
-          if (cat->type == CCT_SUBTOTAL || cat->type == CCT_HSUBTOTAL)
-            {
-              if (lex_match (lexer, T_EQUALS))
-                {
-                  if (!lex_force_string (lexer))
-                    return false;
-
-                  cat->total_label = ss_xstrdup (lex_tokss (lexer));
-                  lex_get (lexer);
-                }
-              else
-                cat->total_label = xstrdup (_("Subtotal"));
-            }
-
+          if (!ctables_table_parse_explicit_category (lexer, ct, cat))
+            return false;
+          cat->location = lex_ofs_location (lexer, start_ofs, lex_ofs (lexer) - 1);
           c->n_cats++;
+
           lex_match (lexer, T_COMMA);
         }
       while (!lex_match (lexer, T_RBRACK));
+
+      struct msg_location *cats_location
+        = lex_ofs_location (lexer, cats_start_ofs, lex_ofs (lexer) - 1);
+      for (size_t i = 0; i < c->n_cats; i++)
+        {
+          const struct ctables_category *cat = &c->cats[i];
+          if (cat->type == CCT_POSTCOMPUTE
+              && !ctables_recursive_check_postcompute (cat->pc->expr, cat,
+                                                       c, cats_location))
+            return false;
+        }
     }
 
   struct ctables_category cat = {
@@ -1566,8 +1738,7 @@ ctables_table_parse_categories (struct lexer *lexer, struct dictionary *dict,
   if (!c->n_cats)
     {
       if (c->n_cats >= allocated_cats)
-        c->cats = x2nrealloc (c->cats, &allocated_cats,
-                                sizeof *c->cats);
+        c->cats = x2nrealloc (c->cats, &allocated_cats, sizeof *c->cats);
       c->cats[c->n_cats++] = cat;
     }
 
@@ -3680,22 +3851,11 @@ ctables_pcexpr_allocate_binary (enum ctables_postcompute_op op,
   *e = (struct ctables_pcexpr) {
     .op = op,
     .subs = { sub0, sub1 },
-    .ofs = { sub0->ofs[0], sub1->ofs[1] }
+    .location = msg_location_merged (sub0->location, sub1->location),
   };
   return e;
 }
 
-static struct msg_location *
-ctables_pcexpr_location (struct lexer *lexer, const struct ctables_pcexpr *e_)
-{
-  if (!e_->location)
-    {
-      struct ctables_pcexpr *e = CONST_CAST (struct ctables_pcexpr *, e_);
-      e->location = lex_ofs_location (lexer, e->ofs[0], e->ofs[1]);
-    }
-  return e_->location;
-}
-
 /* How to parse an operator. */
 struct operator
   {
@@ -3731,8 +3891,7 @@ parse_binary_operators__ (struct lexer *lexer,
       if (!op)
         {
           if (op_count > 1 && chain_warning)
-            msg_at (SW, ctables_pcexpr_location (lexer, lhs),
-                    "%s", chain_warning);
+            msg_at (SW, lhs->location, "%s", chain_warning);
 
           return lhs;
         }
@@ -3873,8 +4032,7 @@ parse_primary (struct lexer *lexer)
       return NULL;
     }
 
-  e.ofs[0] = start_ofs;
-  e.ofs[1] = lex_ofs (lexer) - 1;
+  e.location = lex_ofs_location (lexer, start_ofs, lex_ofs (lexer) - 1);
   return xmemdup (&e, sizeof e);
 }
 
@@ -3886,7 +4044,7 @@ ctables_pcexpr_allocate_neg (struct ctables_pcexpr *sub,
   *e = (struct ctables_pcexpr) {
     .op = CTPO_NEG,
     .subs = { sub },
-    .ofs = { start_ofs, lex_ofs (lexer) - 1 },
+    .location = lex_ofs_location (lexer, start_ofs, lex_ofs (lexer) - 1),
   };
   return e;
 }
@@ -3913,7 +4071,7 @@ parse_exp (struct lexer *lexer)
   *lhs = (struct ctables_pcexpr) {
     .op = CTPO_CONSTANT,
     .number = -lex_tokval (lexer),
-    .ofs = { start_ofs, lex_ofs (lexer) },
+    .location = lex_ofs_location (lexer, start_ofs, lex_ofs (lexer)),
   };
   lex_get (lexer);
 
index 83c7320168eef5a8ca4cf68baaa6c129bd98dd0d..38726d9f5b4827661940d8500e9294a85b785d79 100644 (file)
@@ -174,6 +174,16 @@ msg_location_merge (struct msg_location **dstp, const struct msg_location *src)
     dst->end = src->end;
 }
 
+struct msg_location *
+msg_location_merged (const struct msg_location *a,
+                     const struct msg_location *b)
+{
+  struct msg_location *new = msg_location_dup (a);
+  if (b)
+    msg_location_merge (&new, b);
+  return new;
+}
+
 struct msg_location *
 msg_location_dup (const struct msg_location *src)
 {
index 813febe82a07a3906315937999a8886cc05d6fd0..11e5b9d98eeda0694180836d8149c71ee2d24055 100644 (file)
@@ -118,6 +118,8 @@ struct msg_location *msg_location_dup (const struct msg_location *);
 void msg_location_remove_columns (struct msg_location *);
 
 void msg_location_merge (struct msg_location **, const struct msg_location *);
+struct msg_location *msg_location_merged (const struct msg_location *,
+                                          const struct msg_location *);
 
 bool msg_location_is_empty (const struct msg_location *);
 void msg_location_format (const struct msg_location *, struct string *);