Make cases simpler, faster, and easier to understand.
[pspp-builds.git] / src / language / stats / regression.q
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2005, 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_cdf.h>
20 #include <gsl/gsl_matrix.h>
21 #include <gsl/gsl_vector.h>
22 #include <math.h>
23 #include <stdlib.h>
24
25 #include <data/case.h>
26 #include <data/casegrouper.h>
27 #include <data/casereader.h>
28 #include <data/category.h>
29 #include <data/dictionary.h>
30 #include <data/missing-values.h>
31 #include <data/procedure.h>
32 #include <data/transformations.h>
33 #include <data/value-labels.h>
34 #include <data/variable.h>
35 #include <language/command.h>
36 #include <language/dictionary/split-file.h>
37 #include <language/data-io/file-handle.h>
38 #include <language/lexer/lexer.h>
39 #include <libpspp/compiler.h>
40 #include <libpspp/message.h>
41 #include <libpspp/taint.h>
42 #include <math/design-matrix.h>
43 #include <math/coefficient.h>
44 #include <math/linreg.h>
45 #include <math/moments.h>
46 #include <output/table.h>
47
48 #include "xalloc.h"
49
50 #include "gettext.h"
51 #define _(msgid) gettext (msgid)
52
53 #define REG_LARGE_DATA 1000
54
55 /* (headers) */
56
57 /* (specification)
58    "REGRESSION" (regression_):
59    *variables=custom;
60    +statistics[st_]=r,
61                     coeff,
62                     anova,
63                     outs,
64                     zpp,
65                     label,
66                     sha,
67                     ci,
68                     bcov,
69                     ses,
70                     xtx,
71                     collin,
72                     tol,
73                     selection,
74                     f,
75                     defaults,
76                     all;
77    ^dependent=varlist;
78    +save[sv_]=resid,pred;
79    +method=enter.
80 */
81 /* (declarations) */
82 /* (functions) */
83 static struct cmd_regression cmd;
84
85 /*
86   Moments for each of the variables used.
87  */
88 struct moments_var
89 {
90   struct moments1 *m;
91   const struct variable *v;
92 };
93
94 /*
95   Transformations for saving predicted values
96   and residuals, etc.
97  */
98 struct reg_trns
99 {
100   int n_trns;                   /* Number of transformations. */
101   int trns_id;                  /* Which trns is this one? */
102   pspp_linreg_cache *c;         /* Linear model for this trns. */
103 };
104 /*
105   Variables used (both explanatory and response).
106  */
107 static const struct variable **v_variables;
108
109 /*
110   Number of variables.
111  */
112 static size_t n_variables;
113
114 static bool run_regression (struct casereader *, struct cmd_regression *,
115                             struct dataset *, pspp_linreg_cache **);
116
117 /*
118    STATISTICS subcommand output functions.
119  */
120 static void reg_stats_r (pspp_linreg_cache *);
121 static void reg_stats_coeff (pspp_linreg_cache *);
122 static void reg_stats_anova (pspp_linreg_cache *);
123 static void reg_stats_outs (pspp_linreg_cache *);
124 static void reg_stats_zpp (pspp_linreg_cache *);
125 static void reg_stats_label (pspp_linreg_cache *);
126 static void reg_stats_sha (pspp_linreg_cache *);
127 static void reg_stats_ci (pspp_linreg_cache *);
128 static void reg_stats_f (pspp_linreg_cache *);
129 static void reg_stats_bcov (pspp_linreg_cache *);
130 static void reg_stats_ses (pspp_linreg_cache *);
131 static void reg_stats_xtx (pspp_linreg_cache *);
132 static void reg_stats_collin (pspp_linreg_cache *);
133 static void reg_stats_tol (pspp_linreg_cache *);
134 static void reg_stats_selection (pspp_linreg_cache *);
135 static void statistics_keyword_output (void (*)(pspp_linreg_cache *),
136                                        int, pspp_linreg_cache *);
137
138 static void
139 reg_stats_r (pspp_linreg_cache * c)
140 {
141   struct tab_table *t;
142   int n_rows = 2;
143   int n_cols = 5;
144   double rsq;
145   double adjrsq;
146   double std_error;
147
148   assert (c != NULL);
149   rsq = c->ssm / c->sst;
150   adjrsq = 1.0 - (1.0 - rsq) * (c->n_obs - 1.0) / (c->n_obs - c->n_indeps);
151   std_error = sqrt ((c->n_indeps - 1.0) / (c->n_obs - 1.0));
152   t = tab_create (n_cols, n_rows, 0);
153   tab_dim (t, tab_natural_dimensions);
154   tab_box (t, TAL_2, TAL_2, -1, TAL_1, 0, 0, n_cols - 1, n_rows - 1);
155   tab_hline (t, TAL_2, 0, n_cols - 1, 1);
156   tab_vline (t, TAL_2, 2, 0, n_rows - 1);
157   tab_vline (t, TAL_0, 1, 0, 0);
158
159   tab_text (t, 1, 0, TAB_CENTER | TAT_TITLE, _("R"));
160   tab_text (t, 2, 0, TAB_CENTER | TAT_TITLE, _("R Square"));
161   tab_text (t, 3, 0, TAB_CENTER | TAT_TITLE, _("Adjusted R Square"));
162   tab_text (t, 4, 0, TAB_CENTER | TAT_TITLE, _("Std. Error of the Estimate"));
163   tab_float (t, 1, 1, TAB_RIGHT, sqrt (rsq), 10, 2);
164   tab_float (t, 2, 1, TAB_RIGHT, rsq, 10, 2);
165   tab_float (t, 3, 1, TAB_RIGHT, adjrsq, 10, 2);
166   tab_float (t, 4, 1, TAB_RIGHT, std_error, 10, 2);
167   tab_title (t, _("Model Summary"));
168   tab_submit (t);
169 }
170
171 /*
172   Table showing estimated regression coefficients.
173  */
174 static void
175 reg_stats_coeff (pspp_linreg_cache * c)
176 {
177   size_t j;
178   int n_cols = 7;
179   int n_rows;
180   int this_row;
181   double t_stat;
182   double pval;
183   double std_err;
184   double beta;
185   const char *label;
186
187   const struct variable *v;
188   const union value *val;
189   struct tab_table *t;
190
191   assert (c != NULL);
192   n_rows = c->n_coeffs + 3;
193
194   t = tab_create (n_cols, n_rows, 0);
195   tab_headers (t, 2, 0, 1, 0);
196   tab_dim (t, tab_natural_dimensions);
197   tab_box (t, TAL_2, TAL_2, -1, TAL_1, 0, 0, n_cols - 1, n_rows - 1);
198   tab_hline (t, TAL_2, 0, n_cols - 1, 1);
199   tab_vline (t, TAL_2, 2, 0, n_rows - 1);
200   tab_vline (t, TAL_0, 1, 0, 0);
201
202   tab_text (t, 2, 0, TAB_CENTER | TAT_TITLE, _("B"));
203   tab_text (t, 3, 0, TAB_CENTER | TAT_TITLE, _("Std. Error"));
204   tab_text (t, 4, 0, TAB_CENTER | TAT_TITLE, _("Beta"));
205   tab_text (t, 5, 0, TAB_CENTER | TAT_TITLE, _("t"));
206   tab_text (t, 6, 0, TAB_CENTER | TAT_TITLE, _("Significance"));
207   tab_text (t, 1, 1, TAB_LEFT | TAT_TITLE, _("(Constant)"));
208   tab_float (t, 2, 1, 0, c->intercept, 10, 2);
209   std_err = sqrt (gsl_matrix_get (c->cov, 0, 0));
210   tab_float (t, 3, 1, 0, std_err, 10, 2);
211   tab_float (t, 4, 1, 0, 0.0, 10, 2);
212   t_stat = c->intercept / std_err;
213   tab_float (t, 5, 1, 0, t_stat, 10, 2);
214   pval = 2 * gsl_cdf_tdist_Q (fabs (t_stat), 1.0);
215   tab_float (t, 6, 1, 0, pval, 10, 2);
216   for (j = 0; j < c->n_coeffs; j++)
217     {
218       struct string tstr;
219       ds_init_empty (&tstr);
220       this_row = j + 2;
221
222       v = pspp_coeff_get_var (c->coeff[j], 0);
223       label = var_to_string (v);
224       /* Do not overwrite the variable's name. */
225       ds_put_cstr (&tstr, label);
226       if (var_is_alpha (v))
227         {
228           /*
229              Append the value associated with this coefficient.
230              This makes sense only if we us the usual binary encoding
231              for that value.
232            */
233
234           val = pspp_coeff_get_value (c->coeff[j], v);
235
236           var_append_value_name (v, val, &tstr);
237         }
238
239       tab_text (t, 1, this_row, TAB_CENTER, ds_cstr (&tstr));
240       /*
241          Regression coefficients.
242        */
243       tab_float (t, 2, this_row, 0, c->coeff[j]->estimate, 10, 2);
244       /*
245          Standard error of the coefficients.
246        */
247       std_err = sqrt (gsl_matrix_get (c->cov, j + 1, j + 1));
248       tab_float (t, 3, this_row, 0, std_err, 10, 2);
249       /*
250          Standardized coefficient, i.e., regression coefficient
251          if all variables had unit variance.
252        */
253       beta = pspp_coeff_get_sd (c->coeff[j]);
254       beta *= c->coeff[j]->estimate / c->depvar_std;
255       tab_float (t, 4, this_row, 0, beta, 10, 2);
256
257       /*
258          Test statistic for H0: coefficient is 0.
259        */
260       t_stat = c->coeff[j]->estimate / std_err;
261       tab_float (t, 5, this_row, 0, t_stat, 10, 2);
262       /*
263          P values for the test statistic above.
264        */
265       pval =
266         2 * gsl_cdf_tdist_Q (fabs (t_stat),
267                              (double) (c->n_obs - c->n_coeffs));
268       tab_float (t, 6, this_row, 0, pval, 10, 2);
269       ds_destroy (&tstr);
270     }
271   tab_title (t, _("Coefficients"));
272   tab_submit (t);
273 }
274
275 /*
276   Display the ANOVA table.
277  */
278 static void
279 reg_stats_anova (pspp_linreg_cache * c)
280 {
281   int n_cols = 7;
282   int n_rows = 4;
283   const double msm = c->ssm / c->dfm;
284   const double mse = c->sse / c->dfe;
285   const double F = msm / mse;
286   const double pval = gsl_cdf_fdist_Q (F, c->dfm, c->dfe);
287
288   struct tab_table *t;
289
290   assert (c != NULL);
291   t = tab_create (n_cols, n_rows, 0);
292   tab_headers (t, 2, 0, 1, 0);
293   tab_dim (t, tab_natural_dimensions);
294
295   tab_box (t, TAL_2, TAL_2, -1, TAL_1, 0, 0, n_cols - 1, n_rows - 1);
296
297   tab_hline (t, TAL_2, 0, n_cols - 1, 1);
298   tab_vline (t, TAL_2, 2, 0, n_rows - 1);
299   tab_vline (t, TAL_0, 1, 0, 0);
300
301   tab_text (t, 2, 0, TAB_CENTER | TAT_TITLE, _("Sum of Squares"));
302   tab_text (t, 3, 0, TAB_CENTER | TAT_TITLE, _("df"));
303   tab_text (t, 4, 0, TAB_CENTER | TAT_TITLE, _("Mean Square"));
304   tab_text (t, 5, 0, TAB_CENTER | TAT_TITLE, _("F"));
305   tab_text (t, 6, 0, TAB_CENTER | TAT_TITLE, _("Significance"));
306
307   tab_text (t, 1, 1, TAB_LEFT | TAT_TITLE, _("Regression"));
308   tab_text (t, 1, 2, TAB_LEFT | TAT_TITLE, _("Residual"));
309   tab_text (t, 1, 3, TAB_LEFT | TAT_TITLE, _("Total"));
310
311   /* Sums of Squares */
312   tab_float (t, 2, 1, 0, c->ssm, 10, 2);
313   tab_float (t, 2, 3, 0, c->sst, 10, 2);
314   tab_float (t, 2, 2, 0, c->sse, 10, 2);
315
316
317   /* Degrees of freedom */
318   tab_text (t, 3, 1, TAB_RIGHT | TAT_PRINTF, "%g", c->dfm);
319   tab_text (t, 3, 2, TAB_RIGHT | TAT_PRINTF, "%g", c->dfe);
320   tab_text (t, 3, 3, TAB_RIGHT | TAT_PRINTF, "%g", c->dft);
321
322   /* Mean Squares */
323   tab_float (t, 4, 1, TAB_RIGHT, msm, 8, 3);
324   tab_float (t, 4, 2, TAB_RIGHT, mse, 8, 3);
325
326   tab_float (t, 5, 1, 0, F, 8, 3);
327
328   tab_float (t, 6, 1, 0, pval, 8, 3);
329
330   tab_title (t, _("ANOVA"));
331   tab_submit (t);
332 }
333
334 static void
335 reg_stats_outs (pspp_linreg_cache * c)
336 {
337   assert (c != NULL);
338 }
339
340 static void
341 reg_stats_zpp (pspp_linreg_cache * c)
342 {
343   assert (c != NULL);
344 }
345
346 static void
347 reg_stats_label (pspp_linreg_cache * c)
348 {
349   assert (c != NULL);
350 }
351
352 static void
353 reg_stats_sha (pspp_linreg_cache * c)
354 {
355   assert (c != NULL);
356 }
357 static void
358 reg_stats_ci (pspp_linreg_cache * c)
359 {
360   assert (c != NULL);
361 }
362 static void
363 reg_stats_f (pspp_linreg_cache * c)
364 {
365   assert (c != NULL);
366 }
367 static void
368 reg_stats_bcov (pspp_linreg_cache * c)
369 {
370   int n_cols;
371   int n_rows;
372   int i;
373   int k;
374   int row;
375   int col;
376   const char *label;
377   struct tab_table *t;
378
379   assert (c != NULL);
380   n_cols = c->n_indeps + 1 + 2;
381   n_rows = 2 * (c->n_indeps + 1);
382   t = tab_create (n_cols, n_rows, 0);
383   tab_headers (t, 2, 0, 1, 0);
384   tab_dim (t, tab_natural_dimensions);
385   tab_box (t, TAL_2, TAL_2, -1, TAL_1, 0, 0, n_cols - 1, n_rows - 1);
386   tab_hline (t, TAL_2, 0, n_cols - 1, 1);
387   tab_vline (t, TAL_2, 2, 0, n_rows - 1);
388   tab_vline (t, TAL_0, 1, 0, 0);
389   tab_text (t, 0, 0, TAB_CENTER | TAT_TITLE, _("Model"));
390   tab_text (t, 1, 1, TAB_CENTER | TAT_TITLE, _("Covariances"));
391   for (i = 0; i < c->n_coeffs; i++)
392     {
393       const struct variable *v = pspp_coeff_get_var (c->coeff[i], 0);
394       label = var_to_string (v);
395       tab_text (t, 2, i, TAB_CENTER, label);
396       tab_text (t, i + 2, 0, TAB_CENTER, label);
397       for (k = 1; k < c->n_coeffs; k++)
398         {
399           col = (i <= k) ? k : i;
400           row = (i <= k) ? i : k;
401           tab_float (t, k + 2, i, TAB_CENTER,
402                      gsl_matrix_get (c->cov, row, col), 8, 3);
403         }
404     }
405   tab_title (t, _("Coefficient Correlations"));
406   tab_submit (t);
407 }
408 static void
409 reg_stats_ses (pspp_linreg_cache * c)
410 {
411   assert (c != NULL);
412 }
413 static void
414 reg_stats_xtx (pspp_linreg_cache * c)
415 {
416   assert (c != NULL);
417 }
418 static void
419 reg_stats_collin (pspp_linreg_cache * c)
420 {
421   assert (c != NULL);
422 }
423 static void
424 reg_stats_tol (pspp_linreg_cache * c)
425 {
426   assert (c != NULL);
427 }
428 static void
429 reg_stats_selection (pspp_linreg_cache * c)
430 {
431   assert (c != NULL);
432 }
433
434 static void
435 statistics_keyword_output (void (*function) (pspp_linreg_cache *),
436                            int keyword, pspp_linreg_cache * c)
437 {
438   if (keyword)
439     {
440       (*function) (c);
441     }
442 }
443
444 static void
445 subcommand_statistics (int *keywords, pspp_linreg_cache * c)
446 {
447   /*
448      The order here must match the order in which the STATISTICS
449      keywords appear in the specification section above.
450    */
451   enum
452   { r,
453     coeff,
454     anova,
455     outs,
456     zpp,
457     label,
458     sha,
459     ci,
460     bcov,
461     ses,
462     xtx,
463     collin,
464     tol,
465     selection,
466     f,
467     defaults,
468     all
469   };
470   int i;
471   int d = 1;
472
473   if (keywords[all])
474     {
475       /*
476          Set everything but F.
477        */
478       for (i = 0; i < f; i++)
479         {
480           keywords[i] = 1;
481         }
482     }
483   else
484     {
485       for (i = 0; i < all; i++)
486         {
487           if (keywords[i])
488             {
489               d = 0;
490             }
491         }
492       /*
493          Default output: ANOVA table, parameter estimates,
494          and statistics for variables not entered into model,
495          if appropriate.
496        */
497       if (keywords[defaults] | d)
498         {
499           keywords[anova] = 1;
500           keywords[outs] = 1;
501           keywords[coeff] = 1;
502           keywords[r] = 1;
503         }
504     }
505   statistics_keyword_output (reg_stats_r, keywords[r], c);
506   statistics_keyword_output (reg_stats_anova, keywords[anova], c);
507   statistics_keyword_output (reg_stats_coeff, keywords[coeff], c);
508   statistics_keyword_output (reg_stats_outs, keywords[outs], c);
509   statistics_keyword_output (reg_stats_zpp, keywords[zpp], c);
510   statistics_keyword_output (reg_stats_label, keywords[label], c);
511   statistics_keyword_output (reg_stats_sha, keywords[sha], c);
512   statistics_keyword_output (reg_stats_ci, keywords[ci], c);
513   statistics_keyword_output (reg_stats_f, keywords[f], c);
514   statistics_keyword_output (reg_stats_bcov, keywords[bcov], c);
515   statistics_keyword_output (reg_stats_ses, keywords[ses], c);
516   statistics_keyword_output (reg_stats_xtx, keywords[xtx], c);
517   statistics_keyword_output (reg_stats_collin, keywords[collin], c);
518   statistics_keyword_output (reg_stats_tol, keywords[tol], c);
519   statistics_keyword_output (reg_stats_selection, keywords[selection], c);
520 }
521
522 /*
523   Free the transformation. Free its linear model if this
524   transformation is the last one.
525  */
526 static bool
527 regression_trns_free (void *t_)
528 {
529   bool result = true;
530   struct reg_trns *t = t_;
531
532   if (t->trns_id == t->n_trns)
533     {
534       result = pspp_linreg_cache_free (t->c);
535     }
536   free (t);
537
538   return result;
539 }
540
541 /*
542   Gets the predicted values.
543  */
544 static int
545 regression_trns_pred_proc (void *t_, struct ccase **c,
546                            casenumber case_idx UNUSED)
547 {
548   size_t i;
549   size_t n_vals;
550   struct reg_trns *trns = t_;
551   pspp_linreg_cache *model;
552   union value *output = NULL;
553   const union value **vals = NULL;
554   const struct variable **vars = NULL;
555
556   assert (trns != NULL);
557   model = trns->c;
558   assert (model != NULL);
559   assert (model->depvar != NULL);
560   assert (model->pred != NULL);
561
562   vars = xnmalloc (model->n_coeffs, sizeof (*vars));
563   n_vals = (*model->get_vars) (model, vars);
564
565   vals = xnmalloc (n_vals, sizeof (*vals));
566   *c = case_unshare (*c);
567   output = case_data_rw (*c, model->pred);
568
569   for (i = 0; i < n_vals; i++)
570     {
571       vals[i] = case_data (*c, vars[i]);
572     }
573   output->f = (*model->predict) ((const struct variable **) vars,
574                                  vals, model, n_vals);
575   free (vals);
576   free (vars);
577   return TRNS_CONTINUE;
578 }
579
580 /*
581   Gets the residuals.
582  */
583 static int
584 regression_trns_resid_proc (void *t_, struct ccase **c,
585                             casenumber case_idx UNUSED)
586 {
587   size_t i;
588   size_t n_vals;
589   struct reg_trns *trns = t_;
590   pspp_linreg_cache *model;
591   union value *output = NULL;
592   const union value **vals = NULL;
593   const union value *obs = NULL;
594   const struct variable **vars = NULL;
595
596   assert (trns != NULL);
597   model = trns->c;
598   assert (model != NULL);
599   assert (model->depvar != NULL);
600   assert (model->resid != NULL);
601
602   vars = xnmalloc (model->n_coeffs, sizeof (*vars));
603   n_vals = (*model->get_vars) (model, vars);
604
605   vals = xnmalloc (n_vals, sizeof (*vals));
606   *c = case_unshare (*c);
607   output = case_data_rw (*c, model->resid);
608   assert (output != NULL);
609
610   for (i = 0; i < n_vals; i++)
611     {
612       vals[i] = case_data (*c, vars[i]);
613     }
614   obs = case_data (*c, model->depvar);
615   output->f = (*model->residual) ((const struct variable **) vars,
616                                   vals, obs, model, n_vals);
617   free (vals);
618   free (vars);
619   return TRNS_CONTINUE;
620 }
621
622 /*
623    Returns false if NAME is a duplicate of any existing variable name.
624 */
625 static bool
626 try_name (const struct dictionary *dict, const char *name)
627 {
628   if (dict_lookup_var (dict, name) != NULL)
629     return false;
630
631   return true;
632 }
633
634 static void
635 reg_get_name (const struct dictionary *dict, char name[VAR_NAME_LEN],
636               const char prefix[VAR_NAME_LEN])
637 {
638   int i = 1;
639
640   snprintf (name, VAR_NAME_LEN, "%s%d", prefix, i);
641   while (!try_name (dict, name))
642     {
643       i++;
644       snprintf (name, VAR_NAME_LEN, "%s%d", prefix, i);
645     }
646 }
647
648 static void
649 reg_save_var (struct dataset *ds, const char *prefix, trns_proc_func * f,
650               pspp_linreg_cache * c, struct variable **v, int n_trns)
651 {
652   struct dictionary *dict = dataset_dict (ds);
653   static int trns_index = 1;
654   char name[VAR_NAME_LEN];
655   struct variable *new_var;
656   struct reg_trns *t = NULL;
657
658   t = xmalloc (sizeof (*t));
659   t->trns_id = trns_index;
660   t->n_trns = n_trns;
661   t->c = c;
662   reg_get_name (dict, name, prefix);
663   new_var = dict_create_var (dict, name, 0);
664   assert (new_var != NULL);
665   *v = new_var;
666   add_transformation (ds, f, regression_trns_free, t);
667   trns_index++;
668 }
669 static void
670 subcommand_save (struct dataset *ds, int save, pspp_linreg_cache ** models)
671 {
672   pspp_linreg_cache **lc;
673   int n_trns = 0;
674   int i;
675
676   assert (models != NULL);
677
678   if (save)
679     {
680       /* Count the number of transformations we will need. */
681       for (i = 0; i < REGRESSION_SV_count; i++)
682         {
683           if (cmd.a_save[i])
684             {
685               n_trns++;
686             }
687         }
688       n_trns *= cmd.n_dependent;
689
690       for (lc = models; lc < models + cmd.n_dependent; lc++)
691         {
692           assert (*lc != NULL);
693           assert ((*lc)->depvar != NULL);
694           if (cmd.a_save[REGRESSION_SV_RESID])
695             {
696               reg_save_var (ds, "RES", regression_trns_resid_proc, *lc,
697                             &(*lc)->resid, n_trns);
698             }
699           if (cmd.a_save[REGRESSION_SV_PRED])
700             {
701               reg_save_var (ds, "PRED", regression_trns_pred_proc, *lc,
702                             &(*lc)->pred, n_trns);
703             }
704         }
705     }
706   else
707     {
708       for (lc = models; lc < models + cmd.n_dependent; lc++)
709         {
710           if (*lc != NULL)
711             {
712               pspp_linreg_cache_free (*lc);
713             }
714         }
715     }
716 }
717
718 int
719 cmd_regression (struct lexer *lexer, struct dataset *ds)
720 {
721   struct casegrouper *grouper;
722   struct casereader *group;
723   pspp_linreg_cache **models;
724   bool ok;
725   size_t i;
726
727   if (!parse_regression (lexer, ds, &cmd, NULL))
728     {
729       return CMD_FAILURE;
730     }
731
732   models = xnmalloc (cmd.n_dependent, sizeof *models);
733   for (i = 0; i < cmd.n_dependent; i++)
734     {
735       models[i] = NULL;
736     }
737
738   /* Data pass. */
739   grouper = casegrouper_create_splits (proc_open (ds), dataset_dict (ds));
740   while (casegrouper_get_next_group (grouper, &group))
741     run_regression (group, &cmd, ds, models);
742   ok = casegrouper_destroy (grouper);
743   ok = proc_commit (ds) && ok;
744
745   subcommand_save (ds, cmd.sbc_save, models);
746   free (v_variables);
747   free (models);
748   free_regression (&cmd);
749
750   return ok ? CMD_SUCCESS : CMD_FAILURE;
751 }
752
753 /*
754   Is variable k the dependent variable?
755  */
756 static bool
757 is_depvar (size_t k, const struct variable *v)
758 {
759   return v == v_variables[k];
760 }
761
762 /* Parser for the variables sub command */
763 static int
764 regression_custom_variables (struct lexer *lexer, struct dataset *ds,
765                              struct cmd_regression *cmd UNUSED,
766                              void *aux UNUSED)
767 {
768   const struct dictionary *dict = dataset_dict (ds);
769
770   lex_match (lexer, '=');
771
772   if ((lex_token (lexer) != T_ID
773        || dict_lookup_var (dict, lex_tokid (lexer)) == NULL)
774       && lex_token (lexer) != T_ALL)
775     return 2;
776
777
778   if (!parse_variables_const
779       (lexer, dict, &v_variables, &n_variables, PV_NONE))
780     {
781       free (v_variables);
782       return 0;
783     }
784   assert (n_variables);
785
786   return 1;
787 }
788
789 /* Identify the explanatory variables in v_variables.  Returns
790    the number of independent variables. */
791 static int
792 identify_indep_vars (const struct variable **indep_vars,
793                      const struct variable *depvar)
794 {
795   int n_indep_vars = 0;
796   int i;
797
798   for (i = 0; i < n_variables; i++)
799     if (!is_depvar (i, depvar))
800       indep_vars[n_indep_vars++] = v_variables[i];
801   if ((n_indep_vars < 1) && is_depvar (0, depvar))
802     {
803       /*
804         There is only one independent variable, and it is the same
805         as the dependent variable. Print a warning and continue.
806        */
807       msg (SE,
808            gettext ("The dependent variable is equal to the independent variable." 
809                     "The least squares line is therefore Y=X." 
810                     "Standard errors and related statistics may be meaningless."));
811       n_indep_vars = 1;
812       indep_vars[0] = v_variables[0];
813     }
814   return n_indep_vars;
815 }
816
817 /* Encode categorical variables.
818    Returns number of valid cases. */
819 static int
820 prepare_categories (struct casereader *input,
821                     const struct variable **vars, size_t n_vars,
822                     struct moments_var *mom)
823 {
824   int n_data;
825   struct ccase *c;
826   size_t i;
827
828   assert (vars != NULL);
829   assert (mom != NULL);
830
831   for (i = 0; i < n_vars; i++)
832     if (var_is_alpha (vars[i]))
833       cat_stored_values_create (vars[i]);
834
835   n_data = 0;
836   for (; (c = casereader_read (input)) != NULL; case_unref (c))
837     {
838       /*
839          The second condition ensures the program will run even if
840          there is only one variable to act as both explanatory and
841          response.
842        */
843       for (i = 0; i < n_vars; i++)
844         {
845           const union value *val = case_data (c, vars[i]);
846           if (var_is_alpha (vars[i]))
847             cat_value_update (vars[i], val);
848           else
849             moments1_add (mom[i].m, val->f, 1.0);
850         }
851       n_data++;
852     }
853   casereader_destroy (input);
854
855   return n_data;
856 }
857
858 static void
859 coeff_init (pspp_linreg_cache * c, struct design_matrix *dm)
860 {
861   c->coeff = xnmalloc (dm->m->size2, sizeof (*c->coeff));
862   pspp_coeff_init (c->coeff, dm);
863 }
864
865 static bool
866 run_regression (struct casereader *input, struct cmd_regression *cmd,
867                 struct dataset *ds, pspp_linreg_cache **models)
868 {
869   size_t i;
870   int n_indep = 0;
871   int k;
872   struct ccase *c;
873   const struct variable **indep_vars;
874   struct design_matrix *X;
875   struct moments_var *mom;
876   gsl_vector *Y;
877
878   pspp_linreg_opts lopts;
879
880   assert (models != NULL);
881
882   c = casereader_peek (input, 0);
883   if (c == NULL)
884     {
885       casereader_destroy (input);
886       return true;
887     }
888   output_split_file_values (ds, c);
889   case_unref (c);
890
891   if (!v_variables)
892     {
893       dict_get_vars (dataset_dict (ds), &v_variables, &n_variables, 0);
894     }
895
896   for (i = 0; i < cmd->n_dependent; i++)
897     {
898       if (!var_is_numeric (cmd->v_dependent[i]))
899         {
900           msg (SE, _("Dependent variable must be numeric."));
901           return false;
902         }
903     }
904
905   mom = xnmalloc (n_variables, sizeof (*mom));
906   for (i = 0; i < n_variables; i++)
907     {
908       (mom + i)->m = moments1_create (MOMENT_VARIANCE);
909       (mom + i)->v = v_variables[i];
910     }
911   lopts.get_depvar_mean_std = 1;
912
913   lopts.get_indep_mean_std = xnmalloc (n_variables, sizeof (int));
914   indep_vars = xnmalloc (n_variables, sizeof *indep_vars);
915
916   for (k = 0; k < cmd->n_dependent; k++)
917     {
918       const struct variable *dep_var;
919       struct casereader *reader;
920       casenumber row;
921       struct ccase *c;
922       size_t n_data;            /* Number of valid cases. */
923
924       dep_var = cmd->v_dependent[k];
925       n_indep = identify_indep_vars (indep_vars, dep_var);
926       reader = casereader_clone (input);
927       reader = casereader_create_filter_missing (reader, indep_vars, n_indep,
928                                                  MV_ANY, NULL, NULL);
929       reader = casereader_create_filter_missing (reader, &dep_var, 1,
930                                                  MV_ANY, NULL, NULL);
931       n_data = prepare_categories (casereader_clone (reader),
932                                    indep_vars, n_indep, mom);
933
934       if ((n_data > 0) && (n_indep > 0))
935         {
936           Y = gsl_vector_alloc (n_data);
937           X =
938             design_matrix_create (n_indep,
939                                   (const struct variable **) indep_vars,
940                                   n_data);
941           for (i = 0; i < X->m->size2; i++)
942             {
943               lopts.get_indep_mean_std[i] = 1;
944             }
945           models[k] = pspp_linreg_cache_alloc (dep_var, (const struct variable **) indep_vars,
946                                                X->m->size1, X->m->size2);
947           models[k]->depvar = dep_var;
948           /*
949              For large data sets, use QR decomposition.
950            */
951           if (n_data > sqrt (n_indep) && n_data > REG_LARGE_DATA)
952             {
953               models[k]->method = PSPP_LINREG_QR;
954             }
955
956           /*
957              The second pass fills the design matrix.
958            */
959           reader = casereader_create_counter (reader, &row, -1);
960           for (; (c = casereader_read (reader)) != NULL; case_unref (c))
961             {
962               for (i = 0; i < n_indep; ++i)
963                 {
964                   const struct variable *v = indep_vars[i];
965                   const union value *val = case_data (c, v);
966                   if (var_is_alpha (v))
967                     design_matrix_set_categorical (X, row, v, val);
968                   else
969                     design_matrix_set_numeric (X, row, v, val);
970                 }
971               gsl_vector_set (Y, row, case_num (c, dep_var));
972             }
973           /*
974              Now that we know the number of coefficients, allocate space
975              and store pointers to the variables that correspond to the
976              coefficients.
977            */
978           coeff_init (models[k], X);
979
980           /*
981              Find the least-squares estimates and other statistics.
982            */
983           pspp_linreg ((const gsl_vector *) Y, X, &lopts, models[k]);
984
985           if (!taint_has_tainted_successor (casereader_get_taint (input)))
986             {
987               subcommand_statistics (cmd->a_statistics, models[k]);
988             }
989
990           gsl_vector_free (Y);
991           design_matrix_destroy (X);
992         }
993       else
994         {
995           msg (SE,
996                gettext ("No valid data found. This command was skipped."));
997         }
998       casereader_destroy (reader);
999     }
1000   for (i = 0; i < n_variables; i++)
1001     {
1002       moments1_destroy ((mom + i)->m);
1003     }
1004   free (mom);
1005   free (indep_vars);
1006   free (lopts.get_indep_mean_std);
1007   casereader_destroy (input);
1008
1009   return true;
1010 }
1011
1012 /*
1013   Local Variables:
1014   mode: c
1015   End:
1016 */