matrix-data: Only use as many bytes as necessary to initialize string.
[pspp] / src / language / data-io / matrix-data.c
index 7855caf4646d37d71091f9461e6d6971bb0ef155..81e687983c4110aa290d231944f0960d64916119 100644 (file)
@@ -34,6 +34,7 @@
 #include "language/lexer/variable-parser.h"
 #include "libpspp/i18n.h"
 #include "libpspp/message.h"
+#include "libpspp/misc.h"
 
 #include "gl/xsize.h"
 #include "gl/xalloc.h"
@@ -65,6 +66,8 @@ enum triangle
     FULL
   };
 
+static const int ROWTYPE_WIDTH = 8;
+
 struct matrix_format
 {
   enum triangle triangle;
@@ -72,6 +75,9 @@ struct matrix_format
   const struct variable *rowtype;
   const struct variable *varname;
   int n_continuous_vars;
+  struct variable **split_vars;
+  size_t n_split_vars;
+  long n;
 };
 
 /*
@@ -92,58 +98,105 @@ valid rowtype_ values:
   PROX.
 */
 
-/* Sets the value of OUTCASE which corresponds to MFORMAT's varname variable
-   to the string STR. VAR must be of type string.
+/* Sets the value of OUTCASE which corresponds to VNAME
+   to the value STR.  VNAME must be of type string.
  */
 static void
-set_varname_column (struct ccase *outcase, const struct matrix_format *mformat,
-     const char *str, int len)
+set_varname_column (struct ccase *outcase, const struct variable *vname,
+     const char *str)
 {
-  const struct variable *var = mformat->varname;
-  uint8_t *s = value_str_rw (case_data_rw (outcase, var), len);
+  int len = var_get_width (vname);
+  uint8_t *s = value_str_rw (case_data_rw (outcase, vname), len);
 
   strncpy ((char *) s, str, len);
 }
 
