CTABLES: Support setting the title, caption, ...
[pspp] / src / language / stats / ctables.c
index bef413a879d65d19aba3e3e546ae19ad42f00299..53b99519dcadb8bdfb853e44edb0158c30b95df4 100644 (file)
@@ -224,9 +224,13 @@ enum ctables_label_position
 
 struct var_array
   {
-    const struct ctables_axis *summary;
     struct variable **vars;
     size_t n;
+    size_t scale_idx;
+
+    struct ctables_summary_spec *summaries;
+    size_t n_summaries;
+    struct variable *summary_var;
   };
 
 struct var_array2
@@ -239,6 +243,7 @@ struct ctables_table
   {
     struct ctables_axis *axes[PIVOT_N_AXES];
     struct var_array2 vaas[PIVOT_N_AXES];
+    enum pivot_axis_type summary_axis;
     struct hmap ft;
 
     enum pivot_axis_type slabels_position;
@@ -616,6 +621,21 @@ ctables_summary_default_format (enum ctables_summary_function function,
     }
 }
 
+static char *
+ctables_summary_default_label (enum ctables_summary_function function,
+                               double percentile)
+{
+  static const char *default_labels[] = {
+#define S(ENUM, NAME, LABEL, FORMAT, AVAILABILITY) [ENUM] = LABEL,
+    SUMMARIES
+#undef S
+  };
+
+  return (function == CTSF_PTILE
+          ? xasprintf (_("Percentile %.2f"), percentile)
+          : xstrdup (gettext (default_labels[function])));
+}
+
 static const char *
 ctables_summary_function_name (enum ctables_summary_function function)
 {
@@ -792,17 +812,8 @@ ctables_axis_parse_postfix (struct ctables_axis_parse_ctx *ctx)
           label = ss_xstrdup (lex_tokss (ctx->lexer));
           lex_get (ctx->lexer);
         }
-      else if (function == CTSF_PTILE)
-        label = xasprintf (_("Percentile %.2f"), percentile);
       else
-        {
-          static const char *default_labels[] = {
-#define S(ENUM, NAME, LABEL, FORMAT, AVAILABILITY) [ENUM] = LABEL,
-            SUMMARIES
-#undef S
-          };
-          label = xstrdup (gettext (default_labels[function]));
-        }
+        label = ctables_summary_default_label (function, percentile);
 
       /* Parse format. */
       struct fmt_spec format;
@@ -1299,11 +1310,22 @@ nest_fts (struct var_array2 va0, struct var_array2 va1)
           vars[n++] = b->vars[k];
         assert (n == allocate);
 
-        assert (!(a->summary && b->summary));
+        const struct var_array *summary_src;
+        if (!a->summary_var)
+          summary_src = b;
+        else if (!b->summary_var)
+          summary_src = a;
+        else
+          NOT_REACHED ();
         vaa.vas[vaa.n++] = (struct var_array) {
-          .summary = a->summary ? a->summary : b->summary,
           .vars = vars,
-          .n = n
+          .scale_idx = (a->scale_idx != SIZE_MAX ? a->scale_idx
+                        : b->scale_idx != SIZE_MAX ? a->n + b->scale_idx
+                        : SIZE_MAX),
+          .n = n,
+          .summaries = summary_src->summaries,
+          .n_summaries = summary_src->n_summaries,
+          .summary_var = summary_src->summary_var,
         };
       }
   var_array2_uninit (&va0);
