CROSSTABS: make sort-based crosstabs work
[pspp] / src / language / stats / crosstabs.q
index b137e284968372f2c0580774522c994362252ad4..59befd1ebcaa862acdf0bbb9ba886f32b83bb4b5 100644 (file)
 #include <stdlib.h>
 #include <stdio.h>
 
-#include <data/case.h>
-#include <data/casegrouper.h>
-#include <data/casereader.h>
-#include <data/data-out.h>
-#include <data/dictionary.h>
-#include <data/format.h>
-#include <data/procedure.h>
-#include <data/value-labels.h>
-#include <data/variable.h>
-#include <language/command.h>
-#include <language/dictionary/split-file.h>
-#include <language/lexer/lexer.h>
-#include <language/lexer/variable-parser.h>
-#include <libpspp/array.h>
-#include <libpspp/assertion.h>
-#include <libpspp/compiler.h>
-#include <libpspp/hash.h>
-#include <libpspp/hmap.h>
-#include <libpspp/hmapx.h>
-#include <libpspp/message.h>
-#include <libpspp/misc.h>
-#include <libpspp/pool.h>
-#include <libpspp/str.h>
-#include <output/tab.h>
-
-#include "minmax.h"
-#include "xalloc.h"
-#include "xsize.h"
+#include "data/case.h"
+#include "data/casegrouper.h"
+#include "data/casereader.h"
+#include "data/casewriter.h"
+#include "data/data-out.h"
+#include "data/dictionary.h"
+#include "data/format.h"
+#include "data/procedure.h"
+#include "data/subcase.h"
+#include "data/value-labels.h"
+#include "data/variable.h"
+#include "language/command.h"
+#include "language/dictionary/split-file.h"
+#include "language/lexer/lexer.h"
+#include "language/lexer/variable-parser.h"
+#include "libpspp/array.h"
+#include "libpspp/assertion.h"
+#include "libpspp/compiler.h"
+#include "libpspp/hash.h"
+#include "libpspp/hmap.h"
+#include "libpspp/hmapx.h"
+#include "libpspp/message.h"
+#include "libpspp/misc.h"
+#include "libpspp/pool.h"
+#include "libpspp/str.h"
+#include "math/sort.h"
+#include "output/tab.h"
+
+#include "gl/minmax.h"
+#include "gl/xalloc.h"
+#include "gl/xsize.h"
 
 #include "gettext.h"
 #define _(msgid) gettext (msgid)
 /* A single table entry for general mode. */
 struct table_entry
   {
-    struct hmap_node node;      /* Entry in hash table. */
     double freq;                /* Frequency count. */
     union value values[1];     /* Values. */
   };
@@ -137,7 +139,8 @@ struct pivot_table
     union value *const_values;
 
     /* Data. */
-    struct hmap data;
+    struct subcase src_sc, dst_sc;
+    struct casewriter *sorter;
     struct table_entry **entries;
     size_t n_entries;
 
@@ -200,180 +203,256 @@ struct crosstabs_proc
     unsigned int statistics;    /* Bit k is 1 if statistic k is requested. */
   };
 
-static void
-init_proc (struct crosstabs_proc *proc, struct dataset *ds)
-{
-  const struct variable *wv = dict_get_weight (dataset_dict (ds));
-  proc->dict = dataset_dict (ds);
-  proc->bad_warn = true;
-  proc->variables = NULL;
-  proc->n_variables = 0;
-  proc->pivots = NULL;
-  proc->n_pivots = 0;
-  proc->weight_format = wv ? *var_get_print_format (wv) : F_8_0;
-}
-
-static void
-free_proc (struct crosstabs_proc *proc)
-{
-  struct pivot_table *pt;
-
-  free (proc->variables);
-  for (pt = &proc->pivots[0]; pt < &proc->pivots[proc->n_pivots]; pt++)
-    {
-      free (pt->vars);
-      free (pt->const_vars);
-      /* We must not call value_destroy on const_values because
-         it is a wild pointer; it never pointed to anything owned
-         by the pivot_table.
-
-         The rest of the data was allocated and destroyed at a
-         lower level already. */
-    }
-  free (proc->pivots);
-}
-
-static int internal_cmd_crosstabs (struct lexer *lexer, struct dataset *ds,
-                                   struct crosstabs_proc *);
 static bool should_tabulate_case (const struct pivot_table *,
-                                  const struct ccase *, enum mv_class exclude);
-static void tabulate_general_case (struct pivot_table *, const struct ccase *,
-                                   double weight);
-static void tabulate_integer_case (struct pivot_table *, const struct ccase *,
-                                   double weight);
+                                  const struct ccase *, enum mv_class exclude,
+                                  size_t n_splits);
 static void postcalc (struct crosstabs_proc *);
 static void submit (struct pivot_table *, struct tab_table *);
 