+static void
+blank_varname_column (struct ccase *outcase, const struct variable *vname)
+{
+  int len = var_get_width (vname);
+  uint8_t *s = value_str_rw (case_data_rw (outcase, vname), len);
+
+  memset (s, ' ', len);
+}
 
 static struct casereader *
 preprocess (struct casereader *casereader0, const struct dictionary *dict, void *aux)
 {
   struct matrix_format *mformat = aux;
   const struct caseproto *proto = casereader_get_proto (casereader0);
-  struct casewriter *writer;
-  writer = autopaging_writer_create (proto);
+  struct casewriter *writer = autopaging_writer_create (proto);
+  struct ccase *prev_case = NULL;
+  double **matrices = NULL;
+  size_t n_splits = 0;
+
+  const size_t sizeof_matrix =
+    sizeof (double) * mformat->n_continuous_vars * mformat->n_continuous_vars;
 
-  double *temp_matrix =
-    xcalloc (sizeof (*temp_matrix),
-            mformat->n_continuous_vars * mformat->n_continuous_vars);
 
   /* Make an initial pass to populate our temporary matrix */
   struct casereader *pass0 = casereader_clone (casereader0);
   struct ccase *c;
+  unsigned int prev_split_hash = 1;
   int row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ? 1 : 0;
   for (; (c = casereader_read (pass0)) != NULL; case_unref (c))
     {
+      int s;
+      unsigned int split_hash = 0;
+      for (s = 0; s < mformat->n_split_vars; ++s)
+       {
+         const struct variable *svar = mformat->split_vars[s];
+         const union value *sv = case_data (c, svar);
+         split_hash = value_hash (sv, var_get_width (svar), split_hash);
+       }
+
+      if (matrices == NULL || prev_split_hash != split_hash)
+       {
+         row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ?
+           1 : 0;
+
+         n_splits++;
+         matrices = xrealloc (matrices, sizeof (double*)  * n_splits);
+         matrices[n_splits - 1] = xmalloc (sizeof_matrix);
+       }
+
+      prev_split_hash = split_hash;
+
       int c_offset = (mformat->triangle == UPPER) ? row : 0;
       if (mformat->triangle == UPPER && mformat->diagonal == NO_DIAGONAL)
        c_offset++;
       const union value *v = case_data (c, mformat->rowtype);
-      const char *val = (const char *) value_str (v, 8);
-      if (0 == strncasecmp (val, "corr    ", 8) ||
-         0 == strncasecmp (val, "cov     ", 8))
+      const char *val = (const char *) value_str (v, ROWTYPE_WIDTH);
+      if (0 == strncasecmp (val, "corr    ", ROWTYPE_WIDTH) ||
+         0 == strncasecmp (val, "cov     ", ROWTYPE_WIDTH))
        {
+         if (row >= mformat->n_continuous_vars)
+           {
+             msg (SE,
+                  _("There are %d variable declared but the data has at least %d matrix rows."),
+                  mformat->n_continuous_vars, row + 1);
+             case_unref (c);
+             casereader_destroy (pass0);
+             goto error;
+           }
          int col;
          for (col = c_offset; col < mformat->n_continuous_vars; ++col)
            {
              const struct variable *var =
                dict_get_var (dict,
-                             1 + col - c_offset + var_get_dict_index (mformat->varname));
+                             1 + col - c_offset +
+                             var_get_dict_index (mformat->varname));
 
              double e = case_data (c, var)->f;
              if (e == SYSMIS)
                continue;
-             temp_matrix [col + mformat->n_continuous_vars * row] = e;
-             temp_matrix [row + mformat->n_continuous_vars * col] = e;
+
+             /* Fill in the lower triangle */
+             (matrices[n_splits-1])[col + mformat->n_continuous_vars * row] = e;
+
+             if (mformat->triangle != FULL)
+               /* Fill in the upper triangle */
+               (matrices[n_splits-1]) [row + mformat->n_continuous_vars * col] = e;
            }
          row++;
        }
@@ -154,28 +207,75 @@ preprocess (struct casereader *casereader0, const struct dictionary *dict, void
      temporary matrix */
   const int idx = var_get_dict_index (mformat->varname);
   row = 0;
-  struct ccase *prev_case = NULL;
+
+  if (mformat->n >= 0)
+    {
+      int col;
+      struct ccase *outcase = case_create (proto);
+      union value *v = case_data_rw (outcase, mformat->rowtype);
+      uint8_t *n = value_str_rw (v, ROWTYPE_WIDTH);
+      memcpy (n, "N       ", ROWTYPE_WIDTH);
+      blank_varname_column (outcase, mformat->varname);
+      for (col = 0; col < mformat->n_continuous_vars; ++col)
+       {
+         union value *dest_val =
+           case_data_rw_idx (outcase,
+                             1 + col + var_get_dict_index (mformat->varname));
+         dest_val->f = mformat->n;
+       }
+      casewriter_write (writer, outcase);
+    }
+
+  prev_split_hash = 1;
+  n_splits = 0;
   for (; (c = casereader_read (casereader0)) != NULL; prev_case = c)
     {
+      int s;
+      unsigned int split_hash = 0;
+      for (s = 0; s < mformat->n_split_vars; ++s)
+       {
+         const struct variable *svar = mformat->split_vars[s];
+         const union value *sv = case_data (c, svar);
+         split_hash = value_hash (sv, var_get_width (svar), split_hash);
+       }
+      if (prev_split_hash != split_hash)
+       {
+         n_splits++;
+         row = 0;
+       }
+
+      prev_split_hash = split_hash;
       case_unref (prev_case);
+      const union value *v = case_data (c, mformat->rowtype);
+      const char *val = (const char *) value_str (v, ROWTYPE_WIDTH);
+      if (mformat->n >= 0)
+       {
+         if (0 == strncasecmp (val, "n       ", ROWTYPE_WIDTH) ||
+             0 == strncasecmp (val, "n_vector", ROWTYPE_WIDTH))
+           {
+             msg (SW,
+                  _("The N subcommand was specified, but a N record was also found in the data.  The N record will be ignored."));
+             continue;
+           }
+       }
+
       struct ccase *outcase = case_create (proto);
       case_copy (outcase, 0, c, 0, caseproto_get_n_widths (proto));
-      const union value *v = case_data (c, mformat->rowtype);
-      const char *val = (const char *) value_str (v, 8);
-      if (0 == strncasecmp (val, "corr    ", 8) ||
-         0 == strncasecmp (val, "cov     ", 8))
+
+      if (0 == strncasecmp (val, "corr    ", ROWTYPE_WIDTH) ||
+         0 == strncasecmp (val, "cov     ", ROWTYPE_WIDTH))
        {
          int col;
          const struct variable *var = dict_get_var (dict, idx + 1 + row);
-         set_varname_column (outcase, mformat, var_get_name (var), 8);
-         value_copy (case_data_rw (outcase, mformat->rowtype), v, 8);
+         set_varname_column (outcase, mformat->varname, var_get_name (var));
+         value_copy (case_data_rw (outcase, mformat->rowtype), v, ROWTYPE_WIDTH);
 
          for (col = 0; col < mformat->n_continuous_vars; ++col)
            {
              union value *dest_val =
                case_data_rw_idx (outcase,
                                  1 + col + var_get_dict_index (mformat->varname));
-             dest_val->f = temp_matrix [col + mformat->n_continuous_vars * row];
+             dest_val->f = (matrices[n_splits - 1])[col + mformat->n_continuous_vars * row];
              if (col == row && mformat->diagonal == NO_DIAGONAL)
                dest_val->f = 1.0;
            }
@@ -183,18 +283,18 @@ preprocess (struct casereader *casereader0, const struct dictionary *dict, void
        }
       else
        {
-         set_varname_column (outcase, mformat, "        ", 8);
+         blank_varname_column (outcase, mformat->varname);
        }
 
       /* Special case for SD and N_VECTOR: Rewrite as STDDEV and N respectively */
-      if (0 == strncasecmp (val, "sd      ", 8))
+      if (0 == strncasecmp (val, "sd      ", ROWTYPE_WIDTH))
        {
-         value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), 8,
+         value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), ROWTYPE_WIDTH,
                               (uint8_t *) "STDDEV", 6, ' ');
        }
-      else if (0 == strncasecmp (val, "n_vector", 8))
+      else if (0 == strncasecmp (val, "n_vector", ROWTYPE_WIDTH))
        {
-         value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), 8,
+         value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), ROWTYPE_WIDTH,
                               (uint8_t *) "N", 1, ' ');
        }
 