@@ -1335,18 +1357,22 @@ enumerate_fts (enum pivot_axis_type axis_type, const struct ctables_axis *a)
     {
     case CTAO_VAR:
       assert (!a->var.is_mrset);
+
+      struct variable **vars = xmalloc (sizeof *vars);
+      *vars = a->var.var;
+
       struct var_array *va = xmalloc (sizeof *va);
-      if (a->scale)
-        *va = (struct var_array) { .n = 0 };
-      else
+      *va = (struct var_array) {
+        .vars = vars,
+        .n = 1,
+        .scale_idx = a->scale ? 0 : SIZE_MAX,
+      };
+      if (a->n_summaries || a->scale)
         {
-          struct variable **vars = xmalloc (sizeof *vars);
-          *vars = a->var.var;
-          enum pivot_axis_type *axes = xmalloc (sizeof *axes);
-          *axes = axis_type;
-          *va = (struct var_array) { .vars = vars, .n = 1 };
+          va->summaries = a->summaries;
+          va->n_summaries = a->n_summaries;
+          va->summary_var = a->var.var;
         }
-      va->summary = a->scale || a->n_summaries ? a : NULL;
       return (struct var_array2) { .vas = va, .n = 1 };
 
     case CTAO_STACK:
@@ -1383,7 +1409,6 @@ union ctables_summary
     /* XXX percentiles, median, mode, multiple response */
   };
 
-#if 0
 static void
 ctables_summary_init (union ctables_summary *s,
                       const struct ctables_summary_spec *ss)
@@ -1473,7 +1498,7 @@ ctables_summary_init (union ctables_summary *s,
     }
 }
 
-static void
+static void UNUSED
 ctables_summary_uninit (union ctables_summary *s,
                         const struct ctables_summary_spec *ss)
 {
@@ -1788,7 +1813,6 @@ ctables_summary_value (union ctables_summary *s,
 
   NOT_REACHED ();
 }
-#endif
 
 struct ctables_freq
   {
@@ -1798,11 +1822,11 @@ struct ctables_freq
       {
         size_t vaa_idx;
         union value *values;
+        int leaf;
       }
     axes[PIVOT_N_AXES];
 
-    //union ctables_summary *summaries;
-    double count;
+    union ctables_summary *summaries;
   };
 
 #if 0
@@ -1849,13 +1873,14 @@ ctables_freq_compare_3way (const void *a_, const void *b_, const void *aux_)
 
   const struct var_array *va = &aux->t->vaas[aux->a].vas[a_idx];
   for (size_t i = 0; i < va->n; i++)
-    {
-      int cmp = value_compare_3way (&a->axes[aux->a].values[i],
-                                    &b->axes[aux->a].values[i],
-                                    var_get_width (va->vars[i]));
-      if (cmp)
-        return cmp;
-    }
+    if (i != va->scale_idx)
+      {
+        int cmp = value_compare_3way (&a->axes[aux->a].values[i],
+                                      &b->axes[aux->a].values[i],
+                                      var_get_width (va->vars[i]));
+        if (cmp)
+          return cmp;
+      }
   return 0;
 }
 
@@ -1888,6 +1913,7 @@ ctables_freqtab_insert (struct ctables_table *t,
     [PIVOT_AXIS_COLUMN] = ic,
     [PIVOT_AXIS_LAYER] = il,
   };
+  const struct var_array *ss = &t->vaas[t->summary_axis].vas[ix[t->summary_axis]];
 
   size_t hash = 0;
   for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
@@ -1895,8 +1921,9 @@ ctables_freqtab_insert (struct ctables_table *t,
       const struct var_array *va = &t->vaas[a].vas[ix[a]];
       hash = hash_int (ix[a], hash);
       for (size_t i = 0; i < va->n; i++)
-        hash = value_hash (case_data (c, va->vars[i]),
-                           var_get_width (va->vars[i]), hash);
+        if (i != va->scale_idx)
+          hash = value_hash (case_data (c, va->vars[i]),
+                             var_get_width (va->vars[i]), hash);
     }
 
   struct ctables_freq *f;
@@ -1908,14 +1935,14 @@ ctables_freqtab_insert (struct ctables_table *t,
           if (f->axes[a].vaa_idx != ix[a])
             goto not_equal;
           for (size_t i = 0; i < va->n; i++)
-            if (!value_equal (case_data (c, va->vars[i]),
-                              &f->axes[a].values[i],
-                              var_get_width (va->vars[i])))
-              goto not_equal;
+            if (i != va->scale_idx
+                && !value_equal (case_data (c, va->vars[i]),
+                                 &f->axes[a].values[i],
+                                 var_get_width (va->vars[i])))
+                goto not_equal;
         }
 
-      f->count += weight;
-      return;
+      goto summarize;
 
     not_equal: ;
     }
