Matrix Data: Identify splits explicitly instead of with hashes.
[pspp] / src / language / data-io / matrix-data.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2017 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "data/case.h"
20 #include "data/casereader.h"
21 #include "data/casewriter.h"
22 #include "data/dataset.h"
23 #include "data/dictionary.h"
24 #include "data/format.h"
25 #include "data/transformations.h"
26 #include "data/variable.h"
27 #include "language/command.h"
28 #include "language/data-io/data-parser.h"
29 #include "language/data-io/data-reader.h"
30 #include "language/data-io/file-handle.h"
31 #include "language/data-io/inpt-pgm.h"
32 #include "language/data-io/placement-parser.h"
33 #include "language/lexer/lexer.h"
34 #include "language/lexer/variable-parser.h"
35 #include "libpspp/i18n.h"
36 #include "libpspp/message.h"
37 #include "libpspp/misc.h"
38
39 #include "gl/xsize.h"
40 #include "gl/xalloc.h"
41
42 #include "gettext.h"
43 #define _(msgid) gettext (msgid)
44 \f
45 /* DATA LIST transformation data. */
46 struct data_list_trns
47   {
48     struct data_parser *parser; /* Parser. */
49     struct dfm_reader *reader;  /* Data file reader. */
50     struct variable *end;       /* Variable specified on END subcommand. */
51   };
52
53 static trns_free_func data_list_trns_free;
54 static trns_proc_func data_list_trns_proc;
55
56 enum diagonal
57   {
58     DIAGONAL,
59     NO_DIAGONAL
60   };
61
62 enum triangle
63   {
64     LOWER,
65     UPPER,
66     FULL
67   };
68
69 static const int ROWTYPE_WIDTH = 8;
70
71 struct matrix_format
72 {
73   enum triangle triangle;
74   enum diagonal diagonal;
75   const struct variable *rowtype;
76   const struct variable *varname;
77   int n_continuous_vars;
78   struct variable **split_vars;
79   size_t n_split_vars;
80   long n;
81 };
82
83 /*
84 valid rowtype_ values:
85   CORR,
86   COV,
87   MAT,
88
89
90   MSE,
91   DFE,
92   MEAN,
93   STDDEV (or SD),
94   N_VECTOR (or N),
95   N_SCALAR,
96   N_MATRIX,
97   COUNT,
98   PROX.
99 */
100
101 /* Sets the value of OUTCASE which corresponds to VNAME
102    to the value STR.  VNAME must be of type string.
103  */
104 static void
105 set_varname_column (struct ccase *outcase, const struct variable *vname,
106      const char *str)
107 {
108   int len = var_get_width (vname);
109   uint8_t *s = value_str_rw (case_data_rw (outcase, vname), len);
110
111   strncpy ((char *) s, str, len);
112 }
113
114 static void
115 blank_varname_column (struct ccase *outcase, const struct variable *vname)
116 {
117   int len = var_get_width (vname);
118   uint8_t *s = value_str_rw (case_data_rw (outcase, vname), len);
119
120   memset (s, ' ', len);
121 }
122
123 static struct casereader *
124 preprocess (struct casereader *casereader0, const struct dictionary *dict, void *aux)
125 {
126   struct matrix_format *mformat = aux;
127   const struct caseproto *proto = casereader_get_proto (casereader0);
128   struct casewriter *writer = autopaging_writer_create (proto);
129   struct ccase *prev_case = NULL;
130   double **matrices = NULL;
131   size_t n_splits = 0;
132
133   const size_t sizeof_matrix =
134     sizeof (double) * mformat->n_continuous_vars * mformat->n_continuous_vars;
135
136
137   /* Make an initial pass to populate our temporary matrix */
138   struct casereader *pass0 = casereader_clone (casereader0);
139   struct ccase *c;
140   union value *prev_values = xcalloc (mformat->n_split_vars, sizeof *prev_values);
141   int row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ? 1 : 0;
142   bool first_case = true;
143   for (; (c = casereader_read (pass0)) != NULL; case_unref (c))
144     {
145       int s;
146       bool match = false;
147       if (!first_case)
148         {
149           match = true;
150           for (s = 0; s < mformat->n_split_vars; ++s)
151             {
152               const struct variable *svar = mformat->split_vars[s];
153               const union value *sv = case_data (c, svar);
154               if (! value_equal (prev_values + s, sv, var_get_width (svar)))
155                 {
156                   match = false;
157                   break;
158                 }
159             }
160         }
161       first_case = false;
162
163       if (matrices == NULL || ! match)
164         {
165           row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ?
166             1 : 0;
167
168           n_splits++;
169           matrices = xrealloc (matrices, sizeof (double*)  * n_splits);
170           matrices[n_splits - 1] = xmalloc (sizeof_matrix);
171         }
172
173       for (s = 0; s < mformat->n_split_vars; ++s)
174         {
175           const struct variable *svar = mformat->split_vars[s];
176           const union value *sv = case_data (c, svar);
177           value_clone (prev_values + s, sv, var_get_width (svar));
178         }
179
180       int c_offset = (mformat->triangle == UPPER) ? row : 0;
181       if (mformat->triangle == UPPER && mformat->diagonal == NO_DIAGONAL)
182         c_offset++;
183       const union value *v = case_data (c, mformat->rowtype);
184       const char *val = (const char *) value_str (v, ROWTYPE_WIDTH);
185       if (0 == strncasecmp (val, "corr    ", ROWTYPE_WIDTH) ||
186           0 == strncasecmp (val, "cov     ", ROWTYPE_WIDTH))
187         {
188           if (row >= mformat->n_continuous_vars)
189             {
190               msg (SE,
191                    _("There are %d variable declared but the data has at least %d matrix rows."),
192                    mformat->n_continuous_vars, row + 1);
193               case_unref (c);
194               casereader_destroy (pass0);
195               free (prev_values);
196               goto error;
197             }
198           int col;
199           for (col = c_offset; col < mformat->n_continuous_vars; ++col)
200             {
201               const struct variable *var =
202                 dict_get_var (dict,
203                               1 + col - c_offset +
204                               var_get_dict_index (mformat->varname));
205
206               double e = case_data (c, var)->f;
207               if (e == SYSMIS)
208                 continue;
209
210               /* Fill in the lower triangle */
211               (matrices[n_splits-1])[col + mformat->n_continuous_vars * row] = e;
212
213               if (mformat->triangle != FULL)
214                 /* Fill in the upper triangle */
215                 (matrices[n_splits-1]) [row + mformat->n_continuous_vars * col] = e;
216             }
217           row++;
218         }
219     }
220   casereader_destroy (pass0);
221   free (prev_values);
222
223   /* Now make a second pass to fill in the other triangle from our
224      temporary matrix */
225   const int idx = var_get_dict_index (mformat->varname);
226   row = 0;
227
228   if (mformat->n >= 0)
229     {
230       int col;
231       struct ccase *outcase = case_create (proto);
232       union value *v = case_data_rw (outcase, mformat->rowtype);
233       uint8_t *n = value_str_rw (v, ROWTYPE_WIDTH);
234       memcpy (n, "N       ", ROWTYPE_WIDTH);
235       blank_varname_column (outcase, mformat->varname);
236       for (col = 0; col < mformat->n_continuous_vars; ++col)
237         {
238           union value *dest_val =
239             case_data_rw_idx (outcase,
240                               1 + col + var_get_dict_index (mformat->varname));
241           dest_val->f = mformat->n;
242         }
243       casewriter_write (writer, outcase);
244     }
245
246   n_splits = 0;
247   prev_values = xcalloc (mformat->n_split_vars, sizeof *prev_values);
248   first_case = true;
249   for (; (c = casereader_read (casereader0)) != NULL; prev_case = c)
250     {
251       int s;
252       bool match = false;
253       if (!first_case)
254         {
255           match = true;
256           for (s = 0; s < mformat->n_split_vars; ++s)
257             {
258               const struct variable *svar = mformat->split_vars[s];
259               const union value *sv = case_data (c, svar);
260               if (! value_equal (prev_values + s, sv, var_get_width (svar)))
261                 {
262                   match = false;
263                   break;
264                 }
265             }
266         }
267       first_case = false;
268       if (! match)
269         {
270           n_splits++;
271           row = 0;
272         }
273
274       for (s = 0; s < mformat->n_split_vars; ++s)
275         {
276           const struct variable *svar = mformat->split_vars[s];
277           const union value *sv = case_data (c, svar);
278           value_clone (prev_values + s, sv, var_get_width (svar));
279         }
280
281       case_unref (prev_case);
282       const union value *v = case_data (c, mformat->rowtype);
283       const char *val = (const char *) value_str (v, ROWTYPE_WIDTH);
284       if (mformat->n >= 0)
285         {
286           if (0 == strncasecmp (val, "n       ", ROWTYPE_WIDTH) ||
287               0 == strncasecmp (val, "n_vector", ROWTYPE_WIDTH))
288             {
289               msg (SW,
290                    _("The N subcommand was specified, but a N record was also found in the data.  The N record will be ignored."));
291               continue;
292             }
293         }
294
295       struct ccase *outcase = case_create (proto);
296       case_copy (outcase, 0, c, 0, caseproto_get_n_widths (proto));
297
298       if (0 == strncasecmp (val, "corr    ", ROWTYPE_WIDTH) ||
299           0 == strncasecmp (val, "cov     ", ROWTYPE_WIDTH))
300         {
301           int col;
302           const struct variable *var = dict_get_var (dict, idx + 1 + row);
303           set_varname_column (outcase, mformat->varname, var_get_name (var));
304           value_copy (case_data_rw (outcase, mformat->rowtype), v, ROWTYPE_WIDTH);
305
306           for (col = 0; col < mformat->n_continuous_vars; ++col)
307             {
308               union value *dest_val =
309                 case_data_rw_idx (outcase,
310                                   1 + col + var_get_dict_index (mformat->varname));
311               dest_val->f = (matrices[n_splits - 1])[col + mformat->n_continuous_vars * row];
312               if (col == row && mformat->diagonal == NO_DIAGONAL)
313                 dest_val->f = 1.0;
314             }
315           row++;
316         }
317       else
318         {
319           blank_varname_column (outcase, mformat->varname);
320         }
321
322       /* Special case for SD and N_VECTOR: Rewrite as STDDEV and N respectively */
323       if (0 == strncasecmp (val, "sd      ", ROWTYPE_WIDTH))
324         {
325           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), ROWTYPE_WIDTH,
326                                (uint8_t *) "STDDEV", 6, ' ');
327         }
328       else if (0 == strncasecmp (val, "n_vector", ROWTYPE_WIDTH))
329         {
330           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), ROWTYPE_WIDTH,
331                                (uint8_t *) "N", 1, ' ');
332         }
333
334       casewriter_write (writer, outcase);
335     }
336
337   /* If NODIAGONAL is specified, then a final case must be written */
338   if (mformat->diagonal == NO_DIAGONAL)
339     {
340       int col;
341       struct ccase *outcase = case_create (proto);
342
343       if (prev_case)
344         case_copy (outcase, 0, prev_case, 0, caseproto_get_n_widths (proto));
345
346       const struct variable *var = dict_get_var (dict, idx + 1 + row);
347       set_varname_column (outcase, mformat->varname, var_get_name (var));
348
349       for (col = 0; col < mformat->n_continuous_vars; ++col)
350         {
351           union value *dest_val =
352             case_data_rw_idx (outcase, 1 + col +
353                               var_get_dict_index (mformat->varname));
354           dest_val->f = (matrices[n_splits - 1]) [col + mformat->n_continuous_vars * row];
355           if (col == row && mformat->diagonal == NO_DIAGONAL)
356             dest_val->f = 1.0;
357         }
358
359       casewriter_write (writer, outcase);
360     }
361   free (prev_values);
362
363   if (prev_case)
364     case_unref (prev_case);
365
366   int i;
367   for (i = 0 ; i < n_splits; ++i)
368     free (matrices[i]);
369   free (matrices);
370   struct casereader *reader1 = casewriter_make_reader (writer);
371   casereader_destroy (casereader0);
372   return reader1;
373
374
375 error:
376   if (prev_case)
377     case_unref (prev_case);
378
379   for (i = 0 ; i < n_splits; ++i)
380     free (matrices[i]);
381   free (matrices);
382   casereader_destroy (casereader0);
383   casewriter_destroy (writer);
384   return NULL;
385 }
386
387 int
388 cmd_matrix (struct lexer *lexer, struct dataset *ds)
389 {
390   struct dictionary *dict;
391   struct data_parser *parser;
392   struct dfm_reader *reader;
393   struct file_handle *fh = NULL;
394   char *encoding = NULL;
395   struct matrix_format mformat;
396   int i;
397   size_t n_names;
398   char **names = NULL;
399
400   mformat.triangle = LOWER;
401   mformat.diagonal = DIAGONAL;
402   mformat.n_split_vars = 0;
403   mformat.split_vars = NULL;
404   mformat.n = -1;
405
406   dict = (in_input_program ()
407           ? dataset_dict (ds)
408           : dict_create (get_default_encoding ()));
409   parser = data_parser_create (dict);
410   reader = NULL;
411
412   data_parser_set_type (parser, DP_DELIMITED);
413   data_parser_set_warn_missing_fields (parser, false);
414   data_parser_set_span (parser, false);
415
416   mformat.rowtype = dict_create_var (dict, "ROWTYPE_", ROWTYPE_WIDTH);
417
418   mformat.n_continuous_vars = 0;
419   mformat.n_split_vars = 0;
420
421   if (! lex_force_match_id (lexer, "VARIABLES"))
422     goto error;
423
424   lex_match (lexer, T_EQUALS);
425
426   if (! parse_mixed_vars (lexer, dict, &names, &n_names, PV_NO_DUPLICATE))
427     {
428       int i;
429       for (i = 0; i < n_names; ++i)
430         free (names[i]);
431       free (names);
432       goto error;
433     }
434
435   int longest_name = 0;
436   for (i = 0; i < n_names; ++i)
437     {
438       maximize_int (&longest_name, strlen (names[i]));
439     }
440
441   mformat.varname = dict_create_var (dict, "VARNAME_",
442                                      8 * DIV_RND_UP (longest_name, 8));
443
444   for (i = 0; i < n_names; ++i)
445     {
446       if (0 == strcasecmp (names[i], "ROWTYPE_"))
447         {
448           const struct fmt_spec fmt = fmt_for_input (FMT_A, 8, 0);
449           data_parser_add_delimited_field (parser,
450                                            &fmt,
451                                            var_get_case_index (mformat.rowtype),
452                                            "ROWTYPE_");
453         }
454       else
455         {
456           const struct fmt_spec fmt = fmt_for_input (FMT_F, 10, 4);
457           struct variable *v = dict_create_var (dict, names[i], 0);
458           var_set_both_formats (v, &fmt);
459           data_parser_add_delimited_field (parser,
460                                            &fmt,
461                                            var_get_case_index (mformat.varname) +
462                                            ++mformat.n_continuous_vars,
463                                            names[i]);
464         }
465     }
466   for (i = 0; i < n_names; ++i)
467     free (names[i]);
468   free (names);
469
470   while (lex_token (lexer) != T_ENDCMD)
471     {
472       if (! lex_force_match (lexer, T_SLASH))
473         goto error;
474
475       if (lex_match_id (lexer, "N"))
476         {
477           lex_match (lexer, T_EQUALS);
478
479           if (! lex_force_int (lexer))
480             goto error;
481
482           mformat.n = lex_integer (lexer);
483           if (mformat.n < 0)
484             {
485               msg (SE, _("%s must not be negative."), "N");
486               goto error;
487             }
488           lex_get (lexer);
489         }
490       else if (lex_match_id (lexer, "FORMAT"))
491         {
492           lex_match (lexer, T_EQUALS);
493
494           while (lex_token (lexer) != T_SLASH && (lex_token (lexer) != T_ENDCMD))
495             {
496               if (lex_match_id (lexer, "LIST"))
497                 {
498                   data_parser_set_span (parser, false);
499                 }
500               else if (lex_match_id (lexer, "FREE"))
501                 {
502                   data_parser_set_span (parser, true);
503                 }
504               else if (lex_match_id (lexer, "UPPER"))
505                 {
506                   mformat.triangle = UPPER;
507                 }
508               else if (lex_match_id (lexer, "LOWER"))
509                 {
510                   mformat.triangle = LOWER;
511                 }
512               else if (lex_match_id (lexer, "FULL"))
513                 {
514                   mformat.triangle = FULL;
515                 }
516               else if (lex_match_id (lexer, "DIAGONAL"))
517                 {
518                   mformat.diagonal = DIAGONAL;
519                 }
520               else if (lex_match_id (lexer, "NODIAGONAL"))
521                 {
522                   mformat.diagonal = NO_DIAGONAL;
523                 }
524               else
525                 {
526                   lex_error (lexer, NULL);
527                   goto error;
528                 }
529             }
530         }
531       else if (lex_match_id (lexer, "FILE"))
532         {
533           lex_match (lexer, T_EQUALS);
534           fh_unref (fh);
535           fh = fh_parse (lexer, FH_REF_FILE | FH_REF_INLINE, NULL);
536           if (fh == NULL)
537             goto error;
538         }
539       else if (lex_match_id (lexer, "SPLIT"))
540         {
541           lex_match (lexer, T_EQUALS);
542           if (! parse_variables (lexer, dict, &mformat.split_vars, &mformat.n_split_vars, 0))
543             {
544               free (mformat.split_vars);
545               goto error;
546             }
547           int i;
548           for (i = 0; i < mformat.n_split_vars; ++i)
549             {
550               const struct fmt_spec fmt = fmt_for_input (FMT_F, 4, 0);
551               var_set_both_formats (mformat.split_vars[i], &fmt);
552             }
553           dict_reorder_vars (dict, mformat.split_vars, mformat.n_split_vars);
554           mformat.n_continuous_vars -= mformat.n_split_vars;
555         }
556       else
557         {
558           lex_error (lexer, NULL);
559           goto error;
560         }
561     }
562
563   if (mformat.diagonal == NO_DIAGONAL && mformat.triangle == FULL)
564     {
565       msg (SE, _("FORMAT = FULL and FORMAT = NODIAGONAL are mutually exclusive."));
566       goto error;
567     }
568
569   if (fh == NULL)
570     fh = fh_inline_file ();
571   fh_set_default_handle (fh);
572
573   if (!data_parser_any_fields (parser))
574     {
575       msg (SE, _("At least one variable must be specified."));
576       goto error;
577     }
578
579   if (lex_end_of_command (lexer) != CMD_SUCCESS)
580     goto error;
581
582   reader = dfm_open_reader (fh, lexer, encoding);
583   if (reader == NULL)
584     goto error;
585
586   if (in_input_program ())
587     {
588       struct data_list_trns *trns = xmalloc (sizeof *trns);
589       trns->parser = parser;
590       trns->reader = reader;
591       trns->end = NULL;
592       add_transformation (ds, data_list_trns_proc, data_list_trns_free, trns);
593     }
594   else
595     {
596       data_parser_make_active_file (parser, ds, reader, dict, preprocess,
597                                     &mformat);
598     }
599
600   fh_unref (fh);
601   free (encoding);
602   free (mformat.split_vars);
603
604   return CMD_DATA_LIST;
605
606  error:
607   data_parser_destroy (parser);
608   if (!in_input_program ())
609     dict_unref (dict);
610   fh_unref (fh);
611   free (encoding);
612   free (mformat.split_vars);
613   return CMD_CASCADING_FAILURE;
614 }
615
616 \f
617 /* Input procedure. */
618
619 /* Destroys DATA LIST transformation TRNS.
620    Returns true if successful, false if an I/O error occurred. */
621 static bool
622 data_list_trns_free (void *trns_)
623 {
624   struct data_list_trns *trns = trns_;
625   data_parser_destroy (trns->parser);
626   dfm_close_reader (trns->reader);
627   free (trns);
628   return true;
629 }
630
631 /* Handle DATA LIST transformation TRNS, parsing data into *C. */
632 static int
633 data_list_trns_proc (void *trns_, struct ccase **c, casenumber case_num UNUSED)
634 {
635   struct data_list_trns *trns = trns_;
636   int retval;
637
638   *c = case_unshare (*c);
639   if (data_parser_parse (trns->parser, trns->reader, *c))
640     retval = TRNS_CONTINUE;
641   else if (dfm_reader_error (trns->reader) || dfm_eof (trns->reader) > 1)
642     {
643       /* An I/O error, or encountering end of file for a second
644          time, should be escalated into a more serious error. */
645       retval = TRNS_ERROR;
646     }
647   else
648     retval = TRNS_END_FILE;
649
650   /* If there was an END subcommand handle it. */
651   if (trns->end != NULL)
652     {
653       double *end = &case_data_rw (*c, trns->end)->f;
654       if (retval == TRNS_END_FILE)
655         {
656           *end = 1.0;
657           retval = TRNS_CONTINUE;
658         }
659       else
660         *end = 0.0;
661     }
662
663   return retval;
664 }