@@ -210,16 +310,15 @@ preprocess (struct casereader *casereader0, const struct dictionary *dict, void
       if (prev_case)
        case_copy (outcase, 0, prev_case, 0, caseproto_get_n_widths (proto));
 
-
       const struct variable *var = dict_get_var (dict, idx + 1 + row);
-      set_varname_column (outcase, mformat, var_get_name (var), 8);
+      set_varname_column (outcase, mformat->varname, var_get_name (var));
 
       for (col = 0; col < mformat->n_continuous_vars; ++col)
        {
          union value *dest_val =
            case_data_rw_idx (outcase, 1 + col +
                              var_get_dict_index (mformat->varname));
-         dest_val->f = temp_matrix [col + mformat->n_continuous_vars * row];
+         dest_val->f = (matrices[n_splits - 1]) [col + mformat->n_continuous_vars * row];
          if (col == row && mformat->diagonal == NO_DIAGONAL)
            dest_val->f = 1.0;
        }
@@ -227,13 +326,29 @@ preprocess (struct casereader *casereader0, const struct dictionary *dict, void
       casewriter_write (writer, outcase);
     }
 
+
   if (prev_case)
     case_unref (prev_case);
 
-  free (temp_matrix);
+  int i;
+  for (i = 0 ; i < n_splits; ++i)
+    free (matrices[i]);
+  free (matrices);
   struct casereader *reader1 = casewriter_make_reader (writer);
   casereader_destroy (casereader0);
   return reader1;
+
+
+error:
+  if (prev_case)
+    case_unref (prev_case);
+
+  for (i = 0 ; i < n_splits; ++i)
+    free (matrices[i]);
+  free (matrices);
+  casereader_destroy (casereader0);
+  casewriter_destroy (writer);
+  return NULL;
 }
 
 int
@@ -251,6 +366,9 @@ cmd_matrix (struct lexer *lexer, struct dataset *ds)
 
   mformat.triangle = LOWER;
   mformat.diagonal = DIAGONAL;
