caseinit: Introduce new caseinit_translate_casereader_to_init_vars().
authorBen Pfaff <blp@cs.stanford.edu>
Thu, 2 Mar 2023 22:25:30 +0000 (14:25 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sun, 5 Mar 2023 18:50:42 +0000 (10:50 -0800)
src/data/case.c
src/data/caseinit.c
src/data/caseinit.h
src/data/caseproto.c
src/data/caseproto.h
src/data/casereader.c
src/data/casewriter.c
src/data/dataset.c
src/language/commands/inpt-pgm.c

index 742d1f0f22ca46e6743f0c0a6d808b6fd28250a6..a70179f855bf9114df1c8eb667ac630f892e376b 100644 (file)
@@ -201,8 +201,8 @@ case_copy (struct ccase *dst, size_t dst_idx,
   assert (!case_is_shared (dst));
   assert (caseproto_range_is_valid (dst->proto, dst_idx, n_values));
   assert (caseproto_range_is_valid (src->proto, src_idx, n_values));
-  assert (caseproto_equal (dst->proto, dst_idx, src->proto, src_idx,
-                           n_values));
+  assert (caseproto_range_equal (dst->proto, dst_idx, src->proto, src_idx,
+                                 n_values));
 
   if (dst != src)
     {
index 815041fd1e6cb4453356f6248a7030c8503a9bd6..3efe128efd205c7f618e712a3d6386d5e127c912 100644 (file)
@@ -23,6 +23,7 @@
 #include <string.h>
 
 #include "data/case.h"
+#include "data/casereader.h"
 #include "data/dictionary.h"
 #include "data/value.h"
 #include "data/variable.h"
@@ -66,19 +67,17 @@ init_list_create (struct init_list *list)
 }
 
 /* Initializes NEW as a copy of OLD. */
-static void
-init_list_clone (struct init_list *new, const struct init_list *old)
+static struct init_list
+init_list_clone (const struct init_list *old)
 {
-  size_t i;
-
-  new->values = xmemdup (old->values, old->n * sizeof *old->values);
-  new->n = old->n;
-
-  for (i = 0; i < new->n; i++)
+  struct init_value *values = xmemdup (old->values,
+                                       old->n * sizeof *old->values);
+  for (size_t i = 0; i < old->n; i++)
     {
-      struct init_value *iv = &new->values[i];
+      struct init_value *iv = &values[i];
       value_clone (&iv->value, &iv->value, iv->width);
     }
+  return (struct init_list) { .values = values, .n = old->n };
 }
 
 /* Frees the storage associated with LIST. */
@@ -218,9 +217,11 @@ struct caseinit *
 caseinit_clone (struct caseinit *old)
 {
   struct caseinit *new = xmalloc (sizeof *new);
-  init_list_clone (&new->preinited_values, &old->preinited_values);
-  init_list_clone (&new->reinit_values, &old->reinit_values);
-  init_list_clone (&new->left_values, &old->left_values);
+  *new = (struct caseinit) {
+    .preinited_values = init_list_clone (&old->preinited_values),
+    .reinit_values = init_list_clone (&old->reinit_values),
+    .left_values = init_list_clone (&old->left_values),
+  };
   return new;
 }
 
@@ -272,15 +273,75 @@ void
 caseinit_init_vars (const struct caseinit *ci, struct ccase *c)
 {
   init_list_init (&ci->reinit_values, c);
+}
+
+/* Copies the left vars from CI into C. */
+void
+caseinit_restore_left_vars (struct caseinit *ci, struct ccase *c)
+{
   init_list_init (&ci->left_values, c);
 }
 
-/* Updates the left vars in CI from the data in C, so that the
-   next call to caseinit_init_vars will store those values in the
-   next case. */
+/* Copies the left vars from C into CI. */
 void
-caseinit_update_left_vars (struct caseinit *ci, const struct ccase *c)
+caseinit_save_left_vars (struct caseinit *ci, const struct ccase *c)
 {
   init_list_update (&ci->left_values, c);
 }
+\f
+struct caseinit_translator
+  {
+    struct init_list reinit_values;
+    struct caseproto *proto;
+  };
+
+static struct ccase *
+translate_caseinit (struct ccase *c, void *cit_)
+{
+  const struct caseinit_translator *cit = cit_;
+
+  c = case_unshare_and_resize (c, cit->proto);
+  init_list_init (&cit->reinit_values, c);
+  return c;
+}
+
+static bool
+translate_destroy (void *cit_)
+{
+  struct caseinit_translator *cit = cit_;
+
+  init_list_destroy (&cit->reinit_values);
+  caseproto_unref (cit->proto);
+  free (cit);
+
+  return true;
+}
+
+/* Returns a new casereader that yields each case from R, resized to match
+   OUTPUT_PROTO and initialized from CI as if with caseinit_init_vars().  Takes
+   ownership of R.
+
+   OUTPUT_PROTO must be conformable with R's prototype.  */
+struct casereader *
+caseinit_translate_casereader_to_init_vars (struct caseinit *ci,
+                                            const struct caseproto *output_proto,
+                                            struct casereader *r)
+{
+  assert (caseproto_is_conformable (casereader_get_proto (r), output_proto));
+  if (caseproto_equal (output_proto, casereader_get_proto (r))
+      && ci->reinit_values.n == 0)
+    return casereader_rename (r);
+
+  struct caseinit_translator *cit = xmalloc (sizeof *cit);
+  *cit = (struct caseinit_translator) {
+    .reinit_values = init_list_clone (&ci->reinit_values),
+    .proto = caseproto_ref (output_proto),
+  };
+
+  static const struct casereader_translator_class class = {
+    .translate = translate_caseinit,
+    .destroy = translate_destroy,
+  };
+  return casereader_translate_stateless (r, output_proto, &class, cit);
+}
 
index 9f566218428f45c354e13c09a86c083a2e83fabf..ff79a947fed8e6dac0200400047b60820af17d97 100644 (file)
@@ -31,6 +31,7 @@
 #ifndef DATA_CASEINIT_H
 #define DATA_CASEINIT_H 1
 
+struct caseproto;
 struct dictionary;
 struct ccase;
 
@@ -46,6 +47,12 @@ void caseinit_mark_for_init (struct caseinit *, const struct dictionary *);
 
 /* Initialize data and copy data from case to case. */
 void caseinit_init_vars (const struct caseinit *, struct ccase *);
-void caseinit_update_left_vars (struct caseinit *, const struct ccase *);
+void caseinit_save_left_vars (struct caseinit *, const struct ccase *);
+void caseinit_restore_left_vars (struct caseinit *, struct ccase *);
+
+/* Translate. */
+struct casereader *caseinit_translate_casereader_to_init_vars (
+  struct caseinit *, const struct caseproto *output_proto,
+  struct casereader *);
 
 #endif /* data/caseinit.h */
index f390c3f87cfd113ad1262504da30867181c60c8b..c47b6ae98221930d7a201c946704c1a78cb33f75 100644 (file)
@@ -207,9 +207,9 @@ caseproto_is_conformable (const struct caseproto *a, const struct caseproto *b)
    same as the N widths starting at B_START in B, false if any of
    the corresponding widths differ. */
 bool
-caseproto_equal (const struct caseproto *a, size_t a_start,
-                 const struct caseproto *b, size_t b_start,
-                 size_t n)
+caseproto_range_equal (const struct caseproto *a, size_t a_start,
+                       const struct caseproto *b, size_t b_start,
+                       size_t n)
 {
   size_t i;
 
@@ -221,6 +221,15 @@ caseproto_equal (const struct caseproto *a, size_t a_start,
   return true;
 }
 
+/* Returns true if A and B have the same widths, false otherwise. */
+bool
+caseproto_equal (const struct caseproto *a, const struct caseproto *b)
+{
+  return (a == b ? true
+          : a->n_widths != b->n_widths ? false
+          : caseproto_range_equal (a, 0, b, 0, a->n_widths));
+}
+
 /* Returns true if an array of values that is to be used for
    data of the format specified in PROTO needs to be initialized
    by calling caseproto_init_values, false if that step may be
index e6921888f58baf7e12332810f241f2ac0cc60b42..37c4e19f1126e9b1f13eaada63e9f571da8fcc1f 100644 (file)
@@ -130,9 +130,10 @@ bool caseproto_range_is_valid (const struct caseproto *,
                                size_t ofs, size_t count);
 bool caseproto_is_conformable (const struct caseproto *a,
                                const struct caseproto *b);
-bool caseproto_equal (const struct caseproto *a, size_t a_start,
-                      const struct caseproto *b, size_t b_start,
-                      size_t n);
+bool caseproto_range_equal (const struct caseproto *a, size_t a_start,
+                            const struct caseproto *b, size_t b_start,
+                            size_t n);
+bool caseproto_equal (const struct caseproto *, const struct caseproto *);
 \f
 /* Creation and destruction. */
 
index a410afcbcbfab2c826735d9ef94e4daffdca10f5..9a488385a2509cd558eb6e900ba9f83e773e585c 100644 (file)
@@ -73,8 +73,8 @@ casereader_read (struct casereader *reader)
         {
           size_t n_widths UNUSED = caseproto_get_n_widths (reader->proto);
           assert (case_get_n_values (c) >= n_widths);
-          expensive_assert (caseproto_equal (case_get_proto (c), 0,
-                                             reader->proto, 0, n_widths));
+          expensive_assert (caseproto_range_equal (case_get_proto (c), 0,
+                                                   reader->proto, 0, n_widths));
           return c;
         }
     }
index 768a515e16e7a4bf15cd91c36528f8e7ffcd3d6e..51d5e10bfba4f762750c2b3687e4cefea16c1817 100644 (file)
@@ -52,8 +52,8 @@ casewriter_write (struct casewriter *writer, struct ccase *c)
 {
   size_t n_widths UNUSED = caseproto_get_n_widths (writer->proto);
   assert (case_get_n_values (c) >= n_widths);
-  expensive_assert (caseproto_equal (case_get_proto (c), 0,
-                                     writer->proto, 0, n_widths));
+  expensive_assert (caseproto_range_equal (case_get_proto (c), 0,
+                                           writer->proto, 0, n_widths));
   writer->class->write (writer, writer->aux, c);
 }
 
index a816570c0ed86f777dcac670dfe9560fd3abc78c..6ade4b24c997be69acc68b7e07994a2ca0014fbf 100644 (file)
@@ -448,6 +448,8 @@ proc_open_filtering (struct dataset *ds, bool filter)
   update_last_proc_invocation (ds);
 
   caseinit_mark_for_init (ds->caseinit, ds->dict);
+  ds->source = caseinit_translate_casereader_to_init_vars (
+    ds->caseinit, dict_get_proto (ds->dict), ds->source);
 
   /* Finish up the collection of transformations. */
   add_case_limit_trns (ds);
@@ -549,12 +551,12 @@ proc_casereader_read (struct casereader *reader UNUSED, void *ds_)
       if (c == NULL)
         return NULL;
       c = case_unshare_and_resize (c, dict_get_proto (ds->dict));
-      caseinit_init_vars (ds->caseinit, c);
+      caseinit_restore_left_vars (ds->caseinit, c);
 
       /* Execute permanent transformations.  */
       casenumber case_nr = ds->cases_written + 1;
       retval = trns_chain_execute (&ds->permanent_trns_chain, case_nr, &c);
-      caseinit_update_left_vars (ds->caseinit, c);
+      caseinit_save_left_vars (ds->caseinit, c);
       if (retval != TRNS_CONTINUE)
         continue;
 
index a0e3bab63b3f19d79046ff27d2d83734a3137287..b5bd041d79ffa2d81201113333688743c3080c5b 100644 (file)
@@ -184,6 +184,7 @@ input_program_casereader_read (struct casereader *reader UNUSED, void *inp_)
 
   struct ccase *c = case_create (inp->proto);
   caseinit_init_vars (inp->init, c);
+  caseinit_restore_left_vars (inp->init, c);
 
   for (size_t i = inp->idx < inp->xforms.n ? inp->idx : 0; ; i++)
     {
@@ -191,7 +192,7 @@ input_program_casereader_read (struct casereader *reader UNUSED, void *inp_)
         {
           i = 0;
           c = case_unshare (c);
-          caseinit_update_left_vars (inp->init, c);
+          caseinit_save_left_vars (inp->init, c);
           caseinit_init_vars (inp->init, c);
         }