1e24464709e09195b88b8dcb15f7445b471df60b
[pspp] / src / language / data-io / mconvert.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2021 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <math.h>
20
21 #include "data/any-reader.h"
22 #include "data/any-writer.h"
23 #include "data/casereader.h"
24 #include "data/casewriter.h"
25 #include "data/dataset.h"
26 #include "data/dictionary.h"
27 #include "language/data-io/file-handle.h"
28 #include "language/data-io/matrix-reader.h"
29 #include "language/lexer/lexer.h"
30 #include "language/command.h"
31
32 #include "gettext.h"
33 #define _(msgid) gettext (msgid)
34
35 int
36 cmd_mconvert (struct lexer *lexer, struct dataset *ds)
37 {
38   bool append = false;
39   struct file_handle *in = NULL;
40   struct file_handle *out = NULL;
41   while (lex_token (lexer) != T_ENDCMD)
42     {
43       if (lex_match_id (lexer, "APPEND"))
44         append = true;
45       else if (lex_match_id (lexer, "REPLACE"))
46         append = false;
47       else
48         {
49           if (lex_match_id (lexer, "MATRIX"))
50             lex_match (lexer, T_EQUALS);
51
52           struct file_handle **fhp = (lex_match_id (lexer, "IN") ? &in
53                                       : lex_match_id (lexer, "OUT") ? &out
54                                       : NULL);
55           if (!fhp)
56             {
57               lex_error_expecting (lexer, "IN", "OUT", "APPEND", "REPLACE");
58               goto error;
59             }
60           if (!lex_force_match (lexer, T_LPAREN))
61             goto error;
62
63           fh_unref (*fhp);
64           if (lex_match (lexer, T_ASTERISK))
65             *fhp = NULL;
66           else
67             {
68               *fhp = fh_parse (lexer, FH_REF_FILE, dataset_session (ds));
69               if (!*fhp)
70                 goto error;
71             }
72
73           if (!lex_force_match (lexer, T_RPAREN))
74             goto error;
75         }
76
77       lex_match (lexer, T_SLASH);
78     }
79
80   if (!in && !dataset_has_source (ds))
81     {
82       msg (SE, _("No active file is defined and no external file is "
83                  "specified on MATRIX=IN."));
84       goto error;
85     }
86
87   assert (in);
88   assert (out);
89
90   struct dictionary *d;
91   struct casereader *cr = any_reader_open_and_decode (in, NULL, &d, NULL);
92   if (!cr)
93     goto error;
94
95   struct matrix_reader *mr = matrix_reader_create (d, cr);
96   if (!mr)
97     {
98       casereader_destroy (cr);
99       dict_unref (d);
100       goto error;
101     }
102
103   struct casewriter *cw = any_writer_open (out, d);
104   if (!cw)
105     {
106       matrix_reader_destroy (mr);
107       casereader_destroy (cr);
108       dict_unref (d);
109       goto error;
110     }
111
112   for (;;)
113     {
114       struct matrix_material mm;
115       struct casereader *group;
116       if (!matrix_reader_next (&mm, mr, &group))
117         break;
118
119       bool add_corr = mm.cov && !mm.corr;
120       bool add_cov = mm.corr && !mm.cov;
121       bool remove_corr = add_cov && !append;
122       bool remove_cov = add_corr && !append;
123
124       struct ccase *model = casereader_peek (group, 0);
125       for (size_t i = 0; i < mr->n_fvars; i++)
126         *case_num_rw (model, mr->fvars[i]) = SYSMIS;
127
128       for (;;)
129         {
130           struct ccase *c = casereader_read (group);
131           if (!c)
132             break;
133
134           struct substring rowtype = matrix_reader_get_string (c, mr->rowtype);
135           if ((remove_cov && ss_equals_case (rowtype, ss_cstr ("COV")))
136               || (remove_corr && ss_equals_case (rowtype, ss_cstr ("CORR"))))
137             case_unref (c);
138           else
139             casewriter_write (cw, c);
140         }
141       casereader_destroy (group);
142
143       if (add_corr)
144         {
145           assert (mm.cov->size1 == mr->n_cvars);
146           assert (mm.cov->size2 == mr->n_cvars);
147
148           for (size_t y = 0; y < mr->n_cvars; y++)
149             {
150               struct ccase *c = case_clone (model);
151               for (size_t x = 0; x < mr->n_cvars; x++)
152                 {
153                   double d1 = gsl_matrix_get (mm.cov, x, x);
154                   double d2 = gsl_matrix_get (mm.cov, y, y);
155                   double cov = gsl_matrix_get (mm.cov, y, x);
156                   *case_num_rw (c, mr->cvars[x]) = cov / sqrt (d1 * d2);
157                 }
158               matrix_reader_set_string (c, mr->rowtype, ss_cstr ("CORR"));
159               matrix_reader_set_string (c, mr->varname,
160                                         ss_cstr (var_get_name (mr->cvars[y])));
161               casewriter_write (cw, c);
162             }
163
164           struct ccase *c = case_clone (model);
165           for (size_t x = 0; x < mr->n_cvars; x++)
166             {
167               double variance = gsl_matrix_get (mm.cov, x, x);
168               *case_num_rw (c, mr->cvars[x]) = sqrt (variance);
169             }
170           matrix_reader_set_string (c, mr->rowtype, ss_cstr ("STDDEV"));
171           matrix_reader_set_string (c, mr->varname, ss_empty ());
172           casewriter_write (cw, c);
173         }
174
175       if (add_cov)
176         {
177           assert (mm.corr->size1 == mr->n_cvars);
178           assert (mm.corr->size2 == mr->n_cvars);
179
180           for (size_t y = 0; y < mr->n_cvars; y++)
181             {
182               struct ccase *c = case_clone (model);
183               for (size_t x = 0; x < mr->n_cvars; x++)
184                 {
185                   double d1 = gsl_matrix_get (mm.var_matrix, x, x);
186                   double d2 = gsl_matrix_get (mm.var_matrix, y, y);
187                   double corr = gsl_matrix_get (mm.corr, y, x);
188                   *case_num_rw (c, mr->cvars[x]) = corr * sqrt (d1 * d2);
189                 }
190               matrix_reader_set_string (c, mr->rowtype, ss_cstr ("COV"));
191               matrix_reader_set_string (c, mr->varname,
192                                         ss_cstr (var_get_name (mr->cvars[y])));
193               casewriter_write (cw, c);
194             }
195         }
196
197       case_unref (model);
198     }
199
200   matrix_reader_destroy (mr);
201   casewriter_destroy (cw);
202   fh_unref (in);
203   fh_unref (out);
204   dict_unref (d);
205   return CMD_SUCCESS;
206
207 error:
208   fh_unref (in);
209   fh_unref (out);
210   dict_unref (d);
211   return CMD_FAILURE;
212 }
213