@@ -1932,8 +1959,15 @@ ctables_freqtab_insert (struct ctables_table *t,
         value_clone (&f->axes[a].values[i], case_data (c, va->vars[i]),
                      var_get_width (va->vars[i]));
     }
-  f->count = weight;
+  f->summaries = xmalloc (ss->n_summaries * sizeof *f->summaries);
+  for (size_t i = 0; i < ss->n_summaries; i++)
+    ctables_summary_init (&f->summaries[i], &ss->summaries[i]);
   hmap_insert (&t->ft, &f->node, hash);
+
+summarize:
+  for (size_t i = 0; i < ss->n_summaries; i++)
+    ctables_summary_add (&f->summaries[i], &ss->summaries[i], ss->summary_var,
+                         case_data (c, ss->summary_var), weight);
 }
 
 static bool
@@ -1951,17 +1985,41 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
             *va = (struct var_array) { .n = 0 };
             t->vaas[a] = (struct var_array2) { .vas = va, .n = 1 };
           }
+
+      for (size_t i = 0; i < t->vaas[t->summary_axis].n; i++)
+        {
+          struct var_array *va = &t->vaas[t->summary_axis].vas[i];
+          if (!va->n_summaries)
+            {
+              va->summaries = xmalloc (sizeof *va->summaries);
+              va->n_summaries = 1;
+
+              enum ctables_summary_function function
+                = va->summary_var ? CTSF_MEAN : CTSF_COUNT;
+              struct ctables_var var = { .is_mrset = false, .var = va->summary_var };
+
+              *va->summaries = (struct ctables_summary_spec) {
+                .function = function,
+                .format = ctables_summary_default_format (function, &var),
+                .label = ctables_summary_default_label (function, 0),
+              };
+              if (!va->summary_var)
+                va->summary_var = va->vars[0];
+            }
+        }
     }
 
   struct casereader *input = casereader_create_filter_weight (proc_open (ds),
                                                               dataset_dict (ds),
                                                               NULL, NULL);
   bool warn_on_invalid = true;
