CROSSTABS: make sort-based crosstabs work
authorBen Pfaff <blp@cs.stanford.edu>
Tue, 30 Mar 2010 04:05:10 +0000 (21:05 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Mon, 3 May 2010 04:29:22 +0000 (21:29 -0700)
still needs some touch-up regarding SPLIT FILE and memory leaks

src/language/stats/crosstabs.q

index acec9769b48d2ecf11eff6ef837ffb15aaa027b2..59befd1ebcaa862acdf0bbb9ba886f32b83bb4b5 100644 (file)
 #include <stdlib.h>
 #include <stdio.h>
 
-#include <data/case.h>
-#include <data/casegrouper.h>
-#include <data/casereader.h>
-#include <data/data-out.h>
-#include <data/dictionary.h>
-#include <data/format.h>
-#include <data/procedure.h>
-#include <data/value-labels.h>
-#include <data/variable.h>
-#include <language/command.h>
-#include <language/dictionary/split-file.h>
-#include <language/lexer/lexer.h>
-#include <language/lexer/variable-parser.h>
-#include <libpspp/array.h>
-#include <libpspp/assertion.h>
-#include <libpspp/compiler.h>
-#include <libpspp/hash.h>
-#include <libpspp/hmap.h>
-#include <libpspp/hmapx.h>
-#include <libpspp/message.h>
-#include <libpspp/misc.h>
-#include <libpspp/pool.h>
-#include <libpspp/str.h>
-#include <output/tab.h>
-
-#include "minmax.h"
-#include "xalloc.h"
-#include "xsize.h"
+#include "data/case.h"
+#include "data/casegrouper.h"
+#include "data/casereader.h"
+#include "data/casewriter.h"
+#include "data/data-out.h"
+#include "data/dictionary.h"
+#include "data/format.h"
+#include "data/procedure.h"
+#include "data/subcase.h"
+#include "data/value-labels.h"
+#include "data/variable.h"
+#include "language/command.h"
+#include "language/dictionary/split-file.h"
+#include "language/lexer/lexer.h"
+#include "language/lexer/variable-parser.h"
+#include "libpspp/array.h"
+#include "libpspp/assertion.h"
+#include "libpspp/compiler.h"
+#include "libpspp/hash.h"
+#include "libpspp/hmap.h"
+#include "libpspp/hmapx.h"
+#include "libpspp/message.h"
+#include "libpspp/misc.h"
+#include "libpspp/pool.h"
+#include "libpspp/str.h"
+#include "math/sort.h"
+#include "output/tab.h"
+
+#include "gl/minmax.h"
+#include "gl/xalloc.h"
+#include "gl/xsize.h"
 
 #include "gettext.h"
 #define _(msgid) gettext (msgid)
 /* A single table entry for general mode. */
 struct table_entry
   {
-    struct hmap_node node;      /* Entry in hash table. */
     double freq;                /* Frequency count. */
     union value values[1];     /* Values. */
   };
@@ -137,7 +139,8 @@ struct pivot_table
     union value *const_values;
 
     /* Data. */
-    struct hmap data;
+    struct subcase src_sc, dst_sc;
+    struct casewriter *sorter;
     struct table_entry **entries;
     size_t n_entries;
 
@@ -201,24 +204,35 @@ struct crosstabs_proc
   };
 
 static bool should_tabulate_case (const struct pivot_table *,
-                                  const struct ccase *, enum mv_class exclude);
-static void tabulate_general_case (struct pivot_table *, const struct ccase *,
-                                   double weight);
-static void tabulate_integer_case (struct pivot_table *, const struct ccase *,
-                                   double weight);
+                                  const struct ccase *, enum mv_class exclude,
+                                  size_t n_splits);
 static void postcalc (struct crosstabs_proc *);
 static void submit (struct pivot_table *, struct tab_table *);
 
