Merge branch 'master' of ssh://jmd@git.sv.gnu.org/srv/git/pspp
[pspp-builds.git] / src / language / stats / glm.q
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2007 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_cdf.h>
20 #include <gsl/gsl_matrix.h>
21 #include <gsl/gsl_vector.h>
22 #include <math.h>
23 #include <stdlib.h>
24
25 #include <data/case.h>
26 #include <data/category.h>
27 #include <data/casegrouper.h>
28 #include <data/casereader.h>
29 #include <data/dictionary.h>
30 #include <data/missing-values.h>
31 #include <data/procedure.h>
32 #include <data/transformations.h>
33 #include <data/value-labels.h>
34 #include <data/variable.h>
35 #include <language/command.h>
36 #include <language/dictionary/split-file.h>
37 #include <language/data-io/file-handle.h>
38 #include <language/lexer/lexer.h>
39 #include <libpspp/compiler.h>
40 #include <libpspp/hash.h>
41 #include <libpspp/message.h>
42 #include <math/covariance-matrix.h>
43 #include <math/coefficient.h>
44 #include <math/linreg.h>
45 #include <math/moments.h>
46 #include <output/table.h>
47
48 #include "xalloc.h"
49 #include "gettext.h"
50
51 /* (headers) */
52
53 /* (specification)
54    "GLM" (glm_):
55    *dependent=custom;
56    by=varlist;
57    with=varlist.
58 */
59 /* (declarations) */
60 /* (functions) */
61 static struct cmd_glm cmd;
62
63 /*
64   Moments for each of the variables used.
65  */
66 struct moments_var
67 {
68   struct moments1 *m;
69   double *weight;
70   double *mean;
71   double *variance;
72   const struct variable *v;
73 };
74
75
76 /*
77   Dependent variable used.
78  */
79 static const struct variable **v_dependent;
80
81 /*
82   Number of dependent variables.
83  */
84 static size_t n_dependent;
85
86 #if 0
87 /*
88   Return value for the procedure.
89  */
90 static int pspp_glm_rc = CMD_SUCCESS;
91 #else
92 int cmd_glm (struct lexer *lexer, struct dataset *ds);
93 #endif
94
95 static bool run_glm (struct casereader *,
96                      struct cmd_glm *,
97                      const struct dataset *);
98
99 int
100 cmd_glm (struct lexer *lexer, struct dataset *ds)
101 {
102   struct casegrouper *grouper;
103   struct casereader *group;
104
105   bool ok;
106
107   if (!parse_glm (lexer, ds, &cmd, NULL))
108     return CMD_FAILURE;
109
110   /* Data pass. */
111   grouper = casegrouper_create_splits (proc_open (ds), dataset_dict (ds));
112   while (casegrouper_get_next_group (grouper, &group))
113     {
114       run_glm (group, &cmd, ds);
115     }
116   ok = casegrouper_destroy (grouper);
117   ok = proc_commit (ds) && ok;
118
119   free (v_dependent);
120   return ok ? CMD_SUCCESS : CMD_FAILURE;
121 }
122
123 /* Parser for the dependent sub command */
124 static int
125 glm_custom_dependent (struct lexer *lexer, struct dataset *ds,
126                       struct cmd_glm *cmd UNUSED, void *aux UNUSED)
127 {
128   const struct dictionary *dict = dataset_dict (ds);
129
130   if ((lex_token (lexer) != T_ID
131        || dict_lookup_var (dict, lex_tokid (lexer)) == NULL)
132       && lex_token (lexer) != T_ALL)
133     return 2;
134
135   if (!parse_variables_const
136       (lexer, dict, &v_dependent, &n_dependent, PV_NONE))
137     {
138       free (v_dependent);
139       return 0;
140     }
141   assert (n_dependent);
142   if (n_dependent > 1)
143     msg (SE, _("Multivariate GLM not yet supported"));
144   n_dependent = 1;              /* Drop this line after adding support for multivariate GLM. */
145
146   return 1;
147 }
148
149 /*
150   COV is the covariance matrix for variables included in the
151   model. That means the dependent variable is in there, too.
152  */
153 static void
154 coeff_init (pspp_linreg_cache * c, const struct design_matrix *cov)
155 {
156   c->coeff = xnmalloc (cov->m->size2, sizeof (*c->coeff));
157   c->n_coeffs = cov->m->size2 - 1;
158   pspp_coeff_init (c->coeff, cov);
159 }
160
161 /* Encode categorical variables.
162    Returns number of valid cases. */
163 static int
164 data_pass_one (struct casereader *input,
165                const struct variable **vars, size_t n_vars,
166                struct moments_var **mom)
167 {
168   int n_data;
169   struct ccase c;
170   size_t i;
171
172   for (i = 0; i < n_vars; i++)
173     {
174       mom[i] = xmalloc (sizeof (*mom[i]));
175       mom[i]->v = vars[i];
176       mom[i]->mean = xmalloc (sizeof (*mom[i]->mean));
177       mom[i]->variance = xmalloc (sizeof (*mom[i]->mean));
178       mom[i]->weight = xmalloc (sizeof (*mom[i]->weight));
179       mom[i]->m = moments1_create (MOMENT_VARIANCE);
180       if (var_is_alpha (vars[i]))
181         cat_stored_values_create (vars[i]);
182     }
183
184   n_data = 0;
185   for (; casereader_read (input, &c); case_destroy (&c))
186     {
187       /*
188          The second condition ensures the program will run even if
189          there is only one variable to act as both explanatory and
190          response.
191        */
192       for (i = 0; i < n_vars; i++)
193         {
194           const union value *val = case_data (&c, vars[i]);
195           if (var_is_alpha (vars[i]))
196             cat_value_update (vars[i], val);
197           else
198             moments1_add (mom[i]->m, val->f, 1.0);
199         }
200       n_data++;
201     }
202   casereader_destroy (input);
203   for (i = 0; i < n_vars; i++)
204     {
205       if (var_is_numeric (mom[i]->v))
206         {
207           moments1_calculate (mom[i]->m, mom[i]->weight, mom[i]->mean,
208                               mom[i]->variance, NULL, NULL);
209         }
210     }
211
212   return n_data;
213 }
214
215 static pspp_linreg_cache *
216 fit_model (const struct design_matrix *cov, const struct moments1 **mom, 
217            const struct variable *dep_var, 
218            const struct variable ** indep_vars, 
219            size_t n_data, size_t n_indep)
220 {
221   pspp_linreg_cache *result = NULL;
222   result = pspp_linreg_cache_alloc (dep_var, indep_vars, n_data, n_indep);
223   coeff_init (result, cov);
224   pspp_linreg_with_cov (cov, result);  
225   
226   return result;
227 }
228
229 static bool
230 run_glm (struct casereader *input,
231          struct cmd_glm *cmd,
232          const struct dataset *ds)
233 {
234   casenumber row;
235   const struct variable **indep_vars;
236   const struct variable **all_vars;
237   int n_indep = 0;
238   pspp_linreg_cache *model = NULL; 
239   pspp_linreg_opts lopts;
240   struct ccase c;
241   size_t i;
242   size_t n_all_vars;
243   size_t n_data;                /* Number of valid cases. */
244   struct casereader *reader;
245   struct design_matrix *cov;
246   struct hsh_table *cov_hash;
247   struct moments1 **mom;
248
249   if (!casereader_peek (input, 0, &c))
250     {
251       casereader_destroy (input);
252       return true;
253     }
254   output_split_file_values (ds, &c);
255   case_destroy (&c);
256
257   if (!v_dependent)
258     {
259       dict_get_vars (dataset_dict (ds), &v_dependent, &n_dependent,
260                      1u << DC_SYSTEM);
261     }
262
263   lopts.get_depvar_mean_std = 1;
264
265   lopts.get_indep_mean_std = xnmalloc (n_dependent, sizeof (int));
266   indep_vars = xnmalloc (cmd->n_by, sizeof *indep_vars);
267   n_all_vars = cmd->n_by + n_dependent;
268   all_vars = xnmalloc (n_all_vars, sizeof *all_vars);
269
270   for (i = 0; i < n_dependent; i++)
271     {
272       all_vars[i] = v_dependent[i];
273     }
274   for (i = 0; i < cmd->n_by; i++)
275     {
276       indep_vars[i] = cmd->v_by[i];
277       all_vars[i + n_dependent] = cmd->v_by[i];
278     }
279   n_indep = cmd->n_by;
280   mom = xnmalloc (n_all_vars, sizeof (*mom));
281   for (i = 0; i < n_all_vars; i++)
282     mom[i] = moments1_create (MOMENT_MEAN);
283
284   reader = casereader_clone (input);
285   reader = casereader_create_filter_missing (reader, indep_vars, n_indep,
286                                              MV_ANY, NULL, NULL);
287   reader = casereader_create_filter_missing (reader, v_dependent, 1,
288                                              MV_ANY, NULL, NULL);
289
290   if (n_indep > 0)
291     {
292       for (i = 0; i < n_all_vars; i++)
293         if (var_is_alpha (all_vars[i]))
294           cat_stored_values_create (all_vars[i]);
295       
296       cov_hash = covariance_hsh_create (n_all_vars);
297       reader = casereader_create_counter (reader, &row, -1);
298       for (; casereader_read (reader, &c); case_destroy (&c))
299         {
300           /* 
301              Accumulate the covariance matrix.
302           */
303           covariance_accumulate (cov_hash, mom, &c, all_vars, n_all_vars);
304           n_data++;
305         }
306       cov = covariance_accumulator_to_matrix (cov_hash, mom, all_vars, n_all_vars, n_data);
307
308       hsh_destroy (cov_hash);
309       for (i = 0; i < n_dependent; i++)
310         {
311           model = fit_model (cov, mom, v_dependent[i], indep_vars, n_data, n_indep);
312           pspp_linreg_cache_free (model);
313         }
314
315       casereader_destroy (reader);
316       for (i = 0; i < n_all_vars; i++)
317         {
318           moments1_destroy (mom[i]);
319         }
320       free (mom);
321       covariance_matrix_destroy (cov);
322     }
323   else
324     {
325       msg (SE, gettext ("No valid data found. This command was skipped."));
326     }
327   free (indep_vars);
328   free (lopts.get_indep_mean_std);
329   casereader_destroy (input);
330
331   return true;
332 }
333
334 /*
335   Local Variables:   
336   mode: c
337   End:
338 */