+  double total_weight = 0;
   for (struct ccase *c = casereader_read (input); c;
        case_unref (c), c = casereader_read (input))
     {
       double weight = dict_get_case_weight (dataset_dict (ds), c,
                                             &warn_on_invalid);
+      total_weight += weight;
 
       for (size_t i = 0; i < ct->n_tables; i++)
         {
@@ -1979,7 +2037,18 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
     {
       struct ctables_table *t = ct->tables[i];
 
-      struct pivot_table *pt = pivot_table_create (N_("Custom Tables"));
+      struct pivot_table *pt = pivot_table_create__ (
+        (t->title
+         ? pivot_value_new_user_text (t->title, SIZE_MAX)
+         : pivot_value_new_text (N_("Custom Tables"))),
+        NULL);
+      if (t->caption)
+        pivot_table_set_caption (
+          pt, pivot_value_new_user_text (t->caption, SIZE_MAX));
+      if (t->corner)
+        pivot_table_set_caption (
+          pt, pivot_value_new_user_text (t->corner, SIZE_MAX));
+
       pivot_table_set_look (pt, ct->look);
       struct pivot_dimension *d[PIVOT_N_AXES];
       for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
@@ -1989,12 +2058,14 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
             [PIVOT_AXIS_COLUMN] = N_("Columns"),
             [PIVOT_AXIS_LAYER] = N_("Layers"),
           };
-          d[a] = (t->axes[a]
+          d[a] = (t->axes[a] || a == t->summary_axis
                   ? pivot_dimension_create (pt, a, names[a])
                   : NULL);
           if (!d[a])
             continue;
 
+          assert (t->axes[a]);
+
           struct ctables_freq **sorted = xnmalloc (t->ft.count, sizeof *sorted);
 
           struct ctables_freq *f;
@@ -2004,7 +2075,7 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
           assert (n == t->ft.count);
 
           struct ctables_freq_sort_aux aux = { .t = t, .a = a };
-          n = sort_unique (sorted, n, sizeof *sorted, ctables_freq_compare_3way, &aux);
+          sort (sorted, n, sizeof *sorted, ctables_freq_compare_3way, &aux);
 
           size_t max_depth = 0;
           for (size_t j = 0; j < t->vaas[a].n; j++)
@@ -2013,6 +2084,7 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
 
           struct pivot_category **groups = xnmalloc (max_depth, sizeof *groups);
           struct pivot_category *top = NULL;
+          int prev_leaf = 0;
           for (size_t j = 0; j < n; j++)
             {
               struct ctables_freq *f = sorted[j];
@@ -2026,9 +2098,10 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                   if (prev->axes[a].vaa_idx == f->axes[a].vaa_idx)
                     {
                       for (; n_common < va->n; n_common++)
-                        if (!value_equal (&prev->axes[a].values[n_common],
-                                          &f->axes[a].values[n_common],
-                                          var_get_type (va->vars[n_common])))
+                        if (n_common != va->scale_idx
+                            && !value_equal (&prev->axes[a].values[n_common],
+                                             &f->axes[a].values[n_common],
+                                             var_get_type (va->vars[n_common])))
                           break;
                     }
                   else
@@ -2036,45 +2109,95 @@ ctables_execute (struct dataset *ds, struct ctables *ct)
                 }
               else
                 new_subtable = true;
+
               if (new_subtable)
-                top = pivot_category_create_group__ (
-                  d[a]->root, pivot_value_new_variable (va->vars[0]));
-              printf ("n_common=%zu\n", n_common);
+                {
+                  enum ctables_vlabel vlabel = ct->vlabels[var_get_dict_index (va->vars[0])];
+                  top = d[a]->root;
+                  if (vlabel != CTVL_NONE)
+                    top = pivot_category_create_group__ (
+                      top, pivot_value_new_variable (va->vars[0]));
+                }
+              if (n_common == va->n)
+                {
+                  f->axes[a].leaf = prev_leaf;
+                  continue;
+                }
 
               for (size_t k = n_common; k < va->n; k++)
                 {
                   struct pivot_category *parent = k > 0 ? groups[k - 1] : top;
 
+                  struct pivot_value *label
+                    = (k != va->scale_idx
+                       ? pivot_value_new_var_value (va->vars[k],
+                                                    &f->axes[a].values[k])
+                       : NULL);
                   if (k == va->n - 1)
                     {
-                      pivot_category_create_leaf (
-                        parent,
-                        pivot_value_new_var_value (va->vars[va->n - 1],
-                                                   &f->axes[a].values[va->n - 1]));
+                      if (a == t->summary_axis)
+                        {
+                          if (label)
+                            parent = pivot_category_create_group__ (parent, label);
+                          for (size_t m = 0; m < va->n_summaries; m++)
+                            {
+                              int leaf = pivot_category_create_leaf (
+                                parent, pivot_value_new_text (va->summaries[m].label));
+                              if (m == 0)
+                                prev_leaf = leaf;
+                            }
+                        }
+                      else
+                        {
+                          /* This assertion is true as long as the summary axis
+                             is the axis where the summaries are displayed. */
+                          assert (label);
+
+                          prev_leaf = pivot_category_create_leaf (parent, label);
+                        }
                       break;
                     }
 
-                  parent = pivot_category_create_group__ (
-                    parent,
-                    pivot_value_new_var_value (va->vars[k], &f->axes[a].values[k]));
+                  if (label)
+                    parent = pivot_category_create_group__ (parent, label);
 
-                  parent = pivot_category_create_group__ (
-                    parent, pivot_value_new_variable (va->vars[k]));
+                  enum ctables_vlabel vlabel = ct->vlabels[var_get_dict_index (va->vars[k + 1])];
+                  if (vlabel != CTVL_NONE)
+                    parent = pivot_category_create_group__ (
+                      parent, pivot_value_new_variable (va->vars[k + 1]));
                   groups[k] = parent;
-
-#if 0
-                      for (size_t p = 0; p < ft->n_summaries; p++)
-                        {
-                          if (a == t->slabels_position)
-                            pivot_category_create_leaf (
-                              c, pivot_value_new_text (ft->summaries[p].label));
-                          //pivot_table_put1 (pt, leaf, pivot_value_new_number (value));
-                        }
-#endif
-                    }
                 }
+
+              f->axes[a].leaf = prev_leaf;
+            }
+          free (sorted);
           free (groups);
         }
