MATRIX DATA: Rewrite to canonical rowtype values
[pspp] / src / language / data-io / matrix-data.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2017 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "data/case.h"
20 #include "data/casereader.h"
21 #include "data/casewriter.h"
22 #include "data/dataset.h"
23 #include "data/dictionary.h"
24 #include "data/format.h"
25 #include "data/transformations.h"
26 #include "data/variable.h"
27 #include "language/command.h"
28 #include "language/data-io/data-parser.h"
29 #include "language/data-io/data-reader.h"
30 #include "language/data-io/file-handle.h"
31 #include "language/data-io/inpt-pgm.h"
32 #include "language/data-io/placement-parser.h"
33 #include "language/lexer/lexer.h"
34 #include "language/lexer/variable-parser.h"
35 #include "libpspp/i18n.h"
36 #include "libpspp/message.h"
37
38 #include "gl/xsize.h"
39 #include "gl/xalloc.h"
40
41 #include "gettext.h"
42 #define _(msgid) gettext (msgid)
43 \f
44 /* DATA LIST transformation data. */
45 struct data_list_trns
46   {
47     struct data_parser *parser; /* Parser. */
48     struct dfm_reader *reader;  /* Data file reader. */
49     struct variable *end;       /* Variable specified on END subcommand. */
50   };
51
52 static trns_free_func data_list_trns_free;
53 static trns_proc_func data_list_trns_proc;
54
55 enum diagonal
56   {
57     DIAGONAL,
58     NO_DIAGONAL
59   };
60
61 enum triangle
62   {
63     LOWER,
64     UPPER,
65     FULL
66   };
67
68 struct matrix_format
69 {
70   enum triangle triangle;
71   enum diagonal diagonal;
72   const struct variable *rowtype;
73   const struct variable *varname;
74   int n_continuous_vars;
75 };
76
77 /*
78 valid rowtype_ values:
79   CORR,
80   COV,
81   MAT,
82
83
84   MSE,
85   DFE,
86   MEAN,
87   STDDEV (or SD),
88   N_VECTOR (or N),
89   N_SCALAR,
90   N_MATRIX,
91   COUNT,
92   PROX.
93 */
94
95 /* Sets the value of OUTCASE which corresponds to MFORMAT's varname variable
96    to the string STR. VAR must be of type string.
97  */
98 static void
99 set_varname_column (struct ccase *outcase, const struct matrix_format *mformat,
100      const char *str, int len)
101 {
102   const struct variable *var = mformat->varname;
103   uint8_t *s = value_str_rw (case_data_rw (outcase, var), len);
104
105   strncpy ((char *) s, str, len);
106 }
107
108
109 static struct casereader *
110 preprocess (struct casereader *casereader0, const struct dictionary *dict, void *aux)
111 {
112   struct matrix_format *mformat = aux;
113   const struct caseproto *proto = casereader_get_proto (casereader0);
114   struct casewriter *writer;
115   writer = autopaging_writer_create (proto);
116
117   double *temp_matrix =
118     xcalloc (sizeof (*temp_matrix),
119              mformat->n_continuous_vars * mformat->n_continuous_vars);
120
121   /* Make an initial pass to populate our temporary matrix */
122   struct casereader *pass0 = casereader_clone (casereader0);
123   struct ccase *c;
124   int row = (mformat->triangle == LOWER && mformat->diagonal == NO_DIAGONAL) ? 1 : 0;
125   for (; (c = casereader_read (pass0)) != NULL; case_unref (c))
126     {
127       int c_offset = (mformat->triangle == UPPER) ? row : 0;
128       if (mformat->triangle == UPPER && mformat->diagonal == NO_DIAGONAL)
129         c_offset++;
130       const union value *v = case_data (c, mformat->rowtype);
131       const char *val = (const char *) value_str (v, 8);
132       if (0 == strncasecmp (val, "corr    ", 8) ||
133           0 == strncasecmp (val, "cov     ", 8))
134         {
135           int col;
136           for (col = c_offset; col < mformat->n_continuous_vars; ++col)
137             {
138               const struct variable *var =
139                 dict_get_var (dict,
140                               1 + col - c_offset + var_get_dict_index (mformat->varname));
141
142               double e = case_data (c, var)->f;
143               if (e == SYSMIS)
144                 continue;
145               temp_matrix [col + mformat->n_continuous_vars * row] = e;
146               temp_matrix [row + mformat->n_continuous_vars * col] = e;
147             }
148           row++;
149         }
150     }
151   casereader_destroy (pass0);
152
153   /* Now make a second pass to fill in the other triangle from our
154      temporary matrix */
155   const int idx = var_get_dict_index (mformat->varname);
156   row = 0;
157   struct ccase *prev_case = NULL;
158   for (; (c = casereader_read (casereader0)) != NULL; prev_case = c)
159     {
160       case_unref (prev_case);
161       struct ccase *outcase = case_create (proto);
162       case_copy (outcase, 0, c, 0, caseproto_get_n_widths (proto));
163       const union value *v = case_data (c, mformat->rowtype);
164       const char *val = (const char *) value_str (v, 8);
165       if (0 == strncasecmp (val, "corr    ", 8) ||
166           0 == strncasecmp (val, "cov     ", 8))
167         {
168           int col;
169           const struct variable *var = dict_get_var (dict, idx + 1 + row);
170           set_varname_column (outcase, mformat, var_get_name (var), 8);
171           value_copy (case_data_rw (outcase, mformat->rowtype), v, 8);
172
173           for (col = 0; col < mformat->n_continuous_vars; ++col)
174             {
175               union value *dest_val =
176                 case_data_rw_idx (outcase,
177                                   1 + col + var_get_dict_index (mformat->varname));
178               dest_val->f = temp_matrix [col + mformat->n_continuous_vars * row];
179               if (col == row && mformat->diagonal == NO_DIAGONAL)
180                 dest_val->f = 1.0;
181             }
182           row++;
183         }
184       else
185         {
186           set_varname_column (outcase, mformat, "        ", 8);
187         }
188
189       /* Special case for SD and N_VECTOR: Rewrite as STDDEV and N respectively */
190       if (0 == strncasecmp (val, "sd      ", 8))
191         {
192           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), 8,
193                                (uint8_t *) "STDDEV", 6, ' ');
194         }
195       else if (0 == strncasecmp (val, "n_vector", 8))
196         {
197           value_copy_buf_rpad (case_data_rw (outcase, mformat->rowtype), 8,
198                                (uint8_t *) "N", 1, ' ');
199         }
200
201       casewriter_write (writer, outcase);
202     }
203
204   /* If NODIAGONAL is specified, then a final case must be written */
205   if (mformat->diagonal == NO_DIAGONAL)
206     {
207       int col;
208       struct ccase *outcase = case_create (proto);
209
210       if (prev_case)
211         case_copy (outcase, 0, prev_case, 0, caseproto_get_n_widths (proto));
212
213
214       const struct variable *var = dict_get_var (dict, idx + 1 + row);
215       set_varname_column (outcase, mformat, var_get_name (var), 8);
216
217       for (col = 0; col < mformat->n_continuous_vars; ++col)
218         {
219           union value *dest_val =
220             case_data_rw_idx (outcase, 1 + col +
221                               var_get_dict_index (mformat->varname));
222           dest_val->f = temp_matrix [col + mformat->n_continuous_vars * row];
223           if (col == row && mformat->diagonal == NO_DIAGONAL)
224             dest_val->f = 1.0;
225         }
226
227       casewriter_write (writer, outcase);
228     }
229
230   if (prev_case)
231     case_unref (prev_case);
232
233   free (temp_matrix);
234   struct casereader *reader1 = casewriter_make_reader (writer);
235   casereader_destroy (casereader0);
236   return reader1;
237 }
238
239 int
240 cmd_matrix (struct lexer *lexer, struct dataset *ds)
241 {
242   struct dictionary *dict;
243   struct data_parser *parser;
244   struct dfm_reader *reader;
245   struct file_handle *fh = NULL;
246   char *encoding = NULL;
247   struct matrix_format mformat;
248   int i;
249   size_t n_names;
250   char **names = NULL;
251
252   mformat.triangle = LOWER;
253   mformat.diagonal = DIAGONAL;
254
255   dict = (in_input_program ()
256           ? dataset_dict (ds)
257           : dict_create (get_default_encoding ()));
258   parser = data_parser_create (dict);
259   reader = NULL;
260
261   data_parser_set_type (parser, DP_DELIMITED);
262   data_parser_set_warn_missing_fields (parser, false);
263   data_parser_set_span (parser, false);
264
265   mformat.rowtype = dict_create_var (dict, "ROWTYPE_", 8);
266   mformat.varname = dict_create_var (dict, "VARNAME_", 8);
267
268   mformat.n_continuous_vars = 0;
269
270   if (! lex_force_match_id (lexer, "VARIABLES"))
271     goto error;
272
273   lex_match (lexer, T_EQUALS);
274
275   if (! parse_mixed_vars (lexer, dict, &names, &n_names, 0))
276     {
277       int i;
278       for (i = 0; i < n_names; ++i)
279         free (names[i]);
280       free (names);
281       goto error;
282     }
283
284   for (i = 0; i < n_names; ++i)
285     {
286       if (0 == strcasecmp (names[i], "ROWTYPE_"))
287         {
288           const struct fmt_spec fmt = fmt_for_input (FMT_A, 8, 0);
289           data_parser_add_delimited_field (parser,
290                                            &fmt,
291                                            var_get_case_index (mformat.rowtype),
292                                            "ROWTYPE_");
293         }
294       else
295         {
296           const struct fmt_spec fmt = fmt_for_input (FMT_F, 10, 4);
297           struct variable *v = dict_create_var (dict, names[i], 0);
298           var_set_both_formats (v, &fmt);
299           data_parser_add_delimited_field (parser,
300                                            &fmt,
301                                            var_get_case_index (mformat.varname) +
302                                            ++mformat.n_continuous_vars,
303                                            names[i]);
304         }
305     }
306   for (i = 0; i < n_names; ++i)
307     free (names[i]);
308   free (names);
309
310   while (lex_token (lexer) != T_ENDCMD)
311     {
312       if (! lex_force_match (lexer, T_SLASH))
313         goto error;
314
315       if (lex_match_id (lexer, "FORMAT"))
316         {
317           lex_match (lexer, T_EQUALS);
318
319           while (lex_token (lexer) != T_SLASH && (lex_token (lexer) != T_ENDCMD))
320             {
321               if (lex_match_id (lexer, "LIST"))
322                 {
323                   data_parser_set_span (parser, false);
324                 }
325               else if (lex_match_id (lexer, "FREE"))
326                 {
327                   data_parser_set_span (parser, true);
328                 }
329               else if (lex_match_id (lexer, "UPPER"))
330                 {
331                   mformat.triangle = UPPER;
332                 }
333               else if (lex_match_id (lexer, "LOWER"))
334                 {
335                   mformat.triangle = LOWER;
336                 }
337               else if (lex_match_id (lexer, "FULL"))
338                 {
339                   mformat.triangle = FULL;
340                 }
341               else if (lex_match_id (lexer, "DIAGONAL"))
342                 {
343                   mformat.diagonal = DIAGONAL;
344                 }
345               else if (lex_match_id (lexer, "NODIAGONAL"))
346                 {
347                   mformat.diagonal = NO_DIAGONAL;
348                 }
349               else
350                 {
351                   lex_error (lexer, NULL);
352                   goto error;
353                 }
354             }
355         }
356       else if (lex_match_id (lexer, "FILE"))
357         {
358           lex_match (lexer, T_EQUALS);
359           fh_unref (fh);
360           fh = fh_parse (lexer, FH_REF_FILE | FH_REF_INLINE, NULL);
361           if (fh == NULL)
362             goto error;
363         }
364       else if (lex_match_id (lexer, "SPLIT"))
365         {
366           lex_match (lexer, T_EQUALS);
367           struct variable **split_vars = NULL;
368           size_t n_split_vars;
369           if (! parse_variables (lexer, dict, &split_vars, &n_split_vars, 0))
370             {
371               free (split_vars);
372               goto error;
373             }
374           int i;
375           for (i = 0; i < n_split_vars; ++i)
376             {
377               const struct fmt_spec fmt = fmt_for_input (FMT_F, 4, 0);
378               var_set_both_formats (split_vars[i], &fmt);
379             }
380           dict_reorder_vars (dict, split_vars, n_split_vars);
381           mformat.n_continuous_vars -= n_split_vars;
382           free (split_vars);
383         }
384       else
385         {
386           lex_error (lexer, NULL);
387           goto error;
388         }
389     }
390
391   if (mformat.diagonal == NO_DIAGONAL && mformat.triangle == FULL)
392     {
393       msg (SE, _("FORMAT = FULL and FORMAT = NODIAGONAL are mutually exclusive."));
394       goto error;
395     }
396
397   if (fh == NULL)
398     fh = fh_inline_file ();
399   fh_set_default_handle (fh);
400
401   if (!data_parser_any_fields (parser))
402     {
403       msg (SE, _("At least one variable must be specified."));
404       goto error;
405     }
406
407   if (lex_end_of_command (lexer) != CMD_SUCCESS)
408     goto error;
409
410   reader = dfm_open_reader (fh, lexer, encoding);
411   if (reader == NULL)
412     goto error;
413
414   if (in_input_program ())
415     {
416       struct data_list_trns *trns = xmalloc (sizeof *trns);
417       trns->parser = parser;
418       trns->reader = reader;
419       trns->end = NULL;
420       add_transformation (ds, data_list_trns_proc, data_list_trns_free, trns);
421     }
422   else
423     {
424       data_parser_make_active_file (parser, ds, reader, dict, preprocess, &mformat);
425     }
426
427   fh_unref (fh);
428   free (encoding);
429
430   return CMD_DATA_LIST;
431
432  error:
433   data_parser_destroy (parser);
434   if (!in_input_program ())
435     dict_destroy (dict);
436   fh_unref (fh);
437   free (encoding);
438   return CMD_CASCADING_FAILURE;
439 }
440
441 \f
442 /* Input procedure. */
443
444 /* Destroys DATA LIST transformation TRNS.
445    Returns true if successful, false if an I/O error occurred. */
446 static bool
447 data_list_trns_free (void *trns_)
448 {
449   struct data_list_trns *trns = trns_;
450   data_parser_destroy (trns->parser);
451   dfm_close_reader (trns->reader);
452   free (trns);
453   return true;
454 }
455
456 /* Handle DATA LIST transformation TRNS, parsing data into *C. */
457 static int
458 data_list_trns_proc (void *trns_, struct ccase **c, casenumber case_num UNUSED)
459 {
460   struct data_list_trns *trns = trns_;
461   int retval;
462
463   *c = case_unshare (*c);
464   if (data_parser_parse (trns->parser, trns->reader, *c))
465     retval = TRNS_CONTINUE;
466   else if (dfm_reader_error (trns->reader) || dfm_eof (trns->reader) > 1)
467     {
468       /* An I/O error, or encountering end of file for a second
469          time, should be escalated into a more serious error. */
470       retval = TRNS_ERROR;
471     }
472   else
473     retval = TRNS_END_FILE;
474
475   /* If there was an END subcommand handle it. */
476   if (trns->end != NULL)
477     {
478       double *end = &case_data_rw (*c, trns->end)->f;
479       if (retval == TRNS_END_FILE)
480         {
481           *end = 1.0;
482           retval = TRNS_CONTINUE;
483         }
484       else
485         *end = 0.0;
486     }
487
488   return retval;
489 }
490