862f49cea24300a7f4f158e72933cf8e613b0165
[pspp-builds.git] / src / language / stats / glm.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2010 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/case.h>
20 #include <data/casegrouper.h>
21 #include <data/casereader.h>
22
23 #include <math/covariance.h>
24 #include <math/categoricals.h>
25 #include <math/moments.h>
26 #include <gsl/gsl_matrix.h>
27 #include <linreg/sweep.h>
28
29 #include <libpspp/ll.h>
30
31 #include <language/lexer/lexer.h>
32 #include <language/lexer/variable-parser.h>
33 #include <language/lexer/value-parser.h>
34 #include <language/command.h>
35
36 #include <data/procedure.h>
37 #include <data/value.h>
38 #include <data/dictionary.h>
39
40 #include <language/dictionary/split-file.h>
41 #include <libpspp/taint.h>
42 #include <libpspp/misc.h>
43
44 #include <gsl/gsl_cdf.h>
45 #include <math.h>
46 #include <data/format.h>
47
48 #include <libpspp/message.h>
49
50 #include <output/tab.h>
51
52 #include "gettext.h"
53 #define _(msgid) gettext (msgid)
54
55 struct glm_spec
56 {
57   size_t n_dep_vars;
58   const struct variable **dep_vars;
59
60   size_t n_factor_vars;
61   const struct variable **factor_vars;
62
63   enum mv_class exclude;
64
65   /* The weight variable */
66   const struct variable *wv;
67
68   bool intercept;
69 };
70
71 struct glm_workspace
72 {
73   double total_ssq;
74   struct moments *totals;
75 };
76
77 static void output_glm (const struct glm_spec *, const struct glm_workspace *ws);
78 static void run_glm (const struct glm_spec *cmd, struct casereader *input, const struct dataset *ds);
79
80 int
81 cmd_glm (struct lexer *lexer, struct dataset *ds)
82 {
83   const struct dictionary *dict = dataset_dict (ds);  
84   struct glm_spec glm ;
85   glm.n_dep_vars = 0;
86   glm.n_factor_vars = 0;
87   glm.dep_vars = NULL;
88   glm.factor_vars = NULL;
89   glm.exclude = MV_ANY;
90   glm.intercept = true;
91   glm.wv = dict_get_weight (dict);
92
93   
94   if (!parse_variables_const (lexer, dict,
95                               &glm.dep_vars, &glm.n_dep_vars,
96                               PV_NO_DUPLICATE | PV_NUMERIC))
97     goto error;
98
99   lex_force_match (lexer, T_BY);
100
101   if (!parse_variables_const (lexer, dict,
102                               &glm.factor_vars, &glm.n_factor_vars,
103                               PV_NO_DUPLICATE | PV_NUMERIC))
104     goto error;
105
106   if ( glm.n_dep_vars > 1)
107     {
108       msg (ME, _("Multivariate analysis is not yet implemented"));
109       return CMD_FAILURE;
110     }
111
112   struct const_var_set *factors = const_var_set_create_from_array (glm.factor_vars, glm.n_factor_vars);
113
114
115   while (lex_token (lexer) != T_ENDCMD)
116     {
117       lex_match (lexer, T_SLASH);
118
119       if (lex_match_id (lexer, "MISSING"))
120         {
121           lex_match (lexer, T_EQUALS);
122           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
123             {
124               if (lex_match_id (lexer, "INCLUDE"))
125                 {
126                   glm.exclude = MV_SYSTEM;
127                 }
128               else if (lex_match_id (lexer, "EXCLUDE"))
129                 {
130                   glm.exclude = MV_ANY;
131                 }
132               else
133                 {
134                   lex_error (lexer, NULL);
135                   goto error;
136                 }
137             }
138         }
139       else if (lex_match_id (lexer, "INTERCEPT"))
140         {
141           lex_match (lexer, T_EQUALS);
142           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
143             {
144               if (lex_match_id (lexer, "INCLUDE"))
145                 {
146                   glm.intercept = true;
147                 }
148               else if (lex_match_id (lexer, "EXCLUDE"))
149                 {
150                   glm.intercept = false;
151                 }
152               else
153                 {
154                   lex_error (lexer, NULL);
155                   goto error;
156                 }
157             }
158         }
159       else if (lex_match_id (lexer, "DESIGN"))
160         {
161           size_t n_des;
162           const struct variable **des;
163           lex_match (lexer, T_EQUALS);
164
165           parse_const_var_set_vars (lexer, factors, &des, &n_des, 0);
166         }
167       else
168         {
169           lex_error (lexer, NULL);
170           goto error;
171         }
172     }
173
174
175   {
176     struct casegrouper *grouper;
177     struct casereader *group;
178     bool ok;
179
180     grouper = casegrouper_create_splits (proc_open (ds), dict);
181     while (casegrouper_get_next_group (grouper, &group))
182       run_glm (&glm, group, ds);
183     ok = casegrouper_destroy (grouper);
184     ok = proc_commit (ds) && ok;
185   }
186
187   return CMD_SUCCESS;
188
189  error:
190   return CMD_FAILURE;
191 }
192
193 static  void dump_matrix (const gsl_matrix *m);
194
195 static void
196 run_glm (const struct glm_spec *cmd, struct casereader *input, const struct dataset *ds)
197 {
198   int v;
199   struct taint *taint;
200   struct dictionary *dict = dataset_dict (ds);
201   struct casereader *reader;
202   struct ccase *c;
203
204   struct glm_workspace ws;
205
206   struct categoricals *cats = categoricals_create (cmd->factor_vars, cmd->n_factor_vars,
207                                                    cmd->wv, cmd->exclude, 
208                                                    NULL, NULL,
209                                                    NULL, NULL);
210   
211   struct covariance *cov = covariance_2pass_create (cmd->n_dep_vars, cmd->dep_vars,
212                                                cats, 
213                                                cmd->wv, cmd->exclude);
214
215
216   c = casereader_peek (input, 0);
217   if (c == NULL)
218     {
219       casereader_destroy (input);
220       return;
221     }
222   output_split_file_values (ds, c);
223   case_unref (c);
224
225   taint = taint_clone (casereader_get_taint (input));
226
227   ws.totals = moments_create (MOMENT_VARIANCE);
228
229   bool warn_bad_weight = true;
230   for (reader = casereader_clone (input);
231        (c = casereader_read (reader)) != NULL; case_unref (c))
232     {
233       double weight = dict_get_case_weight (dict, c, &warn_bad_weight);
234
235       for ( v = 0; v < cmd->n_dep_vars; ++v)
236         moments_pass_one (ws.totals, case_data (c, cmd->dep_vars[v])->f, weight);
237
238       covariance_accumulate_pass1 (cov, c);
239     }
240   casereader_destroy (reader);
241
242   categoricals_done (cats);
243
244   for (reader = casereader_clone (input);
245        (c = casereader_read (reader)) != NULL; case_unref (c))
246     {
247       double weight = dict_get_case_weight (dict, c, &warn_bad_weight);
248
249       for ( v = 0; v < cmd->n_dep_vars; ++v)
250         moments_pass_two (ws.totals, case_data (c, cmd->dep_vars[v])->f, weight);
251
252       covariance_accumulate_pass2 (cov, c);
253     }
254   casereader_destroy (reader);
255
256   {
257     gsl_matrix *cm = covariance_calculate_unnormalized (cov);
258
259     dump_matrix (cm);
260
261     ws.total_ssq = gsl_matrix_get (cm, 0, 0);
262
263     reg_sweep (cm, 0);
264
265     dump_matrix (cm);
266
267     gsl_matrix_free (cm);
268   }
269
270   if (!taint_has_tainted_successor (taint))
271     output_glm (cmd, &ws);
272
273   taint_destroy (taint);
274 }
275
276 static void
277 output_glm (const struct glm_spec *cmd, const struct glm_workspace *ws)
278 {
279   const struct fmt_spec *wfmt = cmd->wv ? var_get_print_format (cmd->wv) : &F_8_0;
280
281   int f;
282   int r;
283   const int heading_columns = 1;
284   const int heading_rows = 1;
285   struct tab_table *t ;
286
287   const int nc = 6;
288   int nr = heading_rows + 4 + cmd->n_factor_vars;
289   if (cmd->intercept)
290     nr++;
291
292   t = tab_create (nc, nr);
293   tab_title (t, _("Tests of Between-Subjects Effects"));
294
295   tab_headers (t, heading_columns, 0, heading_rows, 0);
296
297   tab_box (t,
298            TAL_2, TAL_2,
299            -1, TAL_1,
300            0, 0,
301            nc - 1, nr - 1);
302
303   tab_hline (t, TAL_2, 0, nc - 1, heading_rows);
304   tab_vline (t, TAL_2, heading_columns, 0, nr - 1);
305
306   tab_text (t, 0, 0, TAB_CENTER | TAT_TITLE, _("Source"));
307
308   /* TRANSLATORS: The parameter is a roman numeral */
309   tab_text_format (t, 1, 0, TAB_CENTER | TAT_TITLE, _("Type %s Sum of Squares"), "III");
310   tab_text (t, 2, 0, TAB_CENTER | TAT_TITLE, _("df"));
311   tab_text (t, 3, 0, TAB_CENTER | TAT_TITLE, _("Mean Square"));
312   tab_text (t, 4, 0, TAB_CENTER | TAT_TITLE, _("F"));
313   tab_text (t, 5, 0, TAB_CENTER | TAT_TITLE, _("Sig."));
314
315   r = heading_rows;
316   tab_text (t, 0, r++, TAB_LEFT | TAT_TITLE, _("Corrected Model"));
317
318   double intercept, n_total;
319   if (cmd->intercept)
320     {
321       double mean;
322       moments_calculate (ws->totals, &n_total, &mean, NULL, NULL, NULL);
323       intercept = pow2 (mean * n_total) / n_total;
324
325       tab_text (t, 0, r, TAB_LEFT | TAT_TITLE, _("Intercept"));
326       tab_double (t, 1, r, 0, intercept, NULL);
327       tab_double (t, 2, r, 0, 1.00, wfmt);
328
329       tab_double (t, 3, r, 0, intercept / 1.0 , NULL);
330       r++;
331     }
332
333   for (f = 0; f < cmd->n_factor_vars; ++f)
334     {
335       tab_text (t, 0, r++, TAB_LEFT | TAT_TITLE,
336                 var_to_string (cmd->factor_vars[f]));
337     }
338
339   tab_text (t, 0, r++, TAB_LEFT | TAT_TITLE, _("Error"));
340
341   if (cmd->intercept)
342     {
343       double ssq = intercept + ws->total_ssq;
344       double mse = ssq / n_total;
345       tab_text (t, 0, r, TAB_LEFT | TAT_TITLE, _("Total"));
346       tab_double (t, 1, r, 0, ssq, NULL);
347       tab_double (t, 2, r, 0, n_total, wfmt);
348
349       r++;
350     }
351
352   tab_text (t, 0, r, TAB_LEFT | TAT_TITLE, _("Corrected Total"));
353
354   tab_double (t, 1, r, 0, ws->total_ssq, NULL);
355   tab_double (t, 2, r, 0, n_total - 1.0, wfmt);
356
357   tab_submit (t);
358 }
359
360 static 
361 void dump_matrix (const gsl_matrix *m)
362 {
363   size_t i, j;
364   for (i = 0; i < m->size1; ++i)
365     {
366       for (j = 0; j < m->size2; ++j)
367         {
368           double x = gsl_matrix_get (m, i, j);
369           printf ("%.3f ", x);
370         }
371       printf ("\n");
372     }
373   printf ("\n");
374 }