FACTOR: Allow an option '=' after MATRIX
[pspp] / src / language / data-io / matrix-data.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2017 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "data/case.h"
20 #include "data/casereader.h"
21 #include "data/casewriter.h"
22 #include "data/dataset.h"
23 #include "data/dictionary.h"
24 #include "data/format.h"
25 #include "data/transformations.h"
26 #include "data/variable.h"
27 #include "language/command.h"
28 #include "language/data-io/data-parser.h"
29 #include "language/data-io/data-reader.h"
30 #include "language/data-io/file-handle.h"
31 #include "language/data-io/inpt-pgm.h"
32 #include "language/data-io/placement-parser.h"
33 #include "language/lexer/lexer.h"
34 #include "language/lexer/variable-parser.h"
35 #include "libpspp/i18n.h"
36 #include "libpspp/message.h"
37
38 #include "gl/xsize.h"
39 #include "gl/xalloc.h"
40
41 #include "gettext.h"
42 #define _(msgid) gettext (msgid)
43 \f
44 /* DATA LIST transformation data. */
45 struct data_list_trns
46   {
47     struct data_parser *parser; /* Parser. */
48     struct dfm_reader *reader;  /* Data file reader. */
49     struct variable *end;       /* Variable specified on END subcommand. */
50   };
51
52 static trns_free_func data_list_trns_free;
53 static trns_proc_func data_list_trns_proc;
54
55 enum diagonal
56   {
57     DIAGONAL,
58     NO_DIAGONAL
59   };
60
61 enum triangle
62   {
63     LOWER,
64     UPPER,
65     FULL
66   };
67
68 struct matrix_format
69 {
70   enum triangle triangle;
71   enum diagonal diagonal;
72   const struct variable *rowtype;
73   const struct variable *varname;
74   int n_continuous_vars;
75   struct variable **split_vars;
76   size_t n_split_vars;
77 };
78
79 /*
80 valid rowtype_ values:
81   CORR,
82   COV,
83   MAT,
84
85
86   MSE,
87   DFE,
88   MEAN,
89   STDDEV (or SD),
90   N_VECTOR (or N),
91   N_SCALAR,
92   N_MATRIX,
93   COUNT,
94   PROX.
95 */
96
97 /* Sets the value of OUTCASE which corresponds to MFORMAT's varname variable
98    to the string STR. VAR must be of type string.
99  */
100 static void
101 set_varname_column (struct ccase *outcase, const struct matrix_format *mformat,
102      const char *str, int len)
103 {
104   const struct variable *var = mformat->varname;
105   uint8_t *s = value_str_rw (case_data_rw (outcase, var), len);
106
107   strncpy ((char *) s, str, len);
108 }
109
110
111 static struct casereader *
112 preprocess (struct casereader *casereader0, const struct dictionary *dict, void *aux)
113 {
114   struct matrix_format *mformat = aux;
115   const struct caseproto *proto = casereader_get_proto (casereader0);
116   struct casewriter *writer;
117   writer = autopaging_writer_create (proto);
118
119   double **matrices = NULL;
120   size_t n_splits = 0;
121
122   const size_t sizeof_matrix =
123     sizeof (double) * mformat->n_continuous_vars * mformat->n_continuous_vars;
124
125
126   /* Make an initial pass to populate our temporary matrix */
127   struct casereader *pass0 = casereader_clone (casereader0);
128   struct ccase *c;
129   unsigned int prev_split_hash = 1;
130   int row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ? 1 : 0;
131   for (; (c = casereader_read (pass0)) != NULL; case_unref (c))
132     {
133       int s;
134       unsigned int split_hash = 0;
135       for (s = 0; s < mformat->n_split_vars; ++s)
136         {
137           const struct variable *svar = mformat->split_vars[s];
138           const union value *sv = case_data (c, svar);
139           split_hash = value_hash (sv, var_get_width (svar), split_hash);
140         }
141
142       if (matrices == NULL || prev_split_hash != split_hash)
143         {
144           row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ?
145             1 : 0;
146
147           n_splits++;
148           matrices = xrealloc (matrices, sizeof (double*)  * n_splits);
149           matrices[n_splits - 1] = xmalloc (sizeof_matrix);
150         }
151
152       prev_split_hash = split_hash;
153
154       int c_offset = (mformat->triangle == UPPER) ? row : 0;
155       if (mformat->triangle == UPPER && mformat->diagonal == NO_DIAGONAL)
156         c_offset++;
157       const union value *v = case_data (c, mformat->rowtype);
158       const char *val = (const char *) value_str (v, 8);
159       if (0 == strncasecmp (val, "corr    ", 8) ||
160           0 == strncasecmp (val, "cov     ", 8))
161         {
162           int col;
163           for (col = c_offset; col < mformat->n_continuous_vars; ++col)
164             {
165               const struct variable *var =
166                 dict_get_var (dict,
167                               1 + col - c_offset +
168                               var_get_dict_index (mformat->varname));
169
170               double e = case_data (c, var)->f;
171               if (e == SYSMIS)
172                 continue;
173
174
175               (matrices[n_splits-1])[col + mformat->n_continuous_vars * row] = e;
176               (matrices[n_splits-1]) [row + mformat->n_continuous_vars * col] = e;
177             }
178           row++;
179         }
180     }
181   casereader_destroy (pass0);
182
183   /* Now make a second pass to fill in the other triangle from our
184      temporary matrix */
185   const int idx = var_get_dict_index (mformat->varname);
186   row = 0;
187   struct ccase *prev_case = NULL;
188   prev_split_hash = 1;
189   n_splits = 0;
190   for (; (c = casereader_read (casereader0)) != NULL; prev_case = c)
191     {
192       int s;
193       unsigned int split_hash = 0;
194       for (s = 0; s < mformat->n_split_vars; ++s)
195         {
196           const struct variable *svar = mformat->split_vars[s];
197           const union value *sv = case_data (c, svar);
198           split_hash = value_hash (sv, var_get_width (svar), split_hash);
199         }
200       if (prev_split_hash != split_hash)
201         {
202           n_splits++;
203           row = 0;
204         }
205
206       prev_split_hash = split_hash;
207
208       case_unref (prev_case);
209       struct ccase *outcase = case_create (proto);
210       case_copy (outcase, 0, c, 0, caseproto_get_n_widths (proto));
211       const union value *v = case_data (c, mformat->rowtype);
212       const char *val = (const char *) value_str (v, 8);
213       if (0 == strncasecmp (val, "corr    ", 8) ||
214           0 == strncasecmp (val, "cov     ", 8))
215         {
216           int col;
217           const struct variable *var = dict_get_var (dict, idx + 1 + row);
218           set_varname_column (outcase, mformat, var_get_name (var), 8);
219           value_copy (case_data_rw (outcase, mformat->rowtype), v, 8);
220
221           for (col = 0; col < mformat->n_continuous_vars; ++col)
222             {
223               union value *dest_val =
224                 case_data_rw_idx (outcase,
225                                   1 + col + var_get_dict_index (mformat->varname));
226               dest_val->f = (matrices[n_splits - 1])[col + mformat->n_continuous_vars * row];
227               if (col == row && mformat->diagonal == NO_DIAGONAL)
228                 dest_val->f = 1.0;
229             }
230           row++;
231         }
232       else
233         {
234           set_varname_column (outcase, mformat, "        ", 8);
235         }
236
237       /* Special case for SD and N_VECTOR: Rewrite as STDDEV and N respectively */
238       if (0 == strncasecmp (val, "sd      ", 8))
239         {
240           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), 8,
241                                (uint8_t *) "STDDEV", 6, ' ');
242         }
243       else if (0 == strncasecmp (val, "n_vector", 8))
244         {
245           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), 8,
246                                (uint8_t *) "N", 1, ' ');
247         }
248
249       casewriter_write (writer, outcase);
250     }
251
252   /* If NODIAGONAL is specified, then a final case must be written */
253   if (mformat->diagonal == NO_DIAGONAL)
254     {
255       int col;
256       struct ccase *outcase = case_create (proto);
257
258       if (prev_case)
259         case_copy (outcase, 0, prev_case, 0, caseproto_get_n_widths (proto));
260
261
262       const struct variable *var = dict_get_var (dict, idx + 1 + row);
263       set_varname_column (outcase, mformat, var_get_name (var), 8);
264
265       for (col = 0; col < mformat->n_continuous_vars; ++col)
266         {
267           union value *dest_val =
268             case_data_rw_idx (outcase, 1 + col +
269                               var_get_dict_index (mformat->varname));
270           dest_val->f = (matrices[n_splits - 1]) [col + mformat->n_continuous_vars * row];
271           if (col == row && mformat->diagonal == NO_DIAGONAL)
272             dest_val->f = 1.0;
273         }
274
275       casewriter_write (writer, outcase);
276     }
277
278   if (prev_case)
279     case_unref (prev_case);
280
281   int i;
282   for (i = 0 ; i < n_splits; ++i)
283     free (matrices[i]);
284   free (matrices);
285   struct casereader *reader1 = casewriter_make_reader (writer);
286   casereader_destroy (casereader0);
287   return reader1;
288 }
289
290 int
291 cmd_matrix (struct lexer *lexer, struct dataset *ds)
292 {
293   struct dictionary *dict;
294   struct data_parser *parser;
295   struct dfm_reader *reader;
296   struct file_handle *fh = NULL;
297   char *encoding = NULL;
298   struct matrix_format mformat;
299   int i;
300   size_t n_names;
301   char **names = NULL;
302
303   mformat.triangle = LOWER;
304   mformat.diagonal = DIAGONAL;
305   mformat.n_split_vars = 0;
306   mformat.split_vars = NULL;
307
308   dict = (in_input_program ()
309           ? dataset_dict (ds)
310           : dict_create (get_default_encoding ()));
311   parser = data_parser_create (dict);
312   reader = NULL;
313
314   data_parser_set_type (parser, DP_DELIMITED);
315   data_parser_set_warn_missing_fields (parser, false);
316   data_parser_set_span (parser, false);
317
318   mformat.rowtype = dict_create_var (dict, "ROWTYPE_", 8);
319   mformat.varname = dict_create_var (dict, "VARNAME_", 8);
320
321   mformat.n_continuous_vars = 0;
322   mformat.n_split_vars = 0;
323
324   if (! lex_force_match_id (lexer, "VARIABLES"))
325     goto error;
326
327   lex_match (lexer, T_EQUALS);
328
329   if (! parse_mixed_vars (lexer, dict, &names, &n_names, PV_NO_DUPLICATE))
330     {
331       int i;
332       for (i = 0; i < n_names; ++i)
333         free (names[i]);
334       free (names);
335       goto error;
336     }
337
338   for (i = 0; i < n_names; ++i)
339     {
340       if (0 == strcasecmp (names[i], "ROWTYPE_"))
341         {
342           const struct fmt_spec fmt = fmt_for_input (FMT_A, 8, 0);
343           data_parser_add_delimited_field (parser,
344                                            &fmt,
345                                            var_get_case_index (mformat.rowtype),
346                                            "ROWTYPE_");
347         }
348       else
349         {
350           const struct fmt_spec fmt = fmt_for_input (FMT_F, 10, 4);
351           struct variable *v = dict_create_var (dict, names[i], 0);
352           var_set_both_formats (v, &fmt);
353           data_parser_add_delimited_field (parser,
354                                            &fmt,
355                                            var_get_case_index (mformat.varname) +
356                                            ++mformat.n_continuous_vars,
357                                            names[i]);
358         }
359     }
360   for (i = 0; i < n_names; ++i)
361     free (names[i]);
362   free (names);
363
364   while (lex_token (lexer) != T_ENDCMD)
365     {
366       if (! lex_force_match (lexer, T_SLASH))
367         goto error;
368
369       if (lex_match_id (lexer, "FORMAT"))
370         {
371           lex_match (lexer, T_EQUALS);
372
373           while (lex_token (lexer) != T_SLASH && (lex_token (lexer) != T_ENDCMD))
374             {
375               if (lex_match_id (lexer, "LIST"))
376                 {
377                   data_parser_set_span (parser, false);
378                 }
379               else if (lex_match_id (lexer, "FREE"))
380                 {
381                   data_parser_set_span (parser, true);
382                 }
383               else if (lex_match_id (lexer, "UPPER"))
384                 {
385                   mformat.triangle = UPPER;
386                 }
387               else if (lex_match_id (lexer, "LOWER"))
388                 {
389                   mformat.triangle = LOWER;
390                 }
391               else if (lex_match_id (lexer, "FULL"))
392                 {
393                   mformat.triangle = FULL;
394                 }
395               else if (lex_match_id (lexer, "DIAGONAL"))
396                 {
397                   mformat.diagonal = DIAGONAL;
398                 }
399               else if (lex_match_id (lexer, "NODIAGONAL"))
400                 {
401                   mformat.diagonal = NO_DIAGONAL;
402                 }
403               else
404                 {
405                   lex_error (lexer, NULL);
406                   goto error;
407                 }
408             }
409         }
410       else if (lex_match_id (lexer, "FILE"))
411         {
412           lex_match (lexer, T_EQUALS);
413           fh_unref (fh);
414           fh = fh_parse (lexer, FH_REF_FILE | FH_REF_INLINE, NULL);
415           if (fh == NULL)
416             goto error;
417         }
418       else if (lex_match_id (lexer, "SPLIT"))
419         {
420           lex_match (lexer, T_EQUALS);
421           if (! parse_variables (lexer, dict, &mformat.split_vars, &mformat.n_split_vars, 0))
422             {
423               free (mformat.split_vars);
424               goto error;
425             }
426           int i;
427           for (i = 0; i < mformat.n_split_vars; ++i)
428             {
429               const struct fmt_spec fmt = fmt_for_input (FMT_F, 4, 0);
430               var_set_both_formats (mformat.split_vars[i], &fmt);
431             }
432           dict_reorder_vars (dict, mformat.split_vars, mformat.n_split_vars);
433           mformat.n_continuous_vars -= mformat.n_split_vars;
434         }
435       else
436         {
437           lex_error (lexer, NULL);
438           goto error;
439         }
440     }
441
442   if (mformat.diagonal == NO_DIAGONAL && mformat.triangle == FULL)
443     {
444       msg (SE, _("FORMAT = FULL and FORMAT = NODIAGONAL are mutually exclusive."));
445       goto error;
446     }
447
448   if (fh == NULL)
449     fh = fh_inline_file ();
450   fh_set_default_handle (fh);
451
452   if (!data_parser_any_fields (parser))
453     {
454       msg (SE, _("At least one variable must be specified."));
455       goto error;
456     }
457
458   if (lex_end_of_command (lexer) != CMD_SUCCESS)
459     goto error;
460
461   reader = dfm_open_reader (fh, lexer, encoding);
462   if (reader == NULL)
463     goto error;
464
465   if (in_input_program ())
466     {
467       struct data_list_trns *trns = xmalloc (sizeof *trns);
468       trns->parser = parser;
469       trns->reader = reader;
470       trns->end = NULL;
471       add_transformation (ds, data_list_trns_proc, data_list_trns_free, trns);
472     }
473   else
474     {
475       data_parser_make_active_file (parser, ds, reader, dict, preprocess, &mformat);
476     }
477
478   fh_unref (fh);
479   free (encoding);
480   free (mformat.split_vars);
481
482   return CMD_DATA_LIST;
483
484  error:
485   data_parser_destroy (parser);
486   if (!in_input_program ())
487     dict_destroy (dict);
488   fh_unref (fh);
489   free (encoding);
490   free (mformat.split_vars);
491   return CMD_CASCADING_FAILURE;
492 }
493
494 \f
495 /* Input procedure. */
496
497 /* Destroys DATA LIST transformation TRNS.
498    Returns true if successful, false if an I/O error occurred. */
499 static bool
500 data_list_trns_free (void *trns_)
501 {
502   struct data_list_trns *trns = trns_;
503   data_parser_destroy (trns->parser);
504   dfm_close_reader (trns->reader);
505   free (trns);
506   return true;
507 }
508
509 /* Handle DATA LIST transformation TRNS, parsing data into *C. */
510 static int
511 data_list_trns_proc (void *trns_, struct ccase **c, casenumber case_num UNUSED)
512 {
513   struct data_list_trns *trns = trns_;
514   int retval;
515
516   *c = case_unshare (*c);
517   if (data_parser_parse (trns->parser, trns->reader, *c))
518     retval = TRNS_CONTINUE;
519   else if (dfm_reader_error (trns->reader) || dfm_eof (trns->reader) > 1)
520     {
521       /* An I/O error, or encountering end of file for a second
522          time, should be escalated into a more serious error. */
523       retval = TRNS_ERROR;
524     }
525   else
526     retval = TRNS_END_FILE;
527
528   /* If there was an END subcommand handle it. */
529   if (trns->end != NULL)
530     {
531       double *end = &case_data_rw (*c, trns->end)->f;
532       if (retval == TRNS_END_FILE)
533         {
534           *end = 1.0;
535           retval = TRNS_CONTINUE;
536         }
537       else
538         *end = 0.0;
539     }
540
541   return retval;
542 }
543