MATRIX: Factor out write_file, like read_file.
authorBen Pfaff <blp@cs.stanford.edu>
Sat, 6 Nov 2021 05:40:28 +0000 (22:40 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 6 Nov 2021 05:40:28 +0000 (22:40 -0700)
src/language/stats/matrix.c

index ca14562a9c797510ddf9307d9674ba8f9c2b2954..07179cefaf4bb97e1bb247a8c0d6d97cb268a654 100644 (file)
@@ -96,6 +96,13 @@ struct read_file
     char *encoding;
   };
 
+struct write_file
+  {
+    struct file_handle *file;
+    struct dfm_writer *writer;
+    char *encoding;
+  };
+
 struct matrix_state
   {
     struct dataset *dataset;
@@ -104,12 +111,15 @@ struct matrix_state
     struct hmap vars;
     bool in_loop;
     struct file_handle *prev_save_outfile;
-    struct file_handle *prev_write_outfile;
     struct msave_common *common;
 
     struct file_handle *prev_read_file;
     struct read_file **read_files;
     size_t n_read_files;
+
+    struct file_handle *prev_write_file;
+    struct write_file **write_files;
+    size_t n_write_files;
   };
 
 static struct matrix_var *
@@ -1691,6 +1701,48 @@ read_file_destroy (struct read_file *rf)
     }
 }
 
+static struct write_file *
+write_file_create (struct matrix_state *s, struct file_handle *fh)
+{
+  for (size_t i = 0; i < s->n_write_files; i++)
+    {
+      struct write_file *wf = s->write_files[i];
+      if (wf->file == fh)
+        {
+          fh_unref (fh);
+          return wf;
+        }
+    }
+
+  struct write_file *wf = xmalloc (sizeof *wf);
+  *wf = (struct write_file) { .file = fh };
+
+  s->write_files = xrealloc (s->write_files,
+                             (s->n_write_files + 1) * sizeof *s->write_files);
+  s->write_files[s->n_write_files++] = wf;
+  return wf;
+}
+
+static struct dfm_writer *
+write_file_open (struct write_file *wf)
+{
+  if (!wf->writer)
+    wf->writer = dfm_open_writer (wf->file, wf->encoding);
+  return wf->writer;
+}
+
+static void
+write_file_destroy (struct write_file *wf)
+{
+  if (wf)
+    {
+      fh_unref (wf->file);
+      dfm_close_writer (wf->writer);
+      free (wf->encoding);
+      free (wf);
+    }
+}
+
 static bool
 matrix_parse_function (struct matrix_state *s, const char *token,
                        struct matrix_expr **exprp)
@@ -3352,9 +3404,8 @@ struct matrix_cmd
 
         struct write_command
           {
+            struct write_file *wf;
             struct matrix_expr *expression;
-            struct file_handle *outfile;
-            char *encoding;
             int c1, c2;
             enum fmt_type format;
             int w;
@@ -4660,6 +4711,7 @@ matrix_parse_read (struct matrix_state *s)
   s->prev_read_file = fh_ref (fh);
 
   read->rf = read_file_create (s, fh);
+  fh = NULL;
   if (encoding)
     {
       free (read->rf->encoding);
@@ -4922,6 +4974,8 @@ matrix_parse_write (struct matrix_state *s)
     .write = { .format = FMT_F },
   };
 
+  struct file_handle *fh = NULL;
+  char *encoding = NULL;
   struct write_command *write = &cmd->write;
   write->expression = matrix_parse_exp (s);
   if (!write->expression)
@@ -4937,9 +4991,9 @@ matrix_parse_write (struct matrix_state *s)
         {
           lex_match (s->lexer, T_EQUALS);
 
-          fh_unref (write->outfile);
-          write->outfile = fh_parse (s->lexer, FH_REF_FILE, NULL);
-          if (!write->outfile)
+          fh_unref (fh);
+          fh = fh_parse (s->lexer, FH_REF_FILE, NULL);
+          if (!fh)
             goto error;
         }
       else if (lex_match_id (s->lexer, "ENCODING"))
@@ -4948,8 +5002,8 @@ matrix_parse_write (struct matrix_state *s)
          if (!lex_force_string (s->lexer))
            goto error;
 
-          free (write->encoding);
-          write->encoding = ss_xstrdup (lex_tokss (s->lexer));
+          free (encoding);
+          encoding = ss_xstrdup (lex_tokss (s->lexer));
 
          lex_get (s->lexer);
        }
@@ -5051,18 +5105,27 @@ matrix_parse_write (struct matrix_state *s)
       goto error;
     }
 
-  if (!write->outfile)
+  if (!fh)
     {
-      if (s->prev_write_outfile)
-        write->outfile = fh_ref (s->prev_write_outfile);
+      if (s->prev_write_file)
+        fh = fh_ref (s->prev_write_file);
       else
         {
           lex_sbc_missing ("OUTFILE");
           goto error;
         }
     }
-  fh_unref (s->prev_write_outfile);
-  s->prev_write_outfile = fh_ref (write->outfile);
+  fh_unref (s->prev_write_file);
+  s->prev_write_file = fh_ref (fh);
+
+  write->wf = write_file_create (s, fh);
+  fh = NULL;
+  if (encoding)
+    {
+      free (write->wf->encoding);
+      write->wf->encoding = encoding;
+      encoding = NULL;
+    }
 
   /* Field width may be specified in multiple ways:
 
@@ -5100,6 +5163,7 @@ matrix_parse_write (struct matrix_state *s)
   return cmd;
 
 error:
+  fh_unref (fh);
   matrix_cmd_destroy (cmd);
   return NULL;
 }
@@ -5120,7 +5184,7 @@ matrix_cmd_execute_write (struct write_command *write)
       return;
     }
 
-  struct dfm_writer *writer = dfm_open_writer (write->outfile, write->encoding);
+  struct dfm_writer *writer = write_file_open (write->wf);
   if (!writer)
     {
       gsl_matrix_free (m);
@@ -6466,7 +6530,6 @@ matrix_cmd_destroy (struct matrix_cmd *cmd)
 
     case MCMD_WRITE:
       matrix_expr_destroy (cmd->write.expression);
-      free (cmd->write.encoding);
       break;
 
     case MCMD_GET:
@@ -6610,7 +6673,6 @@ cmd_matrix (struct lexer *lexer, struct dataset *ds)
     }
   hmap_destroy (&state.vars);
   fh_unref (state.prev_save_outfile);
-  fh_unref (state.prev_write_outfile);
   if (state.common)
     {
       dict_unref (state.common->dict);
@@ -6621,6 +6683,10 @@ cmd_matrix (struct lexer *lexer, struct dataset *ds)
   for (size_t i = 0; i < state.n_read_files; i++)
     read_file_destroy (state.read_files[i]);
   free (state.read_files);
+  fh_unref (state.prev_write_file);
+  for (size_t i = 0; i < state.n_write_files; i++)
+    write_file_destroy (state.write_files[i]);
+  free (state.write_files);
 
   return CMD_SUCCESS;
 }