609e5eb1fd5156365093ac9c70851ab2dbcfce5b
[pspp] / src / language / data-io / matrix-data.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2017 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "data/case.h"
20 #include "data/casereader.h"
21 #include "data/casewriter.h"
22 #include "data/dataset.h"
23 #include "data/dictionary.h"
24 #include "data/format.h"
25 #include "data/transformations.h"
26 #include "data/variable.h"
27 #include "language/command.h"
28 #include "language/data-io/data-parser.h"
29 #include "language/data-io/data-reader.h"
30 #include "language/data-io/file-handle.h"
31 #include "language/data-io/inpt-pgm.h"
32 #include "language/data-io/placement-parser.h"
33 #include "language/lexer/lexer.h"
34 #include "language/lexer/variable-parser.h"
35 #include "libpspp/i18n.h"
36 #include "libpspp/message.h"
37 #include "libpspp/misc.h"
38
39 #include "gl/xsize.h"
40 #include "gl/xalloc.h"
41
42 #include "gettext.h"
43 #define _(msgid) gettext (msgid)
44 \f
45 /* DATA LIST transformation data. */
46 struct data_list_trns
47   {
48     struct data_parser *parser; /* Parser. */
49     struct dfm_reader *reader;  /* Data file reader. */
50     struct variable *end;       /* Variable specified on END subcommand. */
51   };
52
53 static trns_free_func data_list_trns_free;
54 static trns_proc_func data_list_trns_proc;
55
56 enum diagonal
57   {
58     DIAGONAL,
59     NO_DIAGONAL
60   };
61
62 enum triangle
63   {
64     LOWER,
65     UPPER,
66     FULL
67   };
68
69 static const int ROWTYPE_WIDTH = 8;
70
71 struct matrix_format
72 {
73   enum triangle triangle;
74   enum diagonal diagonal;
75   const struct variable *rowtype;
76   const struct variable *varname;
77   int n_continuous_vars;
78   struct variable **split_vars;
79   size_t n_split_vars;
80   long n;
81 };
82
83 /*
84 valid rowtype_ values:
85   CORR,
86   COV,
87   MAT,
88
89
90   MSE,
91   DFE,
92   MEAN,
93   STDDEV (or SD),
94   N_VECTOR (or N),
95   N_SCALAR,
96   N_MATRIX,
97   COUNT,
98   PROX.
99 */
100
101 /* Sets the value of OUTCASE which corresponds to VNAME
102    to the value STR.  VNAME must be of type string.
103  */
104 static void
105 set_varname_column (struct ccase *outcase, const struct variable *vname,
106      const char *str)
107 {
108   int len = var_get_width (vname);
109   uint8_t *s = value_str_rw (case_data_rw (outcase, vname), len);
110
111   strncpy ((char *) s, str, len);
112 }
113
114 static void
115 blank_varname_column (struct ccase *outcase, const struct variable *vname)
116 {
117   int len = var_get_width (vname);
118   uint8_t *s = value_str_rw (case_data_rw (outcase, vname), len);
119
120   memset (s, ' ', len);
121 }
122
123 static struct casereader *
124 preprocess (struct casereader *casereader0, const struct dictionary *dict, void *aux)
125 {
126   struct matrix_format *mformat = aux;
127   const struct caseproto *proto = casereader_get_proto (casereader0);
128   struct casewriter *writer;
129   writer = autopaging_writer_create (proto);
130   struct ccase *prev_case = NULL;
131   double **matrices = NULL;
132   size_t n_splits = 0;
133
134   const size_t sizeof_matrix =
135     sizeof (double) * mformat->n_continuous_vars * mformat->n_continuous_vars;
136
137
138   /* Make an initial pass to populate our temporary matrix */
139   struct casereader *pass0 = casereader_clone (casereader0);
140   struct ccase *c;
141   unsigned int prev_split_hash = 1;
142   int row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ? 1 : 0;
143   for (; (c = casereader_read (pass0)) != NULL; case_unref (c))
144     {
145       int s;
146       unsigned int split_hash = 0;
147       for (s = 0; s < mformat->n_split_vars; ++s)
148         {
149           const struct variable *svar = mformat->split_vars[s];
150           const union value *sv = case_data (c, svar);
151           split_hash = value_hash (sv, var_get_width (svar), split_hash);
152         }
153
154       if (matrices == NULL || prev_split_hash != split_hash)
155         {
156           row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ?
157             1 : 0;
158
159           n_splits++;
160           matrices = xrealloc (matrices, sizeof (double*)  * n_splits);
161           matrices[n_splits - 1] = xmalloc (sizeof_matrix);
162         }
163
164       prev_split_hash = split_hash;
165
166       int c_offset = (mformat->triangle == UPPER) ? row : 0;
167       if (mformat->triangle == UPPER && mformat->diagonal == NO_DIAGONAL)
168         c_offset++;
169       const union value *v = case_data (c, mformat->rowtype);
170       const char *val = (const char *) value_str (v, ROWTYPE_WIDTH);
171       if (0 == strncasecmp (val, "corr    ", ROWTYPE_WIDTH) ||
172           0 == strncasecmp (val, "cov     ", ROWTYPE_WIDTH))
173         {
174           if (row >= mformat->n_continuous_vars)
175             {
176               msg (SE,
177                    _("There are %d variable declared but the data has at least %d matrix rows."),
178                    mformat->n_continuous_vars, row + 1);
179               goto error;
180             }
181           int col;
182           for (col = c_offset; col < mformat->n_continuous_vars; ++col)
183             {
184               const struct variable *var =
185                 dict_get_var (dict,
186                               1 + col - c_offset +
187                               var_get_dict_index (mformat->varname));
188
189               double e = case_data (c, var)->f;
190               if (e == SYSMIS)
191                 continue;
192
193               /* Fill in the lower triangle */
194               (matrices[n_splits-1])[col + mformat->n_continuous_vars * row] = e;
195
196               if (mformat->triangle != FULL)
197                 /* Fill in the upper triangle */
198                 (matrices[n_splits-1]) [row + mformat->n_continuous_vars * col] = e;
199             }
200           row++;
201         }
202     }
203   casereader_destroy (pass0);
204
205   /* Now make a second pass to fill in the other triangle from our
206      temporary matrix */
207   const int idx = var_get_dict_index (mformat->varname);
208   row = 0;
209
210   if (mformat->n >= 0)
211     {
212       int col;
213       struct ccase *outcase = case_create (proto);
214       union value *v = case_data_rw (outcase, mformat->rowtype);
215       uint8_t *n = value_str_rw (v, ROWTYPE_WIDTH);
216       strncpy ((char *) n, "N        ", ROWTYPE_WIDTH);
217       blank_varname_column (outcase, mformat->varname);
218       for (col = 0; col < mformat->n_continuous_vars; ++col)
219         {
220           union value *dest_val =
221             case_data_rw_idx (outcase,
222                               1 + col + var_get_dict_index (mformat->varname));
223           dest_val->f = mformat->n;
224         }
225       casewriter_write (writer, outcase);
226     }
227
228   prev_split_hash = 1;
229   n_splits = 0;
230   for (; (c = casereader_read (casereader0)) != NULL; prev_case = c)
231     {
232       int s;
233       unsigned int split_hash = 0;
234       for (s = 0; s < mformat->n_split_vars; ++s)
235         {
236           const struct variable *svar = mformat->split_vars[s];
237           const union value *sv = case_data (c, svar);
238           split_hash = value_hash (sv, var_get_width (svar), split_hash);
239         }
240       if (prev_split_hash != split_hash)
241         {
242           n_splits++;
243           row = 0;
244         }
245
246       prev_split_hash = split_hash;
247       case_unref (prev_case);
248       const union value *v = case_data (c, mformat->rowtype);
249       const char *val = (const char *) value_str (v, ROWTYPE_WIDTH);
250       if (mformat->n >= 0)
251         {
252           if (0 == strncasecmp (val, "n       ", ROWTYPE_WIDTH) ||
253               0 == strncasecmp (val, "n_vector", ROWTYPE_WIDTH))
254             {
255               msg (SW,
256                    _("The N subcommand was specified, but a N record was also found in the data.  The N record will be ignored."));
257               continue;
258             }
259         }
260
261       struct ccase *outcase = case_create (proto);
262       case_copy (outcase, 0, c, 0, caseproto_get_n_widths (proto));
263
264       if (0 == strncasecmp (val, "corr    ", ROWTYPE_WIDTH) ||
265           0 == strncasecmp (val, "cov     ", ROWTYPE_WIDTH))
266         {
267           int col;
268           const struct variable *var = dict_get_var (dict, idx + 1 + row);
269           set_varname_column (outcase, mformat->varname, var_get_name (var));
270           value_copy (case_data_rw (outcase, mformat->rowtype), v, ROWTYPE_WIDTH);
271
272           for (col = 0; col < mformat->n_continuous_vars; ++col)
273             {
274               union value *dest_val =
275                 case_data_rw_idx (outcase,
276                                   1 + col + var_get_dict_index (mformat->varname));
277               dest_val->f = (matrices[n_splits - 1])[col + mformat->n_continuous_vars * row];
278               if (col == row && mformat->diagonal == NO_DIAGONAL)
279                 dest_val->f = 1.0;
280             }
281           row++;
282         }
283       else
284         {
285           blank_varname_column (outcase, mformat->varname);
286         }
287
288       /* Special case for SD and N_VECTOR: Rewrite as STDDEV and N respectively */
289       if (0 == strncasecmp (val, "sd      ", ROWTYPE_WIDTH))
290         {
291           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), ROWTYPE_WIDTH,
292                                (uint8_t *) "STDDEV", 6, ' ');
293         }
294       else if (0 == strncasecmp (val, "n_vector", ROWTYPE_WIDTH))
295         {
296           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), ROWTYPE_WIDTH,
297                                (uint8_t *) "N", 1, ' ');
298         }
299
300       casewriter_write (writer, outcase);
301     }
302
303   /* If NODIAGONAL is specified, then a final case must be written */
304   if (mformat->diagonal == NO_DIAGONAL)
305     {
306       int col;
307       struct ccase *outcase = case_create (proto);
308
309       if (prev_case)
310         case_copy (outcase, 0, prev_case, 0, caseproto_get_n_widths (proto));
311
312       const struct variable *var = dict_get_var (dict, idx + 1 + row);
313       set_varname_column (outcase, mformat->varname, var_get_name (var));
314
315       for (col = 0; col < mformat->n_continuous_vars; ++col)
316         {
317           union value *dest_val =
318             case_data_rw_idx (outcase, 1 + col +
319                               var_get_dict_index (mformat->varname));
320           dest_val->f = (matrices[n_splits - 1]) [col + mformat->n_continuous_vars * row];
321           if (col == row && mformat->diagonal == NO_DIAGONAL)
322             dest_val->f = 1.0;
323         }
324
325       casewriter_write (writer, outcase);
326     }
327
328
329   if (prev_case)
330     case_unref (prev_case);
331
332   int i;
333   for (i = 0 ; i < n_splits; ++i)
334     free (matrices[i]);
335   free (matrices);
336   struct casereader *reader1 = casewriter_make_reader (writer);
337   casereader_destroy (casereader0);
338   return reader1;
339
340
341 error:
342   if (prev_case)
343     case_unref (prev_case);
344
345   for (i = 0 ; i < n_splits; ++i)
346     free (matrices[i]);
347   free (matrices);
348   casereader_destroy (casereader0);
349   return NULL;
350 }
351
352 int
353 cmd_matrix (struct lexer *lexer, struct dataset *ds)
354 {
355   struct dictionary *dict;
356   struct data_parser *parser;
357   struct dfm_reader *reader;
358   struct file_handle *fh = NULL;
359   char *encoding = NULL;
360   struct matrix_format mformat;
361   int i;
362   size_t n_names;
363   char **names = NULL;
364
365   mformat.triangle = LOWER;
366   mformat.diagonal = DIAGONAL;
367   mformat.n_split_vars = 0;
368   mformat.split_vars = NULL;
369   mformat.n = -1;
370
371   dict = (in_input_program ()
372           ? dataset_dict (ds)
373           : dict_create (get_default_encoding ()));
374   parser = data_parser_create (dict);
375   reader = NULL;
376
377   data_parser_set_type (parser, DP_DELIMITED);
378   data_parser_set_warn_missing_fields (parser, false);
379   data_parser_set_span (parser, false);
380
381   mformat.rowtype = dict_create_var (dict, "ROWTYPE_", ROWTYPE_WIDTH);
382
383   mformat.n_continuous_vars = 0;
384   mformat.n_split_vars = 0;
385
386   if (! lex_force_match_id (lexer, "VARIABLES"))
387     goto error;
388
389   lex_match (lexer, T_EQUALS);
390
391   if (! parse_mixed_vars (lexer, dict, &names, &n_names, PV_NO_DUPLICATE))
392     {
393       int i;
394       for (i = 0; i < n_names; ++i)
395         free (names[i]);
396       free (names);
397       goto error;
398     }
399
400   int longest_name = 0;
401   for (i = 0; i < n_names; ++i)
402     {
403       maximize_int (&longest_name, strlen (names[i]));
404     }
405
406   mformat.varname = dict_create_var (dict, "VARNAME_",
407                                      8 * DIV_RND_UP (longest_name, 8));
408
409   for (i = 0; i < n_names; ++i)
410     {
411       if (0 == strcasecmp (names[i], "ROWTYPE_"))
412         {
413           const struct fmt_spec fmt = fmt_for_input (FMT_A, 8, 0);
414           data_parser_add_delimited_field (parser,
415                                            &fmt,
416                                            var_get_case_index (mformat.rowtype),
417                                            "ROWTYPE_");
418         }
419       else
420         {
421           const struct fmt_spec fmt = fmt_for_input (FMT_F, 10, 4);
422           struct variable *v = dict_create_var (dict, names[i], 0);
423           var_set_both_formats (v, &fmt);
424           data_parser_add_delimited_field (parser,
425                                            &fmt,
426                                            var_get_case_index (mformat.varname) +
427                                            ++mformat.n_continuous_vars,
428                                            names[i]);
429         }
430     }
431   for (i = 0; i < n_names; ++i)
432     free (names[i]);
433   free (names);
434
435   while (lex_token (lexer) != T_ENDCMD)
436     {
437       if (! lex_force_match (lexer, T_SLASH))
438         goto error;
439
440       if (lex_match_id (lexer, "N"))
441         {
442           lex_match (lexer, T_EQUALS);
443
444           if (! lex_force_int (lexer))
445             goto error;
446
447           mformat.n = lex_integer (lexer);
448           if (mformat.n < 0)
449             {
450               msg (SE, _("%s must not be negative."), "N");
451               goto error;
452             }
453           lex_get (lexer);
454         }
455       else if (lex_match_id (lexer, "FORMAT"))
456         {
457           lex_match (lexer, T_EQUALS);
458
459           while (lex_token (lexer) != T_SLASH && (lex_token (lexer) != T_ENDCMD))
460             {
461               if (lex_match_id (lexer, "LIST"))
462                 {
463                   data_parser_set_span (parser, false);
464                 }
465               else if (lex_match_id (lexer, "FREE"))
466                 {
467                   data_parser_set_span (parser, true);
468                 }
469               else if (lex_match_id (lexer, "UPPER"))
470                 {
471                   mformat.triangle = UPPER;
472                 }
473               else if (lex_match_id (lexer, "LOWER"))
474                 {
475                   mformat.triangle = LOWER;
476                 }
477               else if (lex_match_id (lexer, "FULL"))
478                 {
479                   mformat.triangle = FULL;
480                 }
481               else if (lex_match_id (lexer, "DIAGONAL"))
482                 {
483                   mformat.diagonal = DIAGONAL;
484                 }
485               else if (lex_match_id (lexer, "NODIAGONAL"))
486                 {
487                   mformat.diagonal = NO_DIAGONAL;
488                 }
489               else
490                 {
491                   lex_error (lexer, NULL);
492                   goto error;
493                 }
494             }
495         }
496       else if (lex_match_id (lexer, "FILE"))
497         {
498           lex_match (lexer, T_EQUALS);
499           fh_unref (fh);
500           fh = fh_parse (lexer, FH_REF_FILE | FH_REF_INLINE, NULL);
501           if (fh == NULL)
502             goto error;
503         }
504       else if (lex_match_id (lexer, "SPLIT"))
505         {
506           lex_match (lexer, T_EQUALS);
507           if (! parse_variables (lexer, dict, &mformat.split_vars, &mformat.n_split_vars, 0))
508             {
509               free (mformat.split_vars);
510               goto error;
511             }
512           int i;
513           for (i = 0; i < mformat.n_split_vars; ++i)
514             {
515               const struct fmt_spec fmt = fmt_for_input (FMT_F, 4, 0);
516               var_set_both_formats (mformat.split_vars[i], &fmt);
517             }
518           dict_reorder_vars (dict, mformat.split_vars, mformat.n_split_vars);
519           mformat.n_continuous_vars -= mformat.n_split_vars;
520         }
521       else
522         {
523           lex_error (lexer, NULL);
524           goto error;
525         }
526     }
527
528   if (mformat.diagonal == NO_DIAGONAL && mformat.triangle == FULL)
529     {
530       msg (SE, _("FORMAT = FULL and FORMAT = NODIAGONAL are mutually exclusive."));
531       goto error;
532     }
533
534   if (fh == NULL)
535     fh = fh_inline_file ();
536   fh_set_default_handle (fh);
537
538   if (!data_parser_any_fields (parser))
539     {
540       msg (SE, _("At least one variable must be specified."));
541       goto error;
542     }
543
544   if (lex_end_of_command (lexer) != CMD_SUCCESS)
545     goto error;
546
547   reader = dfm_open_reader (fh, lexer, encoding);
548   if (reader == NULL)
549     goto error;
550
551   if (in_input_program ())
552     {
553       struct data_list_trns *trns = xmalloc (sizeof *trns);
554       trns->parser = parser;
555       trns->reader = reader;
556       trns->end = NULL;
557       add_transformation (ds, data_list_trns_proc, data_list_trns_free, trns);
558     }
559   else
560     {
561       data_parser_make_active_file (parser, ds, reader, dict, preprocess, &mformat);
562     }
563
564   fh_unref (fh);
565   free (encoding);
566   free (mformat.split_vars);
567
568   return CMD_DATA_LIST;
569
570  error:
571   data_parser_destroy (parser);
572   if (!in_input_program ())
573     dict_destroy (dict);
574   fh_unref (fh);
575   free (encoding);
576   free (mformat.split_vars);
577   return CMD_CASCADING_FAILURE;
578 }
579
580 \f
581 /* Input procedure. */
582
583 /* Destroys DATA LIST transformation TRNS.
584    Returns true if successful, false if an I/O error occurred. */
585 static bool
586 data_list_trns_free (void *trns_)
587 {
588   struct data_list_trns *trns = trns_;
589   data_parser_destroy (trns->parser);
590   dfm_close_reader (trns->reader);
591   free (trns);
592   return true;
593 }
594
595 /* Handle DATA LIST transformation TRNS, parsing data into *C. */
596 static int
597 data_list_trns_proc (void *trns_, struct ccase **c, casenumber case_num UNUSED)
598 {
599   struct data_list_trns *trns = trns_;
600   int retval;
601
602   *c = case_unshare (*c);
603   if (data_parser_parse (trns->parser, trns->reader, *c))
604     retval = TRNS_CONTINUE;
605   else if (dfm_reader_error (trns->reader) || dfm_eof (trns->reader) > 1)
606     {
607       /* An I/O error, or encountering end of file for a second
608          time, should be escalated into a more serious error. */
609       retval = TRNS_ERROR;
610     }
611   else
612     retval = TRNS_END_FILE;
613
614   /* If there was an END subcommand handle it. */
615   if (trns->end != NULL)
616     {
617       double *end = &case_data_rw (*c, trns->end)->f;
618       if (retval == TRNS_END_FILE)
619         {
620           *end = 1.0;
621           retval = TRNS_CONTINUE;
622         }
623       else
624         *end = 0.0;
625     }
626
627   return retval;
628 }
629