-/* Parse and execute CROSSTABS, then clean up. */
-int
-cmd_crosstabs (struct lexer *lexer, struct dataset *ds)
+static struct ccase *
+crs_combine_cases (struct ccase *a, struct ccase *b, void *aux UNUSED)
 {
-  struct crosstabs_proc proc;
-  int result;
+  size_t weight_idx = caseproto_get_n_widths (case_get_proto (a)) - 1;
 
-  init_proc (&proc, ds);
-  result = internal_cmd_crosstabs (lexer, ds, &proc);
-  free_proc (&proc);
+  a = case_unshare (a);
+  case_data_rw_idx (a, weight_idx)->f += case_data_idx (b, weight_idx)->f;
+  case_unref (b);
 
-  return result;
+  return a;
 }
 
 /* Parses and executes the CROSSTABS procedure. */
-static int
-internal_cmd_crosstabs (struct lexer *lexer, struct dataset *ds,
-                        struct crosstabs_proc *proc)
+int
+cmd_crosstabs (struct lexer *lexer, struct dataset *ds)
 {
-  struct casegrouper *grouper;
-  struct casereader *input, *group;
+  const struct variable *wv = dict_get_weight (dataset_dict (ds));
+  const struct dictionary *dict = dataset_dict (ds);
+  size_t n_splits = dict_get_split_cnt (dict);
+  struct crosstabs_proc proc;
+  struct casereader *input;
   struct cmd_crosstabs cmd;
   struct pivot_table *pt;
+  struct ccase *c;
+  int result;
   bool ok;
   int i;
 
-  if (!parse_crosstabs (lexer, ds, &cmd, proc))
-    return CMD_FAILURE;
+  proc.dict = dataset_dict (ds);
+  proc.bad_warn = true;
+  proc.variables = NULL;
+  proc.n_variables = 0;
+  proc.pivots = NULL;
+  proc.n_pivots = 0;
+  proc.weight_format = wv ? *var_get_print_format (wv) : F_8_0;
 
-  proc->mode = proc->n_variables ? INTEGER : GENERAL;
+  if (!parse_crosstabs (lexer, ds, &cmd, &proc))
+    {
+      result = CMD_FAILURE;
+      goto exit;
+    }
+
+  proc.mode = proc.n_variables ? INTEGER : GENERAL;
 
   /* CELLS. */
   if (!cmd.sbc_cells)
-    proc->cells = 1u << CRS_CL_COUNT;
+    proc.cells = 1u << CRS_CL_COUNT;
   else if (cmd.a_cells[CRS_CL_ALL])
-    proc->cells = UINT_MAX;
+    proc.cells = UINT_MAX;
   else
     {
-      proc->cells = 0;
+      proc.cells = 0;
       for (i = 0; i < CRS_CL_count; i++)
        if (cmd.a_cells[i])
-         proc->cells |= 1u << i;
-      if (proc->cells == 0)
-        proc->cells = ((1u << CRS_CL_COUNT)
+         proc.cells |= 1u << i;
+      if (proc.cells == 0)
+        proc.cells = ((1u << CRS_CL_COUNT)
                        | (1u << CRS_CL_ROW)
                        | (1u << CRS_CL_COLUMN)
                        | (1u << CRS_CL_TOTAL));
     }
-  proc->cells &= ((1u << CRS_CL_count) - 1);
-  proc->cells &= ~((1u << CRS_CL_NONE) | (1u << CRS_CL_ALL));
-  proc->n_cells = 0;
+  proc.cells &= ((1u << CRS_CL_count) - 1);
+  proc.cells &= ~((1u << CRS_CL_NONE) | (1u << CRS_CL_ALL));
+  proc.n_cells = 0;
   for (i = 0; i < CRS_CL_count; i++)
-    if (proc->cells & (1u << i))
-      proc->a_cells[proc->n_cells++] = i;
+    if (proc.cells & (1u << i))
+      proc.a_cells[proc.n_cells++] = i;
 
   /* STATISTICS. */
   if (cmd.a_statistics[CRS_ST_ALL])
-    proc->statistics = UINT_MAX;
+    proc.statistics = UINT_MAX;
   else if (cmd.sbc_statistics)
     {
       int i;
 
-      proc->statistics = 0;
+      proc.statistics = 0;
       for (i = 0; i < CRS_ST_count; i++)
        if (cmd.a_statistics[i])
-         proc->statistics |= 1u << i;
-      if (proc->statistics == 0)
-        proc->statistics |= 1u << CRS_ST_CHISQ;
+         proc.statistics |= 1u << i;
+      if (proc.statistics == 0)
+        proc.statistics |= 1u << CRS_ST_CHISQ;
     }
   else
-    proc->statistics = 0;
+    proc.statistics = 0;
 
   /* MISSING. */
-  proc->exclude = (cmd.miss == CRS_TABLE ? MV_ANY
+  proc.exclude = (cmd.miss == CRS_TABLE ? MV_ANY
                    : cmd.miss == CRS_INCLUDE ? MV_SYSTEM
                    : MV_NEVER);
-  if (proc->mode == GENERAL && proc->mode == MV_NEVER)
+  if (proc.mode == GENERAL && proc.mode == MV_NEVER)
     {
       msg (SE, _("Missing mode REPORT not allowed in general mode.  "
                 "Assuming MISSING=TABLE."));
-      proc->mode = MV_ANY;
+      proc.mode = MV_ANY;
     }
 
   /* PIVOT. */
-  proc->pivot = cmd.pivot == CRS_PIVOT;
+  proc.pivot = cmd.pivot == CRS_PIVOT;
 
-  input = casereader_create_filter_weight (proc_open (ds), dataset_dict (ds),
-                                           NULL, NULL);
-  grouper = casegrouper_create_splits (input, dataset_dict (ds));
-  while (casegrouper_get_next_group (grouper, &group))
+  for (pt = &proc.pivots[0]; pt < &proc.pivots[proc.n_pivots]; pt++)
     {
-      struct ccase *c;
+      struct caseproto *proto;
+      struct subcase sort;
+
+      subcase_init_empty (&pt->src_sc);
+      subcase_add_vars_always (&pt->src_sc, dict_get_split_vars (dict),
+                               n_splits, SC_ASCEND);
+      subcase_add_vars_always (&pt->src_sc, pt->vars, pt->n_vars, SC_ASCEND);
+
+      subcase_clone (&pt->dst_sc, &pt->src_sc);
+      subcase_project (&pt->dst_sc, 0);
 
-      /* Output SPLIT FILE variables. */
-      c = casereader_peek (group, 0);
-      if (c != NULL)
+      subcase_init_empty (&sort);
+      for (i = 0; i < n_splits; i++)
+        subcase_add_always (&sort, i, subcase_get_width (&pt->src_sc, i),
+                            SC_ASCEND);
+      for (i = 0; i < pt->n_vars; i++)
         {
-          output_split_file_values (ds, c);
-          case_unref (c);
+          size_t var_idx = n_splits + (i == pt->n_vars - 2 ? ROW_VAR
+                                       : i == pt->n_vars - 1 ? COL_VAR
+                                       : i + 2);
+          subcase_add_always (&sort, var_idx,
+                              subcase_get_width (&pt->src_sc, var_idx),
+                              SC_ASCEND);
         }
 
-      /* Initialize hash tables. */
-      for (pt = &proc->pivots[0]; pt < &proc->pivots[proc->n_pivots]; pt++)
-        hmap_init (&pt->data);
+      proto = caseproto_ref (subcase_get_proto (&pt->dst_sc));
+      proto = caseproto_add_width (proto, 0);
+      pt->sorter = sort_distinct_create_writer (&sort, proto,
+                                                crs_combine_cases, NULL, NULL);
+      caseproto_unref (proto);
+    }
 
-      /* Tabulate. */
-      for (; (c = casereader_read (group)) != NULL; case_unref (c))
-        for (pt = &proc->pivots[0]; pt < &proc->pivots[proc->n_pivots]; pt++)
-          {
-            double weight = dict_get_case_weight (dataset_dict (ds), c,
-                                                  &proc->bad_warn);
-            if (should_tabulate_case (pt, c, proc->exclude))
+  input = casereader_create_filter_weight (proc_open (ds), dict, NULL, NULL);
+  for (; (c = casereader_read (input)) != NULL; case_unref (c))
+    {
+      for (pt = &proc.pivots[0]; pt < &proc.pivots[proc.n_pivots]; pt++)
+        {
+          const struct caseproto *proto = casewriter_get_proto (pt->sorter);
+          struct ccase *pt_case;
+
+          pt_case = case_create (proto);
+
+          subcase_copy (&pt->src_sc, c, &pt->dst_sc, pt_case);
+          if (should_tabulate_case (pt, pt_case, proc.exclude, n_splits)
+              && proc.mode == INTEGER)
+            {
+              for (i = 0; i < pt->n_vars; i++)
+                {
+                  double *d = &case_data_rw_idx (pt_case, i + n_splits)->f;
+                  *d = (int) *d;
+                }
+            }
+
+          case_data_rw_idx (pt_case, caseproto_get_n_widths (proto) - 1)->f
+             = dict_get_case_weight (dict, c, &proc.bad_warn);
+
+          casewriter_write (pt->sorter, pt_case);
+        }
+    }
+  ok = casereader_destroy (input);
+  ok = proc_commit (ds) && ok;
+
+  for (pt = &proc.pivots[0]; pt < &proc.pivots[proc.n_pivots]; pt++)
+    {
+      const struct caseproto *proto;
+      struct casegrouper *grouper;
+      struct casereader *data;
+      struct subcase group_sc;
+      struct casereader *group;
+
+      subcase_init_vars (&group_sc, dict_get_split_vars (dict),
+                         dict_get_split_cnt (dict));
+      subcase_project (&group_sc, 0);
+
+      data = casewriter_make_reader (pt->sorter);
+      proto = casereader_get_proto (data);
+      grouper = casegrouper_create_subcase (data, &group_sc);
+      subcase_destroy (&group_sc);
+
+      for (; casegrouper_get_next_group (grouper, &group);
+           casereader_destroy (group))
+        {
+          casenumber n_entries;
+          casenumber i;
+
+          c = casereader_peek (group, 0);
+          if (c != NULL)
+            {
+              /* XXX output_split_file_values (ds, c); */
+              case_unref (c);
+            }
+
+          n_entries = casereader_count_cases (group);
+          if (n_entries > 1000000)
+            {
+              msg (SW, _("Omitting analysis of crosstabulation that has %lu "
+                         "nonempty cells."),
+                   (unsigned long int) n_entries);
+              continue;
+            }
+
+          pt->entries = xmalloc (n_entries * sizeof *pt->entries);
+          pt->n_entries = 0;
+          pt->missing = 0.0;
+          for (; (c = casereader_read (group)) != NULL; case_unref (c))
+            if (should_tabulate_case (pt, c, proc.exclude, n_splits))
               {
-                if (proc->mode == GENERAL)
-                  tabulate_general_case (pt, c, weight);
-                else
-                  tabulate_integer_case (pt, c, weight);
+                struct table_entry *e;
+
+                e = xmalloc (table_entry_size (pt->n_vars));
+                for (i = 0; i < pt->n_vars; i++)
+                  value_clone (&e->values[i], case_data_idx (c, i + n_splits),
+                               caseproto_get_width (proto, i + n_splits));
+                e->freq = case_num_idx (c, pt->n_vars + n_splits);
+
+                pt->entries[pt->n_entries++] = e;
               }
             else
-              pt->missing += weight;
-          }
-      casereader_destroy (group);
+              pt->missing += case_num_idx (c, pt->n_vars + n_splits);
 
-      /* Output. */
-      postcalc (proc);
+          postcalc (&proc);
+        }
+      ok = casegrouper_destroy (grouper) && ok;
     }
-  ok = casegrouper_destroy (grouper);
-  ok = proc_commit (ds) && ok;
 
-  return ok ? CMD_SUCCESS : CMD_CASCADING_FAILURE;
+  result = ok ? CMD_SUCCESS : CMD_CASCADING_FAILURE;
+
+exit:
+  free (proc.variables);
+  for (pt = &proc.pivots[0]; pt < &proc.pivots[proc.n_pivots]; pt++)
+    {
+      free (pt->vars);
+      free (pt->const_vars);
+      /* We must not call value_destroy on const_values because
+         it is a wild pointer; it never pointed to anything owned
+         by the pivot_table.
+
+         The rest of the data was allocated and destroyed at a
+         lower level already. */
+    }
+  free (proc.pivots);
+
+  return result;
 }
 
 /* Parses the TABLES subcommand. */
@@ -559,105 +638,27 @@ crs_custom_variables (struct lexer *lexer, struct dataset *ds,
 
 static bool
 should_tabulate_case (const struct pivot_table *pt, const struct ccase *c,
-                      enum mv_class exclude)
+                      enum mv_class exclude, size_t n_splits)
 {
   int j;
   for (j = 0; j < pt->n_vars; j++)
     {
       const struct variable *var = pt->vars[j];
       struct var_range *range = get_var_range (var);
+      const union value *value = case_data_idx (c, j + n_splits);
 
-      if (var_is_value_missing (var, case_data (c, var), exclude))
+      if (var_is_value_missing (var, value, exclude))
         return false;
 
       if (range != NULL)
         {
-          double num = case_num (c, var);
+          double num = value->f;
           if (num < range->min || num > range->max)
             return false;
         }
     }
   return true;
 }
-
-static void
-tabulate_integer_case (struct pivot_table *pt, const struct ccase *c,
-                       double weight)
-{
-  struct table_entry *te;
-  size_t hash;
-  int j;
-
-  hash = 0;
-  for (j = 0; j < pt->n_vars; j++)
-    {
-      /* Throw away fractional parts of values. */
-      hash = hash_int (case_num (c, pt->vars[j]), hash);
-    }
-
-  HMAP_FOR_EACH_WITH_HASH (te, struct table_entry, node, hash, &pt->data)
-    {
-      for (j = 0; j < pt->n_vars; j++)
-        if ((int) case_num (c, pt->vars[j]) != (int) te->values[j].f)
-          goto no_match;
-
-      /* Found an existing entry. */
-      te->freq += weight;
-      return;
-
-    no_match: ;
-    }
-
-  /* No existing entry.  Create a new one. */
-  te = xmalloc (table_entry_size (pt->n_vars));
-  te->freq = weight;
-  for (j = 0; j < pt->n_vars; j++)
-    te->values[j].f = (int) case_num (c, pt->vars[j]);
-  hmap_insert (&pt->data, &te->node, hash);
-}
-
-static void
-tabulate_general_case (struct pivot_table *pt, const struct ccase *c,
-                       double weight)
-{
-  struct table_entry *te;
-  size_t hash;
-  int j;
-
-  hash = 0;
-  for (j = 0; j < pt->n_vars; j++)
-    {
-      const struct variable *var = pt->vars[j];
-      hash = value_hash (case_data (c, var), var_get_width (var), hash);
-    }
-
-  HMAP_FOR_EACH_WITH_HASH (te, struct table_entry, node, hash, &pt->data)
-    {
-      for (j = 0; j < pt->n_vars; j++)
-        {
-          const struct variable *var = pt->vars[j];
-          if (!value_equal (case_data (c, var), &te->values[j],
-                            var_get_width (var)))
-            goto no_match;
-        }
-
-      /* Found an existing entry. */
-      te->freq += weight;
-      return;
-
-    no_match: ;
-    }
-
-  /* No existing entry.  Create a new one. */
-  te = xmalloc (table_entry_size (pt->n_vars));
-  te->freq = weight;
-  for (j = 0; j < pt->n_vars; j++)
-    {
-      const struct variable *var = pt->vars[j];
-      value_clone (&te->values[j], case_data (c, var), var_get_width (var));
-    }
-  hmap_insert (&pt->data, &te->node, hash);
-}
 \f
 /* Post-data reading calculations. */
 
@@ -665,8 +666,6 @@ static int compare_table_entry_vars_3way (const struct table_entry *a,
                                           const struct table_entry *b,
                                           const struct pivot_table *pt,
                                           int idx0, int idx1);
-static int compare_table_entry_3way (const void *ap_, const void *bp_,
-                                     const void *pt_);
 static void enum_var_values (const struct pivot_table *, int var_idx,
                              union value **valuesp, int *n_values);
 static void output_pivot_table (struct crosstabs_proc *,
@@ -682,23 +681,6 @@ postcalc (struct crosstabs_proc *proc)
 {
   struct pivot_table *pt;
 
-  /* Convert hash tables into sorted arrays of entries. */
-  for (pt = &proc->pivots[0]; pt < &proc->pivots[proc->n_pivots]; pt++)
-    {
-      struct table_entry *e;
-      size_t i;
-
-      pt->n_entries = hmap_count (&pt->data);
-      pt->entries = xnmalloc (pt->n_entries, sizeof *pt->entries);
-      i = 0;
-      HMAP_FOR_EACH (e, struct table_entry, node, &pt->data)
-        pt->entries[i++] = e;
-      hmap_destroy (&pt->data);
-
-      sort (pt->entries, pt->n_entries, sizeof *pt->entries,
-            compare_table_entry_3way, pt);
-    }
-
   make_summary_table (proc);
 
   /* Output each pivot table. */
@@ -723,8 +705,6 @@ postcalc (struct crosstabs_proc *proc)
     {
       size_t i;
 
-      pt->missing = 0.0;
-
       /* Free only the members that were allocated in this
          function.  The other pointer members are either both
          allocated and destroyed at a lower level (in
@@ -783,29 +763,6 @@ compare_table_entry_vars_3way (const struct table_entry *a,
   return 0;
 }
 
-/* Compare the struct table_entry at *AP to the one at *BP and
-   return a strcmp()-type result. */
-static int
-compare_table_entry_3way (const void *ap_, const void *bp_, const void *pt_)
-{
-  const struct table_entry *const *ap = ap_;
-  const struct table_entry *const *bp = bp_;
-  const struct table_entry *a = *ap;
-  const struct table_entry *b = *bp;
-  const struct pivot_table *pt = pt_;
-  int cmp;
-
-  cmp = compare_table_entry_vars_3way (a, b, pt, 2, pt->n_vars);
-  if (cmp != 0)
-    return cmp;
-
-  cmp = compare_table_entry_var_3way (a, b, pt, ROW_VAR);
-  if (cmp != 0)
-    return cmp;
-
-  return compare_table_entry_var_3way (a, b, pt, COL_VAR);
-}
-
 static int
 find_first_difference (const struct pivot_table *pt, size_t row)
 {
@@ -1191,7 +1148,8 @@ create_crosstab_table (struct crosstabs_proc *proc, struct pivot_table *pt)
       /* Insert the formatted value of the variable, then trim
          leading spaces in what was just inserted. */
       ofs = ds_length (&title);
-      s = data_out (&pt->const_values[i], dict_get_encoding (proc->dict), var_get_print_format (var));
+      s = data_out (&pt->const_values[i], var_get_encoding (var),
+                    var_get_print_format (var));
       ds_put_cstr (&title, s);
       free (s);
       ds_remove (&title, ofs, ss_cspan (ds_substr (&title, ofs, SIZE_MAX),