matrix-data: Only use as many bytes as necessary to initialize string.
[pspp] / src / language / data-io / matrix-data.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2017 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "data/case.h"
20 #include "data/casereader.h"
21 #include "data/casewriter.h"
22 #include "data/dataset.h"
23 #include "data/dictionary.h"
24 #include "data/format.h"
25 #include "data/transformations.h"
26 #include "data/variable.h"
27 #include "language/command.h"
28 #include "language/data-io/data-parser.h"
29 #include "language/data-io/data-reader.h"
30 #include "language/data-io/file-handle.h"
31 #include "language/data-io/inpt-pgm.h"
32 #include "language/data-io/placement-parser.h"
33 #include "language/lexer/lexer.h"
34 #include "language/lexer/variable-parser.h"
35 #include "libpspp/i18n.h"
36 #include "libpspp/message.h"
37 #include "libpspp/misc.h"
38
39 #include "gl/xsize.h"
40 #include "gl/xalloc.h"
41
42 #include "gettext.h"
43 #define _(msgid) gettext (msgid)
44 \f
45 /* DATA LIST transformation data. */
46 struct data_list_trns
47   {
48     struct data_parser *parser; /* Parser. */
49     struct dfm_reader *reader;  /* Data file reader. */
50     struct variable *end;       /* Variable specified on END subcommand. */
51   };
52
53 static trns_free_func data_list_trns_free;
54 static trns_proc_func data_list_trns_proc;
55
56 enum diagonal
57   {
58     DIAGONAL,
59     NO_DIAGONAL
60   };
61
62 enum triangle
63   {
64     LOWER,
65     UPPER,
66     FULL
67   };
68
69 static const int ROWTYPE_WIDTH = 8;
70
71 struct matrix_format
72 {
73   enum triangle triangle;
74   enum diagonal diagonal;
75   const struct variable *rowtype;
76   const struct variable *varname;
77   int n_continuous_vars;
78   struct variable **split_vars;
79   size_t n_split_vars;
80   long n;
81 };
82
83 /*
84 valid rowtype_ values:
85   CORR,
86   COV,
87   MAT,
88
89
90   MSE,
91   DFE,
92   MEAN,
93   STDDEV (or SD),
94   N_VECTOR (or N),
95   N_SCALAR,
96   N_MATRIX,
97   COUNT,
98   PROX.
99 */
100
101 /* Sets the value of OUTCASE which corresponds to VNAME
102    to the value STR.  VNAME must be of type string.
103  */
104 static void
105 set_varname_column (struct ccase *outcase, const struct variable *vname,
106      const char *str)
107 {
108   int len = var_get_width (vname);
109   uint8_t *s = value_str_rw (case_data_rw (outcase, vname), len);
110
111   strncpy ((char *) s, str, len);
112 }
113
114 static void
115 blank_varname_column (struct ccase *outcase, const struct variable *vname)
116 {
117   int len = var_get_width (vname);
118   uint8_t *s = value_str_rw (case_data_rw (outcase, vname), len);
119
120   memset (s, ' ', len);
121 }
122
123 static struct casereader *
124 preprocess (struct casereader *casereader0, const struct dictionary *dict, void *aux)
125 {
126   struct matrix_format *mformat = aux;
127   const struct caseproto *proto = casereader_get_proto (casereader0);
128   struct casewriter *writer = autopaging_writer_create (proto);
129   struct ccase *prev_case = NULL;
130   double **matrices = NULL;
131   size_t n_splits = 0;
132
133   const size_t sizeof_matrix =
134     sizeof (double) * mformat->n_continuous_vars * mformat->n_continuous_vars;
135
136
137   /* Make an initial pass to populate our temporary matrix */
138   struct casereader *pass0 = casereader_clone (casereader0);
139   struct ccase *c;
140   unsigned int prev_split_hash = 1;
141   int row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ? 1 : 0;
142   for (; (c = casereader_read (pass0)) != NULL; case_unref (c))
143     {
144       int s;
145       unsigned int split_hash = 0;
146       for (s = 0; s < mformat->n_split_vars; ++s)
147         {
148           const struct variable *svar = mformat->split_vars[s];
149           const union value *sv = case_data (c, svar);
150           split_hash = value_hash (sv, var_get_width (svar), split_hash);
151         }
152
153       if (matrices == NULL || prev_split_hash != split_hash)
154         {
155           row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ?
156             1 : 0;
157
158           n_splits++;
159           matrices = xrealloc (matrices, sizeof (double*)  * n_splits);
160           matrices[n_splits - 1] = xmalloc (sizeof_matrix);
161         }
162
163       prev_split_hash = split_hash;
164
165       int c_offset = (mformat->triangle == UPPER) ? row : 0;
166       if (mformat->triangle == UPPER && mformat->diagonal == NO_DIAGONAL)
167         c_offset++;
168       const union value *v = case_data (c, mformat->rowtype);
169       const char *val = (const char *) value_str (v, ROWTYPE_WIDTH);
170       if (0 == strncasecmp (val, "corr    ", ROWTYPE_WIDTH) ||
171           0 == strncasecmp (val, "cov     ", ROWTYPE_WIDTH))
172         {
173           if (row >= mformat->n_continuous_vars)
174             {
175               msg (SE,
176                    _("There are %d variable declared but the data has at least %d matrix rows."),
177                    mformat->n_continuous_vars, row + 1);
178               case_unref (c);
179               casereader_destroy (pass0);
180               goto error;
181             }
182           int col;
183           for (col = c_offset; col < mformat->n_continuous_vars; ++col)
184             {
185               const struct variable *var =
186                 dict_get_var (dict,
187                               1 + col - c_offset +
188                               var_get_dict_index (mformat->varname));
189
190               double e = case_data (c, var)->f;
191               if (e == SYSMIS)
192                 continue;
193
194               /* Fill in the lower triangle */
195               (matrices[n_splits-1])[col + mformat->n_continuous_vars * row] = e;
196
197               if (mformat->triangle != FULL)
198                 /* Fill in the upper triangle */
199                 (matrices[n_splits-1]) [row + mformat->n_continuous_vars * col] = e;
200             }
201           row++;
202         }
203     }
204   casereader_destroy (pass0);
205
206   /* Now make a second pass to fill in the other triangle from our
207      temporary matrix */
208   const int idx = var_get_dict_index (mformat->varname);
209   row = 0;
210
211   if (mformat->n >= 0)
212     {
213       int col;
214       struct ccase *outcase = case_create (proto);
215       union value *v = case_data_rw (outcase, mformat->rowtype);
216       uint8_t *n = value_str_rw (v, ROWTYPE_WIDTH);
217       memcpy (n, "N       ", ROWTYPE_WIDTH);
218       blank_varname_column (outcase, mformat->varname);
219       for (col = 0; col < mformat->n_continuous_vars; ++col)
220         {
221           union value *dest_val =
222             case_data_rw_idx (outcase,
223                               1 + col + var_get_dict_index (mformat->varname));
224           dest_val->f = mformat->n;
225         }
226       casewriter_write (writer, outcase);
227     }
228
229   prev_split_hash = 1;
230   n_splits = 0;
231   for (; (c = casereader_read (casereader0)) != NULL; prev_case = c)
232     {
233       int s;
234       unsigned int split_hash = 0;
235       for (s = 0; s < mformat->n_split_vars; ++s)
236         {
237           const struct variable *svar = mformat->split_vars[s];
238           const union value *sv = case_data (c, svar);
239           split_hash = value_hash (sv, var_get_width (svar), split_hash);
240         }
241       if (prev_split_hash != split_hash)
242         {
243           n_splits++;
244           row = 0;
245         }
246
247       prev_split_hash = split_hash;
248       case_unref (prev_case);
249       const union value *v = case_data (c, mformat->rowtype);
250       const char *val = (const char *) value_str (v, ROWTYPE_WIDTH);
251       if (mformat->n >= 0)
252         {
253           if (0 == strncasecmp (val, "n       ", ROWTYPE_WIDTH) ||
254               0 == strncasecmp (val, "n_vector", ROWTYPE_WIDTH))
255             {
256               msg (SW,
257                    _("The N subcommand was specified, but a N record was also found in the data.  The N record will be ignored."));
258               continue;
259             }
260         }
261
262       struct ccase *outcase = case_create (proto);
263       case_copy (outcase, 0, c, 0, caseproto_get_n_widths (proto));
264
265       if (0 == strncasecmp (val, "corr    ", ROWTYPE_WIDTH) ||
266           0 == strncasecmp (val, "cov     ", ROWTYPE_WIDTH))
267         {
268           int col;
269           const struct variable *var = dict_get_var (dict, idx + 1 + row);
270           set_varname_column (outcase, mformat->varname, var_get_name (var));
271           value_copy (case_data_rw (outcase, mformat->rowtype), v, ROWTYPE_WIDTH);
272
273           for (col = 0; col < mformat->n_continuous_vars; ++col)
274             {
275               union value *dest_val =
276                 case_data_rw_idx (outcase,
277                                   1 + col + var_get_dict_index (mformat->varname));
278               dest_val->f = (matrices[n_splits - 1])[col + mformat->n_continuous_vars * row];
279               if (col == row && mformat->diagonal == NO_DIAGONAL)
280                 dest_val->f = 1.0;
281             }
282           row++;
283         }
284       else
285         {
286           blank_varname_column (outcase, mformat->varname);
287         }
288
289       /* Special case for SD and N_VECTOR: Rewrite as STDDEV and N respectively */
290       if (0 == strncasecmp (val, "sd      ", ROWTYPE_WIDTH))
291         {
292           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), ROWTYPE_WIDTH,
293                                (uint8_t *) "STDDEV", 6, ' ');
294         }
295       else if (0 == strncasecmp (val, "n_vector", ROWTYPE_WIDTH))
296         {
297           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), ROWTYPE_WIDTH,
298                                (uint8_t *) "N", 1, ' ');
299         }
300
301       casewriter_write (writer, outcase);
302     }
303
304   /* If NODIAGONAL is specified, then a final case must be written */
305   if (mformat->diagonal == NO_DIAGONAL)
306     {
307       int col;
308       struct ccase *outcase = case_create (proto);
309
310       if (prev_case)
311         case_copy (outcase, 0, prev_case, 0, caseproto_get_n_widths (proto));
312
313       const struct variable *var = dict_get_var (dict, idx + 1 + row);
314       set_varname_column (outcase, mformat->varname, var_get_name (var));
315
316       for (col = 0; col < mformat->n_continuous_vars; ++col)
317         {
318           union value *dest_val =
319             case_data_rw_idx (outcase, 1 + col +
320                               var_get_dict_index (mformat->varname));
321           dest_val->f = (matrices[n_splits - 1]) [col + mformat->n_continuous_vars * row];
322           if (col == row && mformat->diagonal == NO_DIAGONAL)
323             dest_val->f = 1.0;
324         }
325
326       casewriter_write (writer, outcase);
327     }
328
329
330   if (prev_case)
331     case_unref (prev_case);
332
333   int i;
334   for (i = 0 ; i < n_splits; ++i)
335     free (matrices[i]);
336   free (matrices);
337   struct casereader *reader1 = casewriter_make_reader (writer);
338   casereader_destroy (casereader0);
339   return reader1;
340
341
342 error:
343   if (prev_case)
344     case_unref (prev_case);
345
346   for (i = 0 ; i < n_splits; ++i)
347     free (matrices[i]);
348   free (matrices);
349   casereader_destroy (casereader0);
350   casewriter_destroy (writer);
351   return NULL;
352 }
353
354 int
355 cmd_matrix (struct lexer *lexer, struct dataset *ds)
356 {
357   struct dictionary *dict;
358   struct data_parser *parser;
359   struct dfm_reader *reader;
360   struct file_handle *fh = NULL;
361   char *encoding = NULL;
362   struct matrix_format mformat;
363   int i;
364   size_t n_names;
365   char **names = NULL;
366
367   mformat.triangle = LOWER;
368   mformat.diagonal = DIAGONAL;
369   mformat.n_split_vars = 0;
370   mformat.split_vars = NULL;
371   mformat.n = -1;
372
373   dict = (in_input_program ()
374           ? dataset_dict (ds)
375           : dict_create (get_default_encoding ()));
376   parser = data_parser_create (dict);
377   reader = NULL;
378
379   data_parser_set_type (parser, DP_DELIMITED);
380   data_parser_set_warn_missing_fields (parser, false);
381   data_parser_set_span (parser, false);
382
383   mformat.rowtype = dict_create_var (dict, "ROWTYPE_", ROWTYPE_WIDTH);
384
385   mformat.n_continuous_vars = 0;
386   mformat.n_split_vars = 0;
387
388   if (! lex_force_match_id (lexer, "VARIABLES"))
389     goto error;
390
391   lex_match (lexer, T_EQUALS);
392
393   if (! parse_mixed_vars (lexer, dict, &names, &n_names, PV_NO_DUPLICATE))
394     {
395       int i;
396       for (i = 0; i < n_names; ++i)
397         free (names[i]);
398       free (names);
399       goto error;
400     }
401
402   int longest_name = 0;
403   for (i = 0; i < n_names; ++i)
404     {
405       maximize_int (&longest_name, strlen (names[i]));
406     }
407
408   mformat.varname = dict_create_var (dict, "VARNAME_",
409                                      8 * DIV_RND_UP (longest_name, 8));
410
411   for (i = 0; i < n_names; ++i)
412     {
413       if (0 == strcasecmp (names[i], "ROWTYPE_"))
414         {
415           const struct fmt_spec fmt = fmt_for_input (FMT_A, 8, 0);
416           data_parser_add_delimited_field (parser,
417                                            &fmt,
418                                            var_get_case_index (mformat.rowtype),
419                                            "ROWTYPE_");
420         }
421       else
422         {
423           const struct fmt_spec fmt = fmt_for_input (FMT_F, 10, 4);
424           struct variable *v = dict_create_var (dict, names[i], 0);
425           var_set_both_formats (v, &fmt);
426           data_parser_add_delimited_field (parser,
427                                            &fmt,
428                                            var_get_case_index (mformat.varname) +
429                                            ++mformat.n_continuous_vars,
430                                            names[i]);
431         }
432     }
433   for (i = 0; i < n_names; ++i)
434     free (names[i]);
435   free (names);
436
437   while (lex_token (lexer) != T_ENDCMD)
438     {
439       if (! lex_force_match (lexer, T_SLASH))
440         goto error;
441
442       if (lex_match_id (lexer, "N"))
443         {
444           lex_match (lexer, T_EQUALS);
445
446           if (! lex_force_int (lexer))
447             goto error;
448
449           mformat.n = lex_integer (lexer);
450           if (mformat.n < 0)
451             {
452               msg (SE, _("%s must not be negative."), "N");
453               goto error;
454             }
455           lex_get (lexer);
456         }
457       else if (lex_match_id (lexer, "FORMAT"))
458         {
459           lex_match (lexer, T_EQUALS);
460
461           while (lex_token (lexer) != T_SLASH && (lex_token (lexer) != T_ENDCMD))
462             {
463               if (lex_match_id (lexer, "LIST"))
464                 {
465                   data_parser_set_span (parser, false);
466                 }
467               else if (lex_match_id (lexer, "FREE"))
468                 {
469                   data_parser_set_span (parser, true);
470                 }
471               else if (lex_match_id (lexer, "UPPER"))
472                 {
473                   mformat.triangle = UPPER;
474                 }
475               else if (lex_match_id (lexer, "LOWER"))
476                 {
477                   mformat.triangle = LOWER;
478                 }
479               else if (lex_match_id (lexer, "FULL"))
480                 {
481                   mformat.triangle = FULL;
482                 }
483               else if (lex_match_id (lexer, "DIAGONAL"))
484                 {
485                   mformat.diagonal = DIAGONAL;
486                 }
487               else if (lex_match_id (lexer, "NODIAGONAL"))
488                 {
489                   mformat.diagonal = NO_DIAGONAL;
490                 }
491               else
492                 {
493                   lex_error (lexer, NULL);
494                   goto error;
495                 }
496             }
497         }
498       else if (lex_match_id (lexer, "FILE"))
499         {
500           lex_match (lexer, T_EQUALS);
501           fh_unref (fh);
502           fh = fh_parse (lexer, FH_REF_FILE | FH_REF_INLINE, NULL);
503           if (fh == NULL)
504             goto error;
505         }
506       else if (lex_match_id (lexer, "SPLIT"))
507         {
508           lex_match (lexer, T_EQUALS);
509           if (! parse_variables (lexer, dict, &mformat.split_vars, &mformat.n_split_vars, 0))
510             {
511               free (mformat.split_vars);
512               goto error;
513             }
514           int i;
515           for (i = 0; i < mformat.n_split_vars; ++i)
516             {
517               const struct fmt_spec fmt = fmt_for_input (FMT_F, 4, 0);
518               var_set_both_formats (mformat.split_vars[i], &fmt);
519             }
520           dict_reorder_vars (dict, mformat.split_vars, mformat.n_split_vars);
521           mformat.n_continuous_vars -= mformat.n_split_vars;
522         }
523       else
524         {
525           lex_error (lexer, NULL);
526           goto error;
527         }
528     }
529
530   if (mformat.diagonal == NO_DIAGONAL && mformat.triangle == FULL)
531     {
532       msg (SE, _("FORMAT = FULL and FORMAT = NODIAGONAL are mutually exclusive."));
533       goto error;
534     }
535
536   if (fh == NULL)
537     fh = fh_inline_file ();
538   fh_set_default_handle (fh);
539
540   if (!data_parser_any_fields (parser))
541     {
542       msg (SE, _("At least one variable must be specified."));
543       goto error;
544     }
545
546   if (lex_end_of_command (lexer) != CMD_SUCCESS)
547     goto error;
548
549   reader = dfm_open_reader (fh, lexer, encoding);
550   if (reader == NULL)
551     goto error;
552
553   if (in_input_program ())
554     {
555       struct data_list_trns *trns = xmalloc (sizeof *trns);
556       trns->parser = parser;
557       trns->reader = reader;
558       trns->end = NULL;
559       add_transformation (ds, data_list_trns_proc, data_list_trns_free, trns);
560     }
561   else
562     {
563       data_parser_make_active_file (parser, ds, reader, dict, preprocess,
564                                     &mformat);
565     }
566
567   fh_unref (fh);
568   free (encoding);
569   free (mformat.split_vars);
570
571   return CMD_DATA_LIST;
572
573  error:
574   data_parser_destroy (parser);
575   if (!in_input_program ())
576     dict_unref (dict);
577   fh_unref (fh);
578   free (encoding);
579   free (mformat.split_vars);
580   return CMD_CASCADING_FAILURE;
581 }
582
583 \f
584 /* Input procedure. */
585
586 /* Destroys DATA LIST transformation TRNS.
587    Returns true if successful, false if an I/O error occurred. */
588 static bool
589 data_list_trns_free (void *trns_)
590 {
591   struct data_list_trns *trns = trns_;
592   data_parser_destroy (trns->parser);
593   dfm_close_reader (trns->reader);
594   free (trns);
595   return true;
596 }
597
598 /* Handle DATA LIST transformation TRNS, parsing data into *C. */
599 static int
600 data_list_trns_proc (void *trns_, struct ccase **c, casenumber case_num UNUSED)
601 {
602   struct data_list_trns *trns = trns_;
603   int retval;
604
605   *c = case_unshare (*c);
606   if (data_parser_parse (trns->parser, trns->reader, *c))
607     retval = TRNS_CONTINUE;
608   else if (dfm_reader_error (trns->reader) || dfm_eof (trns->reader) > 1)
609     {
610       /* An I/O error, or encountering end of file for a second
611          time, should be escalated into a more serious error. */
612       retval = TRNS_ERROR;
613     }
614   else
615     retval = TRNS_END_FILE;
616
617   /* If there was an END subcommand handle it. */
618   if (trns->end != NULL)
619     {
620       double *end = &case_data_rw (*c, trns->end)->f;
621       if (retval == TRNS_END_FILE)
622         {
623           *end = 1.0;
624           retval = TRNS_CONTINUE;
625         }
626       else
627         *end = 0.0;
628     }
629
630   return retval;
631 }
632