+static struct ccase *
+crs_combine_cases (struct ccase *a, struct ccase *b, void *aux UNUSED)
+{
+  size_t weight_idx = caseproto_get_n_widths (case_get_proto (a)) - 1;
+
+  a = case_unshare (a);
+  case_data_rw_idx (a, weight_idx)->f += case_data_idx (b, weight_idx)->f;
+  case_unref (b);
+
+  return a;
+}
+
 /* Parses and executes the CROSSTABS procedure. */
 int
 cmd_crosstabs (struct lexer *lexer, struct dataset *ds)
 {
   const struct variable *wv = dict_get_weight (dataset_dict (ds));
+  const struct dictionary *dict = dataset_dict (ds);
+  size_t n_splits = dict_get_split_cnt (dict);
   struct crosstabs_proc proc;
-  struct casegrouper *grouper;
-  struct casereader *input, *group;
+  struct casereader *input;
   struct cmd_crosstabs cmd;
   struct pivot_table *pt;
+  struct ccase *c;
   int result;
   bool ok;
   int i;
@@ -294,48 +308,132 @@ cmd_crosstabs (struct lexer *lexer, struct dataset *ds)
   /* PIVOT. */
   proc.pivot = cmd.pivot == CRS_PIVOT;
 
-  input = casereader_create_filter_weight (proc_open (ds), dataset_dict (ds),
-                                           NULL, NULL);
-  grouper = casegrouper_create_splits (input, dataset_dict (ds));
-  while (casegrouper_get_next_group (grouper, &group))
+  for (pt = &proc.pivots[0]; pt < &proc.pivots[proc.n_pivots]; pt++)
     {
-      struct ccase *c;
+      struct caseproto *proto;
+      struct subcase sort;
+
+      subcase_init_empty (&pt->src_sc);
+      subcase_add_vars_always (&pt->src_sc, dict_get_split_vars (dict),
+                               n_splits, SC_ASCEND);
+      subcase_add_vars_always (&pt->src_sc, pt->vars, pt->n_vars, SC_ASCEND);
 
-      /* Output SPLIT FILE variables. */
-      c = casereader_peek (group, 0);
-      if (c != NULL)
+      subcase_clone (&pt->dst_sc, &pt->src_sc);
+      subcase_project (&pt->dst_sc, 0);
+
+      subcase_init_empty (&sort);
+      for (i = 0; i < n_splits; i++)
+        subcase_add_always (&sort, i, subcase_get_width (&pt->src_sc, i),
+                            SC_ASCEND);
+      for (i = 0; i < pt->n_vars; i++)
         {
-          output_split_file_values (ds, c);
-          case_unref (c);
+          size_t var_idx = n_splits + (i == pt->n_vars - 2 ? ROW_VAR
+                                       : i == pt->n_vars - 1 ? COL_VAR
+                                       : i + 2);
+          subcase_add_always (&sort, var_idx,
+                              subcase_get_width (&pt->src_sc, var_idx),
+                              SC_ASCEND);
         }
 
-      /* Initialize hash tables. */
+      proto = caseproto_ref (subcase_get_proto (&pt->dst_sc));
+      proto = caseproto_add_width (proto, 0);
+      pt->sorter = sort_distinct_create_writer (&sort, proto,
+                                                crs_combine_cases, NULL, NULL);
+      caseproto_unref (proto);
+    }
+
+  input = casereader_create_filter_weight (proc_open (ds), dict, NULL, NULL);
+  for (; (c = casereader_read (input)) != NULL; case_unref (c))
+    {
       for (pt = &proc.pivots[0]; pt < &proc.pivots[proc.n_pivots]; pt++)
-        hmap_init (&pt->data);
+        {
+          const struct caseproto *proto = casewriter_get_proto (pt->sorter);
+          struct ccase *pt_case;
 
-      /* Tabulate. */
-      for (; (c = casereader_read (group)) != NULL; case_unref (c))
-        for (pt = &proc.pivots[0]; pt < &proc.pivots[proc.n_pivots]; pt++)
-          {
-            double weight = dict_get_case_weight (dataset_dict (ds), c,
-                                                  &proc.bad_warn);
-            if (should_tabulate_case (pt, c, proc.exclude))
+          pt_case = case_create (proto);
+
+          subcase_copy (&pt->src_sc, c, &pt->dst_sc, pt_case);
+          if (should_tabulate_case (pt, pt_case, proc.exclude, n_splits)
+              && proc.mode == INTEGER)
+            {
+              for (i = 0; i < pt->n_vars; i++)
+                {
+                  double *d = &case_data_rw_idx (pt_case, i + n_splits)->f;
+                  *d = (int) *d;
+                }
+            }
+
+          case_data_rw_idx (pt_case, caseproto_get_n_widths (proto) - 1)->f
+             = dict_get_case_weight (dict, c, &proc.bad_warn);
+
+          casewriter_write (pt->sorter, pt_case);
+        }
+    }
+  ok = casereader_destroy (input);
+  ok = proc_commit (ds) && ok;
+
+  for (pt = &proc.pivots[0]; pt < &proc.pivots[proc.n_pivots]; pt++)
+    {
+      const struct caseproto *proto;
+      struct casegrouper *grouper;
+      struct casereader *data;
+      struct subcase group_sc;
+      struct casereader *group;
+
+      subcase_init_vars (&group_sc, dict_get_split_vars (dict),
+                         dict_get_split_cnt (dict));
+      subcase_project (&group_sc, 0);
+
+      data = casewriter_make_reader (pt->sorter);
+      proto = casereader_get_proto (data);
+      grouper = casegrouper_create_subcase (data, &group_sc);
+      subcase_destroy (&group_sc);
+
+      for (; casegrouper_get_next_group (grouper, &group);
+           casereader_destroy (group))
+        {
+          casenumber n_entries;
+          casenumber i;
+
+          c = casereader_peek (group, 0);
+          if (c != NULL)
+            {
+              /* XXX output_split_file_values (ds, c); */
+              case_unref (c);
+            }
+
+          n_entries = casereader_count_cases (group);
+          if (n_entries > 1000000)
+            {
+              msg (SW, _("Omitting analysis of crosstabulation that has %lu "
+                         "nonempty cells."),
+                   (unsigned long int) n_entries);
+              continue;
+            }
+
+          pt->entries = xmalloc (n_entries * sizeof *pt->entries);
+          pt->n_entries = 0;
+          pt->missing = 0.0;
+          for (; (c = casereader_read (group)) != NULL; case_unref (c))
+            if (should_tabulate_case (pt, c, proc.exclude, n_splits))
               {
-                if (proc.mode == GENERAL)
-                  tabulate_general_case (pt, c, weight);
-                else
-                  tabulate_integer_case (pt, c, weight);
+                struct table_entry *e;
+
+                e = xmalloc (table_entry_size (pt->n_vars));
+                for (i = 0; i < pt->n_vars; i++)
+                  value_clone (&e->values[i], case_data_idx (c, i + n_splits),
+                               caseproto_get_width (proto, i + n_splits));
+                e->freq = case_num_idx (c, pt->n_vars + n_splits);
+
+                pt->entries[pt->n_entries++] = e;
               }
             else
-              pt->missing += weight;
-          }
-      casereader_destroy (group);
+              pt->missing += case_num_idx (c, pt->n_vars + n_splits);
 
-      /* Output. */
-      postcalc (&proc);
+          postcalc (&proc);
+        }
+      ok = casegrouper_destroy (grouper) && ok;
     }
-  ok = casegrouper_destroy (grouper);
-  ok = proc_commit (ds) && ok;
 
   result = ok ? CMD_SUCCESS : CMD_CASCADING_FAILURE;
 
@@ -540,105 +638,27 @@ crs_custom_variables (struct lexer *lexer, struct dataset *ds,
 
 static bool
 should_tabulate_case (const struct pivot_table *pt, const struct ccase *c,
-                      enum mv_class exclude)
+                      enum mv_class exclude, size_t n_splits)
 {
   int j;
   for (j = 0; j < pt->n_vars; j++)
     {
       const struct variable *var = pt->vars[j];
       struct var_range *range = get_var_range (var);
+      const union value *value = case_data_idx (c, j + n_splits);
 
-      if (var_is_value_missing (var, case_data (c, var), exclude))
+      if (var_is_value_missing (var, value, exclude))
         return false;
 
       if (range != NULL)
         {
-          double num = case_num (c, var);
+          double num = value->f;
           if (num < range->min || num > range->max)
             return false;
         }
     }
   return true;
 }
-
-static void
-tabulate_integer_case (struct pivot_table *pt, const struct ccase *c,
-                       double weight)
-{
-  struct table_entry *te;
-  size_t hash;
-  int j;
-
-  hash = 0;
-  for (j = 0; j < pt->n_vars; j++)
-    {
-      /* Throw away fractional parts of values. */
-      hash = hash_int (case_num (c, pt->vars[j]), hash);
-    }
-
-  HMAP_FOR_EACH_WITH_HASH (te, struct table_entry, node, hash, &pt->data)
-    {
-      for (j = 0; j < pt->n_vars; j++)
-        if ((int) case_num (c, pt->vars[j]) != (int) te->values[j].f)
-          goto no_match;
-
-      /* Found an existing entry. */
-      te->freq += weight;
-      return;
-
-    no_match: ;
-    }
-
-  /* No existing entry.  Create a new one. */
-  te = xmalloc (table_entry_size (pt->n_vars));
-  te->freq = weight;
-  for (j = 0; j < pt->n_vars; j++)
-    te->values[j].f = (int) case_num (c, pt->vars[j]);
-  hmap_insert (&pt->data, &te->node, hash);
-}
-
-static void
-tabulate_general_case (struct pivot_table *pt, const struct ccase *c,
-                       double weight)
-{
-  struct table_entry *te;
-  size_t hash;
-  int j;
-
-  hash = 0;
-  for (j = 0; j < pt->n_vars; j++)
-    {
-      const struct variable *var = pt->vars[j];
-      hash = value_hash (case_data (c, var), var_get_width (var), hash);
-    }
-
-  HMAP_FOR_EACH_WITH_HASH (te, struct table_entry, node, hash, &pt->data)
-    {
-      for (j = 0; j < pt->n_vars; j++)
-        {
-          const struct variable *var = pt->vars[j];
-          if (!value_equal (case_data (c, var), &te->values[j],
-                            var_get_width (var)))
-            goto no_match;
-        }
-
-      /* Found an existing entry. */
-      te->freq += weight;
-      return;
-
-    no_match: ;
-    }
-
-  /* No existing entry.  Create a new one. */
-  te = xmalloc (table_entry_size (pt->n_vars));
-  te->freq = weight;
-  for (j = 0; j < pt->n_vars; j++)
-    {
-      const struct variable *var = pt->vars[j];
-      value_clone (&te->values[j], case_data (c, var), var_get_width (var));
-    }
-  hmap_insert (&pt->data, &te->node, hash);
-}
 \f
 /* Post-data reading calculations. */
 
@@ -646,8 +666,6 @@ static int compare_table_entry_vars_3way (const struct table_entry *a,
                                           const struct table_entry *b,
                                           const struct pivot_table *pt,
                                           int idx0, int idx1);
-static int compare_table_entry_3way (const void *ap_, const void *bp_,
-                                     const void *pt_);
 static void enum_var_values (const struct pivot_table *, int var_idx,
                              union value **valuesp, int *n_values);
 static void output_pivot_table (struct crosstabs_proc *,
@@ -663,23 +681,6 @@ postcalc (struct crosstabs_proc *proc)
 {
   struct pivot_table *pt;
 
-  /* Convert hash tables into sorted arrays of entries. */
-  for (pt = &proc->pivots[0]; pt < &proc->pivots[proc->n_pivots]; pt++)
-    {
-      struct table_entry *e;
-      size_t i;
-
-      pt->n_entries = hmap_count (&pt->data);
-      pt->entries = xnmalloc (pt->n_entries, sizeof *pt->entries);
-      i = 0;
-      HMAP_FOR_EACH (e, struct table_entry, node, &pt->data)
-        pt->entries[i++] = e;
-      hmap_destroy (&pt->data);
-
-      sort (pt->entries, pt->n_entries, sizeof *pt->entries,
-            compare_table_entry_3way, pt);
-    }
-
   make_summary_table (proc);
 
   /* Output each pivot table. */
@@ -704,8 +705,6 @@ postcalc (struct crosstabs_proc *proc)
     {
       size_t i;
 
-      pt->missing = 0.0;
-
       /* Free only the members that were allocated in this
          function.  The other pointer members are either both
          allocated and destroyed at a lower level (in
@@ -764,29 +763,6 @@ compare_table_entry_vars_3way (const struct table_entry *a,
   return 0;
 }
 
-/* Compare the struct table_entry at *AP to the one at *BP and
-   return a strcmp()-type result. */
-static int
-compare_table_entry_3way (const void *ap_, const void *bp_, const void *pt_)
-{
-  const struct table_entry *const *ap = ap_;
-  const struct table_entry *const *bp = bp_;
-  const struct table_entry *a = *ap;
-  const struct table_entry *b = *bp;
-  const struct pivot_table *pt = pt_;
-  int cmp;
-
-  cmp = compare_table_entry_vars_3way (a, b, pt, 2, pt->n_vars);
-  if (cmp != 0)
-    return cmp;
-
-  cmp = compare_table_entry_var_3way (a, b, pt, ROW_VAR);
-  if (cmp != 0)
-    return cmp;
-
-  return compare_table_entry_var_3way (a, b, pt, COL_VAR);
-}
-
 static int
 find_first_difference (const struct pivot_table *pt, size_t row)
 {