+  mformat.n_split_vars = 0;
+  mformat.split_vars = NULL;
+  mformat.n = -1;
 
   dict = (in_input_program ()
           ? dataset_dict (ds)
@@ -262,17 +380,17 @@ cmd_matrix (struct lexer *lexer, struct dataset *ds)
   data_parser_set_warn_missing_fields (parser, false);
   data_parser_set_span (parser, false);
 
-  mformat.rowtype = dict_create_var (dict, "ROWTYPE_", 8);
-  mformat.varname = dict_create_var (dict, "VARNAME_", 8);
+  mformat.rowtype = dict_create_var (dict, "ROWTYPE_", ROWTYPE_WIDTH);
 
   mformat.n_continuous_vars = 0;
+  mformat.n_split_vars = 0;
 
   if (! lex_force_match_id (lexer, "VARIABLES"))
     goto error;
 
   lex_match (lexer, T_EQUALS);
 
-  if (! parse_mixed_vars (lexer, dict, &names, &n_names, 0))
+  if (! parse_mixed_vars (lexer, dict, &names, &n_names, PV_NO_DUPLICATE))
     {
       int i;
       for (i = 0; i < n_names; ++i)
@@ -281,6 +399,15 @@ cmd_matrix (struct lexer *lexer, struct dataset *ds)
       goto error;
     }
 
+  int longest_name = 0;
+  for (i = 0; i < n_names; ++i)
+    {
+      maximize_int (&longest_name, strlen (names[i]));
+    }
+
+  mformat.varname = dict_create_var (dict, "VARNAME_",
+                                    8 * DIV_RND_UP (longest_name, 8));
+
   for (i = 0; i < n_names; ++i)
     {
       if (0 == strcasecmp (names[i], "ROWTYPE_"))
@@ -312,7 +439,22 @@ cmd_matrix (struct lexer *lexer, struct dataset *ds)
       if (! lex_force_match (lexer, T_SLASH))
        goto error;
 
-      if (lex_match_id (lexer, "FORMAT"))
+      if (lex_match_id (lexer, "N"))
+       {
+         lex_match (lexer, T_EQUALS);
+
+         if (! lex_force_int (lexer))
+           goto error;
+
+         mformat.n = lex_integer (lexer);
+         if (mformat.n < 0)
+           {
+             msg (SE, _("%s must not be negative."), "N");
+             goto error;
+           }
+         lex_get (lexer);
+       }
+      else if (lex_match_id (lexer, "FORMAT"))
        {
          lex_match (lexer, T_EQUALS);
 
@@ -364,22 +506,19 @@ cmd_matrix (struct lexer *lexer, struct dataset *ds)
       else if (lex_match_id (lexer, "SPLIT"))
        {
          lex_match (lexer, T_EQUALS);
-         struct variable **split_vars = NULL;
-         size_t n_split_vars;
-         if (! parse_variables (lexer, dict, &split_vars, &n_split_vars, 0))
+         if (! parse_variables (lexer, dict, &mformat.split_vars, &mformat.n_split_vars, 0))
            {
-             free (split_vars);
+             free (mformat.split_vars);
              goto error;
            }
          int i;
-         for (i = 0; i < n_split_vars; ++i)
+         for (i = 0; i < mformat.n_split_vars; ++i)
            {
              const struct fmt_spec fmt = fmt_for_input (FMT_F, 4, 0);
-             var_set_both_formats (split_vars[i], &fmt);
+             var_set_both_formats (mformat.split_vars[i], &fmt);
            }
-         dict_reorder_vars (dict, split_vars, n_split_vars);
-         mformat.n_continuous_vars -= n_split_vars;
-         free (split_vars);
+         dict_reorder_vars (dict, mformat.split_vars, mformat.n_split_vars);
+         mformat.n_continuous_vars -= mformat.n_split_vars;
        }
       else
        {
@@ -421,20 +560,23 @@ cmd_matrix (struct lexer *lexer, struct dataset *ds)
     }
   else
     {
-      data_parser_make_active_file (parser, ds, reader, dict, preprocess, &mformat);
+      data_parser_make_active_file (parser, ds, reader, dict, preprocess,
+                                   &mformat);
     }
 
   fh_unref (fh);
   free (encoding);
+  free (mformat.split_vars);
 
   return CMD_DATA_LIST;
 
  error:
   data_parser_destroy (parser);
   if (!in_input_program ())
-    dict_destroy (dict);
+    dict_unref (dict);
   fh_unref (fh);
   free (encoding);
+  free (mformat.split_vars);
   return CMD_CASCADING_FAILURE;
 }