+      struct ctables_freq *f;
+      HMAP_FOR_EACH (f, struct ctables_freq, node, &t->ft)
+        {
+          const struct var_array *ss = &t->vaas[t->summary_axis].vas[f->axes[t->summary_axis].vaa_idx];
+          for (size_t j = 0; j < ss->n_summaries; j++)
+            {
+              size_t dindexes[3];
+              size_t n_dindexes = 0;
+
+              for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
+                if (d[a])
+                  {
+                    int leaf = f->axes[a].leaf;
+                    if (a == t->summary_axis)
+                      leaf += j;
+                    dindexes[n_dindexes++] = leaf;
+                  }
+
+              double d = ctables_summary_value (&f->summaries[j], &ss->summaries[j]);
+              struct pivot_value *value = pivot_value_new_number (d);
+              value->numeric.format = ss->summaries[j].format;
+              pivot_table_put (pt, dindexes, n_dindexes, value);
+            }
+        }
+
       pivot_table_submit (pt);
     }
 
@@ -2363,26 +2486,57 @@ cmd_ctables (struct lexer *lexer, struct dataset *ds)
 
       const struct ctables_axis *scales[PIVOT_N_AXES];
       size_t n_scales = 0;
-      for (size_t i = 0; i < 3; i++)
+      for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
         {
-          scales[i] = find_scale (t->axes[i]);
-          if (scales[i])
+          scales[a] = find_scale (t->axes[a]);
+          if (scales[a])
             n_scales++;
         }
       if (n_scales > 1)
         {
-          msg (SE, _("Scale variables may appear only on one dimension."));
+          msg (SE, _("Scale variables may appear only on one axis."));
           if (scales[PIVOT_AXIS_ROW])
             msg_at (SN, scales[PIVOT_AXIS_ROW]->loc,
-                    _("This scale variable appears in the rows dimension."));
+                    _("This scale variable appears on the rows axis."));
           if (scales[PIVOT_AXIS_COLUMN])
             msg_at (SN, scales[PIVOT_AXIS_COLUMN]->loc,
-                    _("This scale variable appears in the columns dimension."));
+                    _("This scale variable appears on the columns axis."));
           if (scales[PIVOT_AXIS_LAYER])
             msg_at (SN, scales[PIVOT_AXIS_LAYER]->loc,
-                    _("This scale variable appears in the layer dimension."));
+                    _("This scale variable appears on the layer axis."));
+          goto error;
+        }
+
+      const struct ctables_axis *summaries[PIVOT_N_AXES];
+      size_t n_summaries = 0;
+      for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
+        {
+          summaries[a] = (scales[a]
+                          ? scales[a]
+                          : find_categorical_summary_spec (t->axes[a]));
+          if (summaries[a])
+            n_summaries++;
+        }
+      if (n_summaries > 1)
+        {
+          msg (SE, _("Summaries may appear only on one axis."));
+          if (summaries[PIVOT_AXIS_ROW])
+            msg_at (SN, summaries[PIVOT_AXIS_ROW]->loc,
+                    _("This variable on the rows axis has a summary."));
+          if (summaries[PIVOT_AXIS_COLUMN])
+            msg_at (SN, summaries[PIVOT_AXIS_COLUMN]->loc,
+                    _("This variable on the columns axis has a summary."));
+          if (summaries[PIVOT_AXIS_LAYER])
+            msg_at (SN, summaries[PIVOT_AXIS_LAYER]->loc,
+                    _("This variable on the layers axis has a summary."));
           goto error;
         }
+      for (enum pivot_axis_type a = 0; a < PIVOT_N_AXES; a++)
+        if (n_summaries ? summaries[a] : t->axes[a])
+          {
+            t->summary_axis = a;
+            break;
+          }
 
       if (lex_token (lexer) == T_ENDCMD)
         break;