Fix potential crash if matrix file variables are of unexpected types
[pspp] / src / language / data-io / matrix-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2017 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "matrix-reader.h"
20
21 #include <stdbool.h>
22
23 #include <libpspp/hash-functions.h>
24 #include <libpspp/message.h>
25 #include <data/casegrouper.h>
26 #include <data/casereader.h>
27 #include <data/dictionary.h>
28 #include <data/variable.h>
29
30 #include "gettext.h"
31 #define _(msgid) gettext (msgid)
32 #define N_(msgid) msgid
33
34
35 /*
36 This module interprets a "data matrix", typically generated by the command
37 MATRIX DATA.  The dictionary of such a matrix takes the form:
38
39  s_0, s_1, ... s_m, ROWTYPE_, VARNAME_, v_0, v_1, .... v_n
40
41 where s_0, s_1 ... s_m are the variables defining the splits, and
42 v_0, v_1 ... v_n are the continuous variables.
43
44 m >= 0; n >= 0
45
46 The ROWTYPE_ variable is of type A8.
47 The VARNAME_ variable is a string type whose width is not predetermined.
48 The variables s_x are of type F4.0 (although this reader accepts any type),
49 and v_x are of any numeric type.
50
51 The values of the ROWTYPE_ variable are in the set {MEAN, STDDEV, N, CORR, COV}
52 and determine the purpose of that case.
53 The values of the VARNAME_ variable must correspond to the names of the varibles
54 in {v_0, v_1 ... v_n} and indicate the rows of the correlation or covariance
55 matrices.
56
57
58
59 A typical example is as follows:
60
61 s_0 ROWTYPE_   VARNAME_   v_0         v_1         v_2
62
63 0   MEAN                5.0000       4.0000       3.0000
64 0   STDDEV              1.0000       2.0000       3.0000
65 0   N                   9.0000       9.0000       9.0000
66 0   CORR       V1       1.0000        .6000        .7000
67 0   CORR       V2        .6000       1.0000        .8000
68 0   CORR       V3        .7000        .8000       1.0000
69 1   MEAN                9.0000       8.0000       7.0000
70 1   STDDEV              5.0000       6.0000       7.0000
71 1   N                   9.0000       9.0000       9.0000
72 1   CORR       V1       1.0000        .4000        .3000
73 1   CORR       V2        .4000       1.0000        .2000
74 1   CORR       V3        .3000        .2000       1.0000
75
76 */
77
78 struct matrix_reader
79 {
80   const struct variable *varname;
81   const struct variable *rowtype;
82   struct casegrouper *grouper;
83
84   gsl_matrix *n_vectors;
85   gsl_matrix *mean_vectors;
86   gsl_matrix *var_vectors;
87
88   //  gsl_matrix *correlation;
89   //  gsl_matrix *covariance;
90 };
91
92 struct matrix_reader *
93 create_matrix_reader_from_case_reader (const struct dictionary *dict, struct casereader *in_reader,
94                                        const struct variable ***vars, size_t *n_vars)
95 {
96   struct matrix_reader *mr = xzalloc (sizeof *mr);
97
98   mr->varname = dict_lookup_var (dict, "varname_");
99   if (mr->varname == NULL)
100     {
101       msg (ME, _("Matrix dataset lacks a variable called %s."), "VARNAME_");
102       free (mr);
103       return NULL;
104     }
105
106   if (!var_is_alpha (mr->varname))
107     {
108       msg (ME, _("Matrix dataset variable %s should be of string type."),
109            "VARNAME_");
110       free (mr);
111       return NULL;
112     }
113
114   mr->rowtype = dict_lookup_var (dict, "rowtype_");
115   if (mr->rowtype == NULL)
116     {
117       msg (ME, _("Matrix dataset lacks a variable called %s."), "ROWTYPE_");
118       free (mr);
119       return NULL;
120     }
121
122   if (!var_is_alpha (mr->rowtype))
123     {
124       msg (ME, _("Matrix dataset variable %s should be of string type."),
125            "ROWTYPE_");
126       free (mr);
127       return NULL;
128     }
129
130   size_t dvarcnt;
131   const struct variable **dvars = NULL;
132   dict_get_vars (dict, &dvars, &dvarcnt, DC_SCRATCH);
133
134   if (n_vars)
135     *n_vars = dvarcnt - var_get_dict_index (mr->varname) - 1;
136
137   if (vars)
138     {
139       int i;
140       *vars = xcalloc (sizeof (struct variable **), *n_vars);
141
142       for (i = 0; i < *n_vars; ++i)
143         {
144           (*vars)[i] = dvars[i + var_get_dict_index (mr->varname) + 1];
145         }
146     }
147
148   /* All the variables before ROWTYPE_ (if any) are split variables */
149   mr->grouper = casegrouper_create_vars (in_reader, dvars, var_get_dict_index (mr->rowtype));
150
151   free (dvars);
152
153   return mr;
154 }
155
156 bool
157 destroy_matrix_reader (struct matrix_reader *mr)
158 {
159   if (mr == NULL)
160     return false;
161   bool ret = casegrouper_destroy (mr->grouper);
162   free (mr);
163   return ret;
164 }
165
166
167 /*
168    Allocates MATRIX if necessary,
169    and populates row MROW, from the data in C corresponding to
170    variables in VARS. N_VARS is the length of VARS.
171 */
172 static void
173 matrix_fill_row (gsl_matrix **matrix,
174       const struct ccase *c, int mrow,
175       const struct variable **vars, size_t n_vars)
176 {
177   int col;
178   if (*matrix == NULL)
179     *matrix = gsl_matrix_alloc (n_vars, n_vars);
180
181   for (col = 0; col < n_vars; ++col)
182     {
183       const struct variable *cv = vars [col];
184       double x = case_data (c, cv)->f;
185       assert (col  < (*matrix)->size2);
186       assert (mrow < (*matrix)->size1);
187       gsl_matrix_set (*matrix, mrow, col, x);
188     }
189 }
190
191 bool
192 next_matrix_from_reader (struct matrix_material *mm,
193                          struct matrix_reader *mr,
194                          const struct variable **vars, int n_vars)
195 {
196   struct casereader *group;
197
198   assert (vars);
199
200   gsl_matrix_free (mr->n_vectors);
201   gsl_matrix_free (mr->mean_vectors);
202   gsl_matrix_free (mr->var_vectors);
203
204   if (!casegrouper_get_next_group (mr->grouper, &group))
205     return false;
206
207   mr->n_vectors    = gsl_matrix_alloc (n_vars, n_vars);
208   mr->mean_vectors = gsl_matrix_alloc (n_vars, n_vars);
209   mr->var_vectors  = gsl_matrix_alloc (n_vars, n_vars);
210
211   mm->n = mr->n_vectors;
212   mm->mean_matrix = mr->mean_vectors;
213   mm->var_matrix = mr->var_vectors;
214
215   // FIXME: Make this into a hash table.
216   unsigned long *table = xmalloc (sizeof (*table) * n_vars);
217   int i;
218   for (i = 0; i < n_vars; ++i)
219     {
220       const int w = var_get_width (mr->varname);
221       char s[w];
222       memset (s, 0, w);
223       const char *name = var_get_name (vars[i]);
224       strncpy (s, name, w);
225       unsigned long h = hash_bytes (s, w, 0);
226       table[i] = h;
227     }
228
229   struct ccase *c;
230   for ( ; (c = casereader_read (group) ); case_unref (c))
231     {
232       const union value *uv  = case_data (c, mr->rowtype);
233       int col, row;
234       for (col = 0; col < n_vars; ++col)
235         {
236           const struct variable *cv = vars[col];
237           double x = case_data (c, cv)->f;
238           if (0 == strncasecmp ((char *)value_str (uv, 8), "N       ", 8))
239             for (row = 0; row < n_vars; ++row)
240               gsl_matrix_set (mr->n_vectors, row, col, x);
241           else if (0 == strncasecmp ((char *) value_str (uv, 8), "MEAN    ", 8))
242             for (row = 0; row < n_vars; ++row)
243               gsl_matrix_set (mr->mean_vectors, row, col, x);
244           else if (0 == strncasecmp ((char *) value_str (uv, 8), "STDDEV  ", 8))
245             for (row = 0; row < n_vars; ++row)
246               gsl_matrix_set (mr->var_vectors, row, col, x * x);
247         }
248
249       const union value *uvv  = case_data (c, mr->varname);
250       const uint8_t *vs = value_str (uvv, var_get_width (mr->varname));
251       int w = var_get_width (mr->varname);
252       unsigned long h = hash_bytes (vs, w, 0);
253
254       int mrow = -1;
255       for (i = 0; i < n_vars; ++i)
256         {
257           if (table[i] == h)
258             {
259               mrow = i;
260               break;
261             }
262         }
263
264       if (mrow == -1)
265         continue;
266
267       if (0 == strncasecmp ((char *) value_str (uv, 8), "CORR    ", 8))
268         {
269           matrix_fill_row (&mm->corr, c, mrow, vars, n_vars);
270         }
271       else if (0 == strncasecmp ((char *) value_str (uv, 8), "COV     ", 8))
272         {
273           matrix_fill_row (&mm->cov, c, mrow, vars, n_vars);
274         }
275     }
276
277   casereader_destroy (group);
278
279   free (table);
280
281   return true;
282 }