New functions glm_custom_design and parse_interactions.
[pspp-builds.git] / src / language / stats / glm.q
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2007, 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_cdf.h>
20 #include <gsl/gsl_matrix.h>
21 #include <gsl/gsl_vector.h>
22 #include <math.h>
23 #include <stdlib.h>
24
25 #include <data/case.h>
26 #include <data/category.h>
27 #include <data/casegrouper.h>
28 #include <data/casereader.h>
29 #include <data/dictionary.h>
30 #include <data/missing-values.h>
31 #include <data/procedure.h>
32 #include <data/transformations.h>
33 #include <data/value-labels.h>
34 #include <data/variable.h>
35 #include <language/command.h>
36 #include <language/dictionary/split-file.h>
37 #include <language/data-io/file-handle.h>
38 #include <language/lexer/lexer.h>
39 #include <libpspp/compiler.h>
40 #include <libpspp/hash.h>
41 #include <libpspp/message.h>
42 #include <math/covariance-matrix.h>
43 #include <math/coefficient.h>
44 #include <math/linreg.h>
45 #include <math/moments.h>
46 #include <output/table.h>
47
48 #include "xalloc.h"
49 #include "gettext.h"
50
51 /* (headers) */
52
53 /* (specification)
54    "GLM" (glm_):
55    *dependent=custom;
56    design=custom;
57    by=varlist;
58    with=varlist.
59 */
60 /* (declarations) */
61 /* (functions) */
62 static struct cmd_glm cmd;
63
64
65 /*
66   Moments for each of the variables used.
67  */
68 struct moments_var
69 {
70   struct moments1 *m;
71   double *weight;
72   double *mean;
73   double *variance;
74   const struct variable *v;
75 };
76
77
78 /*
79   Dependent variable used.
80  */
81 static const struct variable **v_dependent;
82
83 /*
84   Number of dependent variables.
85  */
86 static size_t n_dependent;
87
88 size_t n_inter; /* Number of interactions. */
89 size_t n_members; /* Number of memebr variables in an interaction. */ 
90
91 struct interaction_variable **interactions;
92
93 int cmd_glm (struct lexer *lexer, struct dataset *ds);
94
95 static bool run_glm (struct casereader *,
96                      struct cmd_glm *,
97                      const struct dataset *);
98
99 int
100 cmd_glm (struct lexer *lexer, struct dataset *ds)
101 {
102   struct casegrouper *grouper;
103   struct casereader *group;
104
105   bool ok;
106
107   if (!parse_glm (lexer, ds, &cmd, NULL))
108     return CMD_FAILURE;
109
110   /* Data pass. */
111   grouper = casegrouper_create_splits (proc_open (ds), dataset_dict (ds));
112   while (casegrouper_get_next_group (grouper, &group))
113     {
114       run_glm (group, &cmd, ds);
115     }
116   ok = casegrouper_destroy (grouper);
117   ok = proc_commit (ds) && ok;
118
119   free (v_dependent);
120   return ok ? CMD_SUCCESS : CMD_FAILURE;
121 }
122 static int
123 parse_interactions (struct lexer *lexer, const struct variable **interaction_vars, int n_members,
124                     int max_members, struct dataset *ds)
125 {
126   if (lex_match (lexer, '*'))
127     {
128       if (n_members > max_members)
129         {
130           max_members *= 2;
131           xnrealloc (interaction_vars, max_members, sizeof (*interaction_vars));
132         }
133       interaction_vars[n_members] = parse_variable (lexer, dataset_dict (ds));
134       parse_interactions (lexer, interaction_vars, n_members++, max_members, ds);
135     }
136   return n_members;
137 }
138 /* Parser for the design subcommand. */
139 static int
140 glm_custom_design (struct lexer *lexer, struct dataset *ds,
141                    struct cmd_glm *cmd UNUSED, void *aux UNUSED)
142 {
143   size_t n_inter = 0;
144   size_t n_allocated = 2;
145   size_t n_members;
146   struct variable **interaction_vars;
147   struct variable *this_var;
148
149   interactions = xnmalloc (n_allocated, sizeof (*interactions));
150
151   while (lex_token (lexer) != T_STOP && lex_token (lexer) != '.')
152     {
153       this_var = parse_variable (lexer, dataset_dict (ds));
154       if (lex_match (lexer, '('))
155         {
156           lex_force_match (lexer, ')');
157         }
158       else if (lex_match (lexer, '*'))
159         {
160           n_members = 1;
161           interaction_vars = xnmalloc (2 * n_inter, sizeof (*interaction_vars));
162           n_members = parse_interactions (lexer, interaction_vars, 1, 2 * n_inter, ds);
163           if (n_allocated < n_inter)
164             {
165               n_allocated *= 2;
166               xnrealloc (interactions, n_allocated, sizeof (*interactions));
167             }
168           interactions [n_inter - 1] = 
169             interaction_variable_create (interaction_vars, n_members);
170           n_inter++;
171           free (interaction_vars);
172         }
173     }
174   return 1;
175 }
176 /* Parser for the dependent sub command */
177 static int
178 glm_custom_dependent (struct lexer *lexer, struct dataset *ds,
179                       struct cmd_glm *cmd UNUSED, void *aux UNUSED)
180 {
181   const struct dictionary *dict = dataset_dict (ds);
182
183   if ((lex_token (lexer) != T_ID
184        || dict_lookup_var (dict, lex_tokid (lexer)) == NULL)
185       && lex_token (lexer) != T_ALL)
186     return 2;
187
188   if (!parse_variables_const
189       (lexer, dict, &v_dependent, &n_dependent, PV_NONE))
190     {
191       free (v_dependent);
192       return 0;
193     }
194   assert (n_dependent);
195   if (n_dependent > 1)
196     msg (SE, _("Multivariate GLM not yet supported"));
197   n_dependent = 1;              /* Drop this line after adding support for multivariate GLM. */
198
199   return 1;
200 }
201
202 /*
203   COV is the covariance matrix for variables included in the
204   model. That means the dependent variable is in there, too.
205  */
206 static void
207 coeff_init (pspp_linreg_cache * c, const struct design_matrix *cov)
208 {
209   c->coeff = xnmalloc (cov->m->size2, sizeof (*c->coeff));
210   c->n_coeffs = cov->m->size2 - 1;
211   pspp_coeff_init (c->coeff, cov);
212 }
213
214
215 static pspp_linreg_cache *
216 fit_model (const struct covariance_matrix *cov,
217            const struct variable *dep_var, 
218            const struct variable ** indep_vars, 
219            size_t n_data, size_t n_indep)
220 {
221   pspp_linreg_cache *result = NULL;
222   result = pspp_linreg_cache_alloc (dep_var, indep_vars, n_data, n_indep);
223   coeff_init (result, covariance_to_design (cov));
224   pspp_linreg_with_cov (cov, result);  
225   
226   return result;
227 }
228
229 static bool
230 run_glm (struct casereader *input,
231          struct cmd_glm *cmd,
232          const struct dataset *ds)
233 {
234   casenumber row;
235   const struct variable **indep_vars;
236   const struct variable **all_vars;
237   int n_indep = 0;
238   pspp_linreg_cache *model = NULL; 
239   pspp_linreg_opts lopts;
240   struct ccase *c;
241   size_t i;
242   size_t n_all_vars;
243   size_t n_data;                /* Number of valid cases. */
244   struct casereader *reader;
245   struct covariance_matrix *cov;
246
247   c = casereader_peek (input, 0);
248   if (c == NULL)
249     {
250       casereader_destroy (input);
251       return true;
252     }
253   output_split_file_values (ds, c);
254   case_unref (c);
255
256   if (!v_dependent)
257     {
258       dict_get_vars (dataset_dict (ds), &v_dependent, &n_dependent,
259                      1u << DC_SYSTEM);
260     }
261
262   lopts.get_depvar_mean_std = 1;
263
264   lopts.get_indep_mean_std = xnmalloc (n_dependent, sizeof (int));
265   indep_vars = xnmalloc (cmd->n_by, sizeof *indep_vars);
266   n_all_vars = cmd->n_by + n_dependent;
267   all_vars = xnmalloc (n_all_vars, sizeof *all_vars);
268
269   for (i = 0; i < n_dependent; i++)
270     {
271       all_vars[i] = v_dependent[i];
272     }
273   for (i = 0; i < cmd->n_by; i++)
274     {
275       indep_vars[i] = cmd->v_by[i];
276       all_vars[i + n_dependent] = cmd->v_by[i];
277     }
278   n_indep = cmd->n_by;
279
280   reader = casereader_clone (input);
281   reader = casereader_create_filter_missing (reader, indep_vars, n_indep,
282                                              MV_ANY, NULL, NULL);
283   reader = casereader_create_filter_missing (reader, v_dependent, 1,
284                                              MV_ANY, NULL, NULL);
285
286   if (n_indep > 0)
287     {
288       for (i = 0; i < n_all_vars; i++)
289         if (var_is_alpha (all_vars[i]))
290           cat_stored_values_create (all_vars[i]);
291       
292       cov = covariance_matrix_init (n_all_vars, all_vars, ONE_PASS, PAIRWISE, MV_ANY);
293
294       reader = casereader_create_counter (reader, &row, -1);
295
296       for (i = 0; i < n_inter; i++)
297         if (var_is_alpha (interaction_get_variable (interactions[i])))
298           cat_stored_values_create (interaction_get_variable (interactions[i]));
299       covariance_interaction_set (cov, interactions, 1);
300       for (; (c = casereader_read (reader)) != NULL; case_unref (c))
301         {
302           /* 
303              Accumulate the covariance matrix.
304           */
305           covariance_matrix_accumulate (cov, c, interactions, 1);
306           n_data++;
307         }
308       covariance_matrix_compute (cov);
309       for (i = 0; i < n_dependent; i++)
310         {
311           model = fit_model (cov, v_dependent[i], indep_vars, n_data, n_indep);
312           pspp_linreg_cache_free (model);
313         }
314
315       casereader_destroy (reader);
316       covariance_matrix_destroy (cov);
317     }
318   else
319     {
320       msg (SE, gettext ("No valid data found. This command was skipped."));
321     }
322   free (indep_vars);
323   free (lopts.get_indep_mean_std);
324   casereader_destroy (input);
325
326   return true;
327 }
328
329 /*
330   Local Variables:   
331   mode: c
332   End:
333 */