logistic.c: minor refactoring
[pspp] / src / language / stats / logistic.c
1 /* pspp - a program for statistical analysis.
2    Copyright (C) 2012 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17
18 /* 
19    References: 
20    1. "Coding Logistic Regression with Newton-Raphson", James McCaffrey
21    http://msdn.microsoft.com/en-us/magazine/jj618304.aspx
22
23    2. "SPSS Statistical Algorithms" Chapter LOGISTIC REGRESSION Algorithms
24
25
26    The Newton Raphson method finds successive approximations to $\bf b$ where 
27    approximation ${\bf b}_t$ is (hopefully) better than the previous ${\bf b}_{t-1}$.
28
29    $ {\bf b}_t = {\bf b}_{t -1} + ({\bf X}^T{\bf W}_{t-1}{\bf X})^{-1}{\bf X}^T({\bf y} - {\bf \pi}_{t-1})$
30    where:
31
32    $\bf X$ is the $n \times p$ design matrix, $n$ being the number of cases, 
33    $p$ the number of parameters, \par
34    $\bf W$ is the diagonal matrix whose diagonal elements are
35    $\hat{\pi}_0(1 - \hat{\pi}_0), \, \hat{\pi}_1(1 - \hat{\pi}_2)\dots \hat{\pi}_{n-1}(1 - \hat{\pi}_{n-1})$
36    \par
37
38 */
39
40 #include <config.h>
41
42 #include <gsl/gsl_blas.h> 
43
44 #include <gsl/gsl_linalg.h>
45 #include <gsl/gsl_cdf.h>
46 #include <gsl/gsl_matrix.h>
47 #include <gsl/gsl_vector.h>
48 #include <math.h>
49
50 #include "data/case.h"
51 #include "data/casegrouper.h"
52 #include "data/casereader.h"
53 #include "data/dataset.h"
54 #include "data/dictionary.h"
55 #include "data/format.h"
56 #include "data/value.h"
57 #include "language/command.h"
58 #include "language/dictionary/split-file.h"
59 #include "language/lexer/lexer.h"
60 #include "language/lexer/value-parser.h"
61 #include "language/lexer/variable-parser.h"
62 #include "libpspp/assertion.h"
63 #include "libpspp/ll.h"
64 #include "libpspp/message.h"
65 #include "libpspp/misc.h"
66 #include "math/categoricals.h"
67 #include "math/interaction.h"
68 #include "libpspp/hmap.h"
69 #include "libpspp/hash-functions.h"
70
71 #include "output/tab.h"
72
73 #include "gettext.h"
74 #define _(msgid) gettext (msgid)
75
76
77
78
79 #define   PRINT_EACH_STEP  0x01
80 #define   PRINT_SUMMARY    0x02
81 #define   PRINT_CORR       0x04
82 #define   PRINT_ITER       0x08
83 #define   PRINT_GOODFIT    0x10
84 #define   PRINT_CI         0x20
85
86
87 #define PRINT_DEFAULT (PRINT_SUMMARY | PRINT_EACH_STEP)
88
89 /*
90   The constant parameters of the procedure.
91   That is, those which are set by the user.
92 */
93 struct lr_spec
94 {
95   /* The dependent variable */
96   const struct variable *dep_var;
97
98   /* The predictor variables (excluding categorical ones) */
99   const struct variable **predictor_vars;
100   size_t n_predictor_vars;
101
102   /* The categorical predictors */
103   struct interaction **cat_predictors;
104   size_t n_cat_predictors;
105
106
107   /* The union of the categorical and non-categorical variables */
108   const struct variable **indep_vars;
109   size_t n_indep_vars;
110
111
112   /* Which classes of missing vars are to be excluded */
113   enum mv_class exclude;
114
115   /* The weight variable */
116   const struct variable *wv;
117
118   /* The dictionary of the dataset */
119   const struct dictionary *dict;
120
121   /* True iff the constant (intercept) is to be included in the model */
122   bool constant;
123
124   /* Ths maximum number of iterations */
125   int max_iter;
126
127   /* Other iteration limiting conditions */
128   double bcon;
129   double min_epsilon;
130   double lcon;
131
132   /* The confidence interval (in percent) */
133   int confidence;
134
135   /* What results should be presented */
136   unsigned int print;
137
138   double cut_point;
139 };
140
141
142 /* The results and intermediate result of the procedure.
143    These are mutated as the procedure runs. Used for
144    temporary variables etc.
145 */
146 struct lr_result
147 {
148   /* Used to indicate if a pass should flag a warning when 
149      invalid (ie negative or missing) weight values are encountered */
150   bool warn_bad_weight;
151
152   /* The two values of the dependent variable. */
153   union value y0;
154   union value y1;
155
156
157   /* The sum of caseweights */
158   double cc;
159
160   /* The number of missing and nonmissing cases */
161   casenumber n_missing;
162   casenumber n_nonmissing;
163
164
165   gsl_matrix *hessian;
166
167   /* The categoricals and their payload. Null if  the analysis has no
168    categorical predictors */
169   struct categoricals *cats;
170   struct payload cp;
171
172
173   /* The estimates of the predictor coefficients */
174   gsl_vector *beta_hat;
175 };
176
177
178 /*
179   Convert INPUT into a dichotomous scalar, according to how the dependent variable's
180   values are mapped.
181   For simple cases, this is a 1:1 mapping
182   The return value is always either 0 or 1
183 */
184 static double
185 map_dependent_var (const struct lr_spec *cmd, const struct lr_result *res, const union value *input)
186 {
187   const int width = var_get_width (cmd->dep_var);
188   if (value_equal (input, &res->y0, width))
189     return 0;
190
191   if (value_equal (input, &res->y1, width))
192     return 1;
193
194   /* This should never happen.  If it does,  then y0 and/or y1 have probably not been set */
195   NOT_REACHED ();
196
197   return SYSMIS;
198 }
199
200
201 static void output_categories (const struct lr_spec *cmd, const struct lr_result *res);
202
203 static void output_depvarmap (const struct lr_spec *cmd, const struct lr_result *);
204
205 static void output_variables (const struct lr_spec *cmd, 
206                               const struct lr_result *);
207
208 static void output_model_summary (const struct lr_result *,
209                                   double initial_likelihood, double likelihood);
210
211 static void case_processing_summary (const struct lr_result *);
212
213
214 /* Return the value of case C corresponding to the INDEX'th entry in the
215    model */
216 static double
217 predictor_value (const struct ccase *c, 
218                     const struct variable **x, size_t n_x, 
219                     const struct categoricals *cats,
220                     size_t index)
221 {
222   /* Values of the scalar predictor variables */
223   if (index < n_x) 
224     return case_data (c, x[index])->f;
225
226   /* Coded values of categorical predictor variables (or interactions) */
227   if (cats && index - n_x  < categoricals_df_total (cats))
228     {
229       double x = categoricals_get_dummy_code_for_case (cats, index - n_x, c);
230       return x;
231     }
232
233   /* The constant term */
234   return 1.0;
235 }
236
237
238 /*
239   Return the probability beta_hat (that is the estimator logit(y) )
240   corresponding to the coefficient estimator for case C
241 */
242 static double 
243 pi_hat (const struct lr_spec *cmd, 
244         const struct lr_result *res,
245         const struct variable **x, size_t n_x,
246         const struct ccase *c)
247 {
248   int v0;
249   double pi = 0;
250   size_t n_coeffs = res->beta_hat->size;
251
252   if (cmd->constant)
253     {
254       pi += gsl_vector_get (res->beta_hat, res->beta_hat->size - 1);
255       n_coeffs--;
256     }
257   
258   for (v0 = 0; v0 < n_coeffs; ++v0)
259     {
260       pi += gsl_vector_get (res->beta_hat, v0) * 
261         predictor_value (c, x, n_x, res->cats, v0);
262     }
263
264   pi = 1.0 / (1.0 + exp(-pi));
265
266   return pi;
267 }
268
269
270 /*
271   Calculates the Hessian matrix X' V  X,
272   where: X is the n by N_X matrix comprising the n cases in INPUT
273   V is a diagonal matrix { (pi_hat_0)(1 - pi_hat_0), (pi_hat_1)(1 - pi_hat_1), ... (pi_hat_{N-1})(1 - pi_hat_{N-1})} 
274   (the partial derivative of the predicted values)
275
276   If ALL predicted values derivatives are close to zero or one, then CONVERGED
277   will be set to true.
278 */
279 static void
280 hessian (const struct lr_spec *cmd, 
281          struct lr_result *res,
282          struct casereader *input,
283          const struct variable **x, size_t n_x,
284          bool *converged)
285 {
286   struct casereader *reader;
287   struct ccase *c;
288
289   double max_w = -DBL_MAX;
290
291   gsl_matrix_set_zero (res->hessian);
292
293   for (reader = casereader_clone (input);
294        (c = casereader_read (reader)) != NULL; case_unref (c))
295     {
296       int v0, v1;
297       double pi = pi_hat (cmd, res, x, n_x, c);
298
299       double weight = dict_get_case_weight (cmd->dict, c, &res->warn_bad_weight);
300       double w = pi * (1 - pi);
301       if (w > max_w)
302         max_w = w;
303       w *= weight;
304
305       for (v0 = 0; v0 < res->beta_hat->size; ++v0)
306         {
307           double in0 = predictor_value (c, x, n_x, res->cats, v0);
308           for (v1 = 0; v1 < res->beta_hat->size; ++v1)
309             {
310               double in1 = predictor_value (c, x, n_x, res->cats, v1);
311               double *o = gsl_matrix_ptr (res->hessian, v0, v1);
312               *o += in0 * w * in1;
313             }
314         }
315     }
316   casereader_destroy (reader);
317
318   if ( max_w < cmd->min_epsilon)
319     {
320       *converged = true;
321       msg (MN, _("All predicted values are either 1 or 0"));
322     }
323 }
324
325
326 /* Calculates the value  X' (y - pi)
327    where X is the design model, 
328    y is the vector of observed independent variables
329    pi is the vector of estimates for y
330
331    As a side effect, the likelihood is stored in LIKELIHOOD
332 */
333 static gsl_vector *
334 xt_times_y_pi (const struct lr_spec *cmd,
335                struct lr_result *res,
336                struct casereader *input,
337                const struct variable **x, size_t n_x,
338                const struct variable *y_var,
339                double *likelihood)
340 {
341   struct casereader *reader;
342   struct ccase *c;
343   gsl_vector *output = gsl_vector_calloc (res->beta_hat->size);
344
345   *likelihood = 1.0;
346   for (reader = casereader_clone (input);
347        (c = casereader_read (reader)) != NULL; case_unref (c))
348     {
349       int v0;
350       double pi = pi_hat (cmd, res, x, n_x, c);
351       double weight = dict_get_case_weight (cmd->dict, c, &res->warn_bad_weight);
352
353
354       double y = map_dependent_var (cmd, res, case_data (c, y_var));
355
356       *likelihood *= pow (pi, weight * y) * pow (1 - pi, weight * (1 - y));
357
358       for (v0 = 0; v0 < res->beta_hat->size; ++v0)
359         {
360           double in0 = predictor_value (c, x, n_x, res->cats, v0);
361           double *o = gsl_vector_ptr (output, v0);
362           *o += in0 * (y - pi) * weight;
363         }
364     }
365
366   casereader_destroy (reader);
367
368   return output;
369 }
370
371 \f
372
373 /* "payload" functions for the categoricals.
374    The only function is to accumulate the frequency of each
375    category.
376  */
377
378 static void *
379 frq_create  (const void *aux1 UNUSED, void *aux2 UNUSED)
380 {
381   return xzalloc (sizeof (double));
382 }
383
384 static void
385 frq_update  (const void *aux1 UNUSED, void *aux2 UNUSED,
386              void *ud, const struct ccase *c UNUSED , double weight)
387 {
388   double *freq = ud;
389   *freq += weight;
390 }
391
392 static void 
393 frq_destroy (const void *aux1 UNUSED, void *aux2 UNUSED, void *user_data UNUSED)
394 {
395   free (user_data);
396 }
397
398 \f
399
400 /* 
401    Makes an initial pass though the data, doing the following:
402
403    * Checks that the dependent variable is  dichotomous,
404    * Creates and initialises the categoricals,
405    * Accumulates summary results,
406    * Calculates necessary initial values.
407    * Creates an initial value for \hat\beta the vector of beta_hats of \beta
408
409    Returns true if successful
410 */
411 static bool
412 initial_pass (const struct lr_spec *cmd, struct lr_result *res, struct casereader *input)
413 {
414   const int width = var_get_width (cmd->dep_var);
415
416   struct ccase *c;
417   struct casereader *reader;
418
419   double sum;
420   double sumA = 0.0;
421   double sumB = 0.0;
422
423   bool v0set = false;
424   bool v1set = false;
425
426   size_t n_coefficients = cmd->n_predictor_vars;
427   if (cmd->constant)
428     n_coefficients++;
429
430   /* Create categoricals if appropriate */
431   if (cmd->n_cat_predictors > 0)
432     {
433       res->cp.create = frq_create;
434       res->cp.update = frq_update;
435       res->cp.calculate = NULL;
436       res->cp.destroy = frq_destroy;
437
438       res->cats = categoricals_create (cmd->cat_predictors, cmd->n_cat_predictors,
439                                        cmd->wv, cmd->exclude, MV_ANY);
440
441       categoricals_set_payload (res->cats, &res->cp, cmd, res);
442     }
443
444   res->cc = 0;
445   for (reader = casereader_clone (input);
446        (c = casereader_read (reader)) != NULL; case_unref (c))
447     {
448       int v;
449       bool missing = false;
450       double weight = dict_get_case_weight (cmd->dict, c, &res->warn_bad_weight);
451       const union value *depval = case_data (c, cmd->dep_var);
452
453       for (v = 0; v < cmd->n_indep_vars; ++v)
454         {
455           const union value *val = case_data (c, cmd->indep_vars[v]);
456           if (var_is_value_missing (cmd->indep_vars[v], val, cmd->exclude))
457             {
458               missing = true;
459               break;
460             }
461         }
462
463       /* Accumulate the missing and non-missing counts */
464       if (missing)
465         {
466           res->n_missing++;
467           continue;
468         }
469       res->n_nonmissing++;
470
471       /* Find the values of the dependent variable */
472       if (!v0set)
473         {
474           value_clone (&res->y0, depval, width);
475           v0set = true;
476         }
477       else if (!v1set)
478         {
479           if ( !value_equal (&res->y0, depval, width))
480             {
481               value_clone (&res->y1, depval, width);
482               v1set = true;
483             }
484         }
485       else
486         {
487           if (! value_equal (&res->y0, depval, width)
488               &&
489               ! value_equal (&res->y1, depval, width)
490               )
491             {
492               msg (ME, _("Dependent variable's values are not dichotomous."));
493               goto error;
494             }
495         }
496
497       if (v0set && value_equal (&res->y0, depval, width))
498           sumA += weight;
499
500       if (v1set && value_equal (&res->y1, depval, width))
501           sumB += weight;
502
503
504       res->cc += weight;
505
506       categoricals_update (res->cats, c);
507     }
508   casereader_destroy (reader);
509
510   categoricals_done (res->cats);
511
512   sum = sumB;
513
514   /* Ensure that Y0 is less than Y1.  Otherwise the mapping gets
515      inverted, which is confusing to users */
516   if (var_is_numeric (cmd->dep_var) && value_compare_3way (&res->y0, &res->y1, width) > 0)
517     {
518       union value tmp;
519       value_clone (&tmp, &res->y0, width);
520       value_copy (&res->y0, &res->y1, width);
521       value_copy (&res->y1, &tmp, width);
522       value_destroy (&tmp, width);
523       sum = sumA;
524     }
525
526   n_coefficients += categoricals_df_total (res->cats);
527   res->beta_hat = gsl_vector_calloc (n_coefficients);
528
529   if (cmd->constant)
530     {
531       double mean = sum / res->cc;
532       gsl_vector_set (res->beta_hat, res->beta_hat->size - 1, log (mean / (1 - mean)));
533     }
534
535   return true;
536
537  error:
538   casereader_destroy (reader);
539   return false;
540 }
541
542
543
544 /* Start of the logistic regression routine proper */
545 static bool
546 run_lr (const struct lr_spec *cmd, struct casereader *input,
547         const struct dataset *ds UNUSED)
548 {
549   int i;
550
551   bool converged = false;
552
553   /* Set the likelihoods to a negative sentinel value */
554   double likelihood = -1;
555   double prev_likelihood = -1;
556   double initial_likelihood = -1;
557
558   struct lr_result work;
559   work.n_missing = 0;
560   work.n_nonmissing = 0;
561   work.warn_bad_weight = true;
562   work.cats = NULL;
563   work.beta_hat = NULL;
564
565   /* Get the initial estimates of \beta and their standard errors.
566      And perform other auxilliary initialisation.  */
567   if (! initial_pass (cmd, &work, input))
568     return false;
569   
570   for (i = 0; i < cmd->n_cat_predictors; ++i)
571     {
572       if (1 >= categoricals_n_count (work.cats, i))
573         {
574           struct string str;
575           ds_init_empty (&str);
576           
577           interaction_to_string (cmd->cat_predictors[i], &str);
578
579           msg (ME, _("Category %s does not have at least two distinct values. Logistic regression will not be run."),
580                ds_cstr(&str));
581           ds_destroy (&str);
582           return false;
583         }
584     }
585
586   output_depvarmap (cmd, &work);
587
588   case_processing_summary (&work);
589
590
591   input = casereader_create_filter_missing (input,
592                                             cmd->indep_vars,
593                                             cmd->n_indep_vars,
594                                             cmd->exclude,
595                                             NULL,
596                                             NULL);
597
598
599   work.hessian = gsl_matrix_calloc (work.beta_hat->size, work.beta_hat->size);
600
601   /* Start the Newton Raphson iteration process... */
602   for( i = 0 ; i < cmd->max_iter ; ++i)
603     {
604       double min, max;
605       gsl_vector *v ;
606
607       
608       hessian (cmd, &work, input,
609                cmd->predictor_vars, cmd->n_predictor_vars,
610                &converged);
611
612       gsl_linalg_cholesky_decomp (work.hessian);
613       gsl_linalg_cholesky_invert (work.hessian);
614
615       v = xt_times_y_pi (cmd, &work, input,
616                          cmd->predictor_vars, cmd->n_predictor_vars,
617                          cmd->dep_var,
618                          &likelihood);
619
620       {
621         /* delta = M.v */
622         gsl_vector *delta = gsl_vector_alloc (v->size);
623         gsl_blas_dgemv (CblasNoTrans, 1.0, work.hessian, v, 0, delta);
624         gsl_vector_free (v);
625
626
627         gsl_vector_add (work.beta_hat, delta);
628
629         gsl_vector_minmax (delta, &min, &max);
630
631         if ( fabs (min) < cmd->bcon && fabs (max) < cmd->bcon)
632           {
633             msg (MN, _("Estimation terminated at iteration number %d because parameter estimates changed by less than %g"),
634                  i + 1, cmd->bcon);
635             converged = true;
636           }
637
638         gsl_vector_free (delta);
639       }
640
641       if ( prev_likelihood >= 0)
642         {
643           if (-log (likelihood) > -(1.0 - cmd->lcon) * log (prev_likelihood))
644             {
645               msg (MN, _("Estimation terminated at iteration number %d because Log Likelihood decreased by less than %g%%"), i + 1, 100 * cmd->lcon);
646               converged = true;
647             }
648         }
649       if (i == 0)
650         initial_likelihood = likelihood;
651       prev_likelihood = likelihood;
652
653       if (converged)
654         break;
655     }
656   casereader_destroy (input);
657   assert (initial_likelihood >= 0);
658
659   if ( ! converged) 
660     msg (MW, _("Estimation terminated at iteration number %d because maximum iterations has been reached"), i );
661
662
663   output_model_summary (&work, initial_likelihood, likelihood);
664
665   if (work.cats)
666     output_categories (cmd, &work);
667
668   output_variables (cmd, &work);
669
670   gsl_matrix_free (work.hessian);
671   gsl_vector_free (work.beta_hat); 
672   
673   categoricals_destroy (work.cats);
674
675   return true;
676 }
677
678 struct variable_node
679 {
680   struct hmap_node node;      /* Node in hash map. */
681   const struct variable *var; /* The variable */
682 };
683
684 static struct variable_node *
685 lookup_variable (const struct hmap *map, const struct variable *var, unsigned int hash)
686 {
687   struct variable_node *vn = NULL;
688   HMAP_FOR_EACH_WITH_HASH (vn, struct variable_node, node, hash, map)
689     {
690       if (vn->var == var)
691         break;
692       
693       fprintf (stderr, "Warning: Hash table collision\n");
694     }
695   
696   return vn;
697 }
698
699
700 /* Parse the LOGISTIC REGRESSION command syntax */
701 int
702 cmd_logistic (struct lexer *lexer, struct dataset *ds)
703 {
704   /* Temporary location for the predictor variables.
705      These may or may not include the categorical predictors */
706   const struct variable **pred_vars;
707   size_t n_pred_vars;
708
709   int v, x;
710   struct lr_spec lr;
711   lr.dict = dataset_dict (ds);
712   lr.n_predictor_vars = 0;
713   lr.predictor_vars = NULL;
714   lr.exclude = MV_ANY;
715   lr.wv = dict_get_weight (lr.dict);
716   lr.max_iter = 20;
717   lr.lcon = 0.0000;
718   lr.bcon = 0.001;
719   lr.min_epsilon = 0.00000001;
720   lr.cut_point = 0.5;
721   lr.constant = true;
722   lr.confidence = 95;
723   lr.print = PRINT_DEFAULT;
724   lr.cat_predictors = NULL;
725   lr.n_cat_predictors = 0;
726   lr.indep_vars = NULL;
727
728
729   if (lex_match_id (lexer, "VARIABLES"))
730     lex_match (lexer, T_EQUALS);
731
732   if (! (lr.dep_var = parse_variable_const (lexer, lr.dict)))
733     goto error;
734
735   lex_force_match (lexer, T_WITH);
736
737   if (!parse_variables_const (lexer, lr.dict,
738                               &pred_vars, &n_pred_vars,
739                               PV_NO_DUPLICATE))
740     goto error;
741
742
743   while (lex_token (lexer) != T_ENDCMD)
744     {
745       lex_match (lexer, T_SLASH);
746
747       if (lex_match_id (lexer, "MISSING"))
748         {
749           lex_match (lexer, T_EQUALS);
750           while (lex_token (lexer) != T_ENDCMD
751                  && lex_token (lexer) != T_SLASH)
752             {
753               if (lex_match_id (lexer, "INCLUDE"))
754                 {
755                   lr.exclude = MV_SYSTEM;
756                 }
757               else if (lex_match_id (lexer, "EXCLUDE"))
758                 {
759                   lr.exclude = MV_ANY;
760                 }
761               else
762                 {
763                   lex_error (lexer, NULL);
764                   goto error;
765                 }
766             }
767         }
768       else if (lex_match_id (lexer, "ORIGIN"))
769         {
770           lr.constant = false;
771         }
772       else if (lex_match_id (lexer, "NOORIGIN"))
773         {
774           lr.constant = true;
775         }
776       else if (lex_match_id (lexer, "NOCONST"))
777         {
778           lr.constant = false;
779         }
780       else if (lex_match_id (lexer, "EXTERNAL"))
781         {
782           /* This is for compatibility.  It does nothing */
783         }
784       else if (lex_match_id (lexer, "CATEGORICAL"))
785         {
786           lex_match (lexer, T_EQUALS);
787           do
788             {
789               lr.cat_predictors = xrealloc (lr.cat_predictors,
790                                   sizeof (*lr.cat_predictors) * ++lr.n_cat_predictors);
791               lr.cat_predictors[lr.n_cat_predictors - 1] = 0;
792             }
793           while (parse_design_interaction (lexer, lr.dict, 
794                                            lr.cat_predictors + lr.n_cat_predictors - 1));
795           lr.n_cat_predictors--;
796         }
797       else if (lex_match_id (lexer, "PRINT"))
798         {
799           lex_match (lexer, T_EQUALS);
800           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
801             {
802               if (lex_match_id (lexer, "DEFAULT"))
803                 {
804                   lr.print |= PRINT_DEFAULT;
805                 }
806               else if (lex_match_id (lexer, "SUMMARY"))
807                 {
808                   lr.print |= PRINT_SUMMARY;
809                 }
810 #if 0
811               else if (lex_match_id (lexer, "CORR"))
812                 {
813                   lr.print |= PRINT_CORR;
814                 }
815               else if (lex_match_id (lexer, "ITER"))
816                 {
817                   lr.print |= PRINT_ITER;
818                 }
819               else if (lex_match_id (lexer, "GOODFIT"))
820                 {
821                   lr.print |= PRINT_GOODFIT;
822                 }
823 #endif
824               else if (lex_match_id (lexer, "CI"))
825                 {
826                   lr.print |= PRINT_CI;
827                   if (lex_force_match (lexer, T_LPAREN))
828                     {
829                       if (! lex_force_int (lexer))
830                         {
831                           lex_error (lexer, NULL);
832                           goto error;
833                         }
834                       lr.confidence = lex_integer (lexer);
835                       lex_get (lexer);
836                       if ( ! lex_force_match (lexer, T_RPAREN))
837                         {
838                           lex_error (lexer, NULL);
839                           goto error;
840                         }
841                     }
842                 }
843               else if (lex_match_id (lexer, "ALL"))
844                 {
845                   lr.print = ~0x0000;
846                 }
847               else
848                 {
849                   lex_error (lexer, NULL);
850                   goto error;
851                 }
852             }
853         }
854       else if (lex_match_id (lexer, "CRITERIA"))
855         {
856           lex_match (lexer, T_EQUALS);
857           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
858             {
859               if (lex_match_id (lexer, "BCON"))
860                 {
861                   if (lex_force_match (lexer, T_LPAREN))
862                     {
863                       if (! lex_force_num (lexer))
864                         {
865                           lex_error (lexer, NULL);
866                           goto error;
867                         }
868                       lr.bcon = lex_number (lexer);
869                       lex_get (lexer);
870                       if ( ! lex_force_match (lexer, T_RPAREN))
871                         {
872                           lex_error (lexer, NULL);
873                           goto error;
874                         }
875                     }
876                 }
877               else if (lex_match_id (lexer, "ITERATE"))
878                 {
879                   if (lex_force_match (lexer, T_LPAREN))
880                     {
881                       if (! lex_force_int (lexer))
882                         {
883                           lex_error (lexer, NULL);
884                           goto error;
885                         }
886                       lr.max_iter = lex_integer (lexer);
887                       lex_get (lexer);
888                       if ( ! lex_force_match (lexer, T_RPAREN))
889                         {
890                           lex_error (lexer, NULL);
891                           goto error;
892                         }
893                     }
894                 }
895               else if (lex_match_id (lexer, "LCON"))
896                 {
897                   if (lex_force_match (lexer, T_LPAREN))
898                     {
899                       if (! lex_force_num (lexer))
900                         {
901                           lex_error (lexer, NULL);
902                           goto error;
903                         }
904                       lr.lcon = lex_number (lexer);
905                       lex_get (lexer);
906                       if ( ! lex_force_match (lexer, T_RPAREN))
907                         {
908                           lex_error (lexer, NULL);
909                           goto error;
910                         }
911                     }
912                 }
913               else if (lex_match_id (lexer, "EPS"))
914                 {
915                   if (lex_force_match (lexer, T_LPAREN))
916                     {
917                       if (! lex_force_num (lexer))
918                         {
919                           lex_error (lexer, NULL);
920                           goto error;
921                         }
922                       lr.min_epsilon = lex_number (lexer);
923                       lex_get (lexer);
924                       if ( ! lex_force_match (lexer, T_RPAREN))
925                         {
926                           lex_error (lexer, NULL);
927                           goto error;
928                         }
929                     }
930                 }
931               else
932                 {
933                   lex_error (lexer, NULL);
934                   goto error;
935                 }
936             }
937         }
938       else
939         {
940           lex_error (lexer, NULL);
941           goto error;
942         }
943     }
944
945   /* Copy the predictor variables from the temporary location into the 
946      final one, dropping any categorical variables which appear there.
947      FIXME: This is O(NxM).
948   */
949
950   {
951   struct variable_node *vn, *next;
952   struct hmap allvars;
953   hmap_init (&allvars);
954   for (v = x = 0; v < n_pred_vars; ++v)
955     {
956       bool drop = false;
957       const struct variable *var = pred_vars[v];
958       int cv = 0;
959
960       unsigned int hash = hash_pointer (var, 0);
961       struct variable_node *vn = lookup_variable (&allvars, var, hash);
962       if (vn == NULL)
963         {
964           vn = xmalloc (sizeof *vn);
965           vn->var = var;
966           hmap_insert (&allvars, &vn->node,  hash);
967         }
968
969       for (cv = 0; cv < lr.n_cat_predictors ; ++cv)
970         {
971           int iv;
972           const struct interaction *iact = lr.cat_predictors[cv];
973           for (iv = 0 ; iv < iact->n_vars ; ++iv)
974             {
975               const struct variable *ivar = iact->vars[iv];
976               unsigned int hash = hash_pointer (ivar, 0);
977               struct variable_node *vn = lookup_variable (&allvars, ivar, hash);
978               if (vn == NULL)
979                 {
980                   vn = xmalloc (sizeof *vn);
981                   vn->var = ivar;
982                   
983                   hmap_insert (&allvars, &vn->node,  hash);
984                 }
985
986               if (var == ivar)
987                 {
988                   drop = true;
989                 }
990             }
991         }
992
993       if (drop)
994         continue;
995
996       lr.predictor_vars = xrealloc (lr.predictor_vars, sizeof *lr.predictor_vars * (x + 1));
997       lr.predictor_vars[x++] = var;
998       lr.n_predictor_vars++;
999     }
1000   free (pred_vars);
1001
1002   lr.n_indep_vars = hmap_count (&allvars);
1003   lr.indep_vars = xmalloc (lr.n_indep_vars * sizeof *lr.indep_vars);
1004
1005   /* Interate over each variable and push it into the array */
1006   x = 0;
1007   HMAP_FOR_EACH_SAFE (vn, next, struct variable_node, node, &allvars)
1008     {
1009       lr.indep_vars[x++] = vn->var;
1010       free (vn);
1011     }
1012   hmap_destroy (&allvars);
1013   }  
1014
1015
1016   /* logistical regression for each split group */
1017   {
1018     struct casegrouper *grouper;
1019     struct casereader *group;
1020     bool ok;
1021
1022     grouper = casegrouper_create_splits (proc_open (ds), lr.dict);
1023     while (casegrouper_get_next_group (grouper, &group))
1024       ok = run_lr (&lr, group, ds);
1025     ok = casegrouper_destroy (grouper);
1026     ok = proc_commit (ds) && ok;
1027   }
1028
1029   free (lr.predictor_vars);
1030   free (lr.cat_predictors);
1031   free (lr.indep_vars);
1032
1033   return CMD_SUCCESS;
1034
1035  error:
1036
1037   free (lr.predictor_vars);
1038   free (lr.cat_predictors);
1039   free (lr.indep_vars);
1040
1041   return CMD_FAILURE;
1042 }
1043
1044
1045 \f
1046
1047 /* Show the Dependent Variable Encoding box.
1048    This indicates how the dependent variable
1049    is mapped to the internal zero/one values.
1050 */
1051 static void
1052 output_depvarmap (const struct lr_spec *cmd, const struct lr_result *res)
1053 {
1054   const int heading_columns = 0;
1055   const int heading_rows = 1;
1056   struct tab_table *t;
1057   struct string str;
1058
1059   const int nc = 2;
1060   int nr = heading_rows + 2;
1061
1062   t = tab_create (nc, nr);
1063   tab_title (t, _("Dependent Variable Encoding"));
1064
1065   tab_headers (t, heading_columns, 0, heading_rows, 0);
1066
1067   tab_box (t, TAL_2, TAL_2, -1, TAL_1, 0, 0, nc - 1, nr - 1);
1068
1069   tab_hline (t, TAL_2, 0, nc - 1, heading_rows);
1070   tab_vline (t, TAL_2, heading_columns, 0, nr - 1);
1071
1072   tab_text (t,  0, 0, TAB_CENTER | TAT_TITLE, _("Original Value"));
1073   tab_text (t,  1, 0, TAB_CENTER | TAT_TITLE, _("Internal Value"));
1074
1075
1076
1077   ds_init_empty (&str);
1078   var_append_value_name (cmd->dep_var, &res->y0, &str);
1079   tab_text (t,  0, 0 + heading_rows,  0, ds_cstr (&str));
1080
1081   ds_clear (&str);
1082   var_append_value_name (cmd->dep_var, &res->y1, &str);
1083   tab_text (t,  0, 1 + heading_rows,  0, ds_cstr (&str));
1084
1085
1086   tab_double (t, 1, 0 + heading_rows, 0, map_dependent_var (cmd, res, &res->y0), &F_8_0);
1087   tab_double (t, 1, 1 + heading_rows, 0, map_dependent_var (cmd, res, &res->y1), &F_8_0);
1088   ds_destroy (&str);
1089
1090   tab_submit (t);
1091 }
1092
1093
1094 /* Show the Variables in the Equation box */
1095 static void
1096 output_variables (const struct lr_spec *cmd, 
1097                   const struct lr_result *res)
1098 {
1099   int row = 0;
1100   const int heading_columns = 1;
1101   int heading_rows = 1;
1102   struct tab_table *t;
1103
1104   int nc = 8;
1105   int nr ;
1106   int i = 0;
1107   int ivar = 0;
1108   int idx_correction = 0;
1109
1110   if (cmd->print & PRINT_CI)
1111     {
1112       nc += 2;
1113       heading_rows += 1;
1114       row++;
1115     }
1116   nr = heading_rows + cmd->n_predictor_vars;
1117   if (cmd->constant)
1118     nr++;
1119
1120   if (res->cats)
1121     nr += categoricals_df_total (res->cats) + cmd->n_cat_predictors;
1122
1123   t = tab_create (nc, nr);
1124   tab_title (t, _("Variables in the Equation"));
1125
1126   tab_headers (t, heading_columns, 0, heading_rows, 0);
1127
1128   tab_box (t, TAL_2, TAL_2, -1, TAL_1, 0, 0, nc - 1, nr - 1);
1129
1130   tab_hline (t, TAL_2, 0, nc - 1, heading_rows);
1131   tab_vline (t, TAL_2, heading_columns, 0, nr - 1);
1132
1133   tab_text (t,  0, row + 1, TAB_CENTER | TAT_TITLE, _("Step 1"));
1134
1135   tab_text (t,  2, row, TAB_CENTER | TAT_TITLE, _("B"));
1136   tab_text (t,  3, row, TAB_CENTER | TAT_TITLE, _("S.E."));
1137   tab_text (t,  4, row, TAB_CENTER | TAT_TITLE, _("Wald"));
1138   tab_text (t,  5, row, TAB_CENTER | TAT_TITLE, _("df"));
1139   tab_text (t,  6, row, TAB_CENTER | TAT_TITLE, _("Sig."));
1140   tab_text (t,  7, row, TAB_CENTER | TAT_TITLE, _("Exp(B)"));
1141
1142   if (cmd->print & PRINT_CI)
1143     {
1144       tab_joint_text_format (t, 8, 0, 9, 0,
1145                              TAB_CENTER | TAT_TITLE, _("%d%% CI for Exp(B)"), cmd->confidence);
1146
1147       tab_text (t,  8, row, TAB_CENTER | TAT_TITLE, _("Lower"));
1148       tab_text (t,  9, row, TAB_CENTER | TAT_TITLE, _("Upper"));
1149     }
1150  
1151   for (row = heading_rows ; row < nr; ++row)
1152     {
1153       const int idx = row - heading_rows - idx_correction;
1154
1155       const double b = gsl_vector_get (res->beta_hat, idx);
1156       const double sigma2 = gsl_matrix_get (res->hessian, idx, idx);
1157       const double wald = pow2 (b) / sigma2;
1158       const double df = 1;
1159
1160       if (idx < cmd->n_predictor_vars)
1161         {
1162           tab_text (t, 1, row, TAB_LEFT | TAT_TITLE, 
1163                     var_to_string (cmd->predictor_vars[idx]));
1164         }
1165       else if (i < cmd->n_cat_predictors)
1166         {
1167           double wald;
1168           bool summary = false;
1169           struct string str;
1170           const struct interaction *cat_predictors = cmd->cat_predictors[i];
1171           const int df = categoricals_df (res->cats, i);
1172
1173           ds_init_empty (&str);
1174           interaction_to_string (cat_predictors, &str);
1175
1176           if (ivar == 0)
1177             {
1178               /* Calculate the Wald statistic,
1179                  which is \beta' C^-1 \beta .
1180                  where \beta is the vector of the coefficient estimates comprising this
1181                  categorial variable. and C is the corresponding submatrix of the 
1182                  hessian matrix.
1183               */
1184               gsl_matrix_const_view mv =
1185                 gsl_matrix_const_submatrix (res->hessian, idx, idx, df, df);
1186               gsl_matrix *subhessian = gsl_matrix_alloc (mv.matrix.size1, mv.matrix.size2);
1187               gsl_vector_const_view vv = gsl_vector_const_subvector (res->beta_hat, idx, df);
1188               gsl_vector *temp = gsl_vector_alloc (df);
1189
1190               gsl_matrix_memcpy (subhessian, &mv.matrix);
1191               gsl_linalg_cholesky_decomp (subhessian);
1192               gsl_linalg_cholesky_invert (subhessian);
1193
1194               gsl_blas_dgemv (CblasTrans, 1.0, subhessian, &vv.vector, 0, temp);
1195               gsl_blas_ddot (temp, &vv.vector, &wald);
1196
1197               tab_double (t, 4, row, 0, wald, 0);
1198               tab_double (t, 5, row, 0, df, &F_8_0);
1199               tab_double (t, 6, row, 0, gsl_cdf_chisq_Q (wald, df), 0);
1200
1201               idx_correction ++;
1202               summary = true;
1203               gsl_matrix_free (subhessian);
1204               gsl_vector_free (temp);
1205             }
1206           else
1207             {
1208               ds_put_format (&str, "(%d)", ivar);
1209             }
1210
1211           tab_text (t, 1, row, TAB_LEFT | TAT_TITLE, ds_cstr (&str));
1212           if (ivar++ == df)
1213             {
1214               ++i; /* next interaction */
1215               ivar = 0;
1216             }
1217
1218           ds_destroy (&str);
1219
1220           if (summary)
1221             continue;
1222         }
1223       else
1224         {
1225           tab_text (t, 1, row, TAB_LEFT | TAT_TITLE, _("Constant"));
1226         }
1227
1228       tab_double (t, 2, row, 0, b, 0);
1229       tab_double (t, 3, row, 0, sqrt (sigma2), 0);
1230       tab_double (t, 4, row, 0, wald, 0);
1231       tab_double (t, 5, row, 0, df, &F_8_0);
1232       tab_double (t, 6, row, 0, gsl_cdf_chisq_Q (wald, df), 0);
1233       tab_double (t, 7, row, 0, exp (b), 0);
1234
1235       if (cmd->print & PRINT_CI)
1236         {
1237           double wc = gsl_cdf_ugaussian_Pinv (0.5 + cmd->confidence / 200.0);
1238           wc *= sqrt (sigma2);
1239
1240           if (idx < cmd->n_predictor_vars)
1241             {
1242               tab_double (t, 8, row, 0, exp (b - wc), 0);
1243               tab_double (t, 9, row, 0, exp (b + wc), 0);
1244             }
1245         }
1246     }
1247
1248   tab_submit (t);
1249 }
1250
1251
1252 /* Show the model summary box */
1253 static void
1254 output_model_summary (const struct lr_result *res,
1255                       double initial_likelihood, double likelihood)
1256 {
1257   const int heading_columns = 0;
1258   const int heading_rows = 1;
1259   struct tab_table *t;
1260
1261   const int nc = 4;
1262   int nr = heading_rows + 1;
1263   double cox;
1264
1265   t = tab_create (nc, nr);
1266   tab_title (t, _("Model Summary"));
1267
1268   tab_headers (t, heading_columns, 0, heading_rows, 0);
1269
1270   tab_box (t, TAL_2, TAL_2, -1, TAL_1, 0, 0, nc - 1, nr - 1);
1271
1272   tab_hline (t, TAL_2, 0, nc - 1, heading_rows);
1273   tab_vline (t, TAL_2, heading_columns, 0, nr - 1);
1274
1275   tab_text (t,  0, 0, TAB_LEFT | TAT_TITLE, _("Step 1"));
1276   tab_text (t,  1, 0, TAB_CENTER | TAT_TITLE, _("-2 Log likelihood"));
1277   tab_double (t,  1, 1, 0, -2 * log (likelihood), 0);
1278
1279
1280   tab_text (t,  2, 0, TAB_CENTER | TAT_TITLE, _("Cox & Snell R Square"));
1281   cox =  1.0 - pow (initial_likelihood /likelihood, 2 / res->cc);
1282   tab_double (t,  2, 1, 0, cox, 0);
1283
1284   tab_text (t,  3, 0, TAB_CENTER | TAT_TITLE, _("Nagelkerke R Square"));
1285   tab_double (t,  3, 1, 0, cox / ( 1.0 - pow (initial_likelihood, 2 / res->cc)), 0);
1286
1287
1288   tab_submit (t);
1289 }
1290
1291 /* Show the case processing summary box */
1292 static void
1293 case_processing_summary (const struct lr_result *res)
1294 {
1295   const int heading_columns = 1;
1296   const int heading_rows = 1;
1297   struct tab_table *t;
1298
1299   const int nc = 3;
1300   const int nr = heading_rows + 3;
1301   casenumber total;
1302
1303   t = tab_create (nc, nr);
1304   tab_title (t, _("Case Processing Summary"));
1305
1306   tab_headers (t, heading_columns, 0, heading_rows, 0);
1307
1308   tab_box (t, TAL_2, TAL_2, -1, TAL_1, 0, 0, nc - 1, nr - 1);
1309
1310   tab_hline (t, TAL_2, 0, nc - 1, heading_rows);
1311   tab_vline (t, TAL_2, heading_columns, 0, nr - 1);
1312
1313   tab_text (t,  0, 0, TAB_LEFT | TAT_TITLE, _("Unweighted Cases"));
1314   tab_text (t,  1, 0, TAB_CENTER | TAT_TITLE, _("N"));
1315   tab_text (t,  2, 0, TAB_CENTER | TAT_TITLE, _("Percent"));
1316
1317
1318   tab_text (t,  0, 1, TAB_LEFT | TAT_TITLE, _("Included in Analysis"));
1319   tab_text (t,  0, 2, TAB_LEFT | TAT_TITLE, _("Missing Cases"));
1320   tab_text (t,  0, 3, TAB_LEFT | TAT_TITLE, _("Total"));
1321
1322   tab_double (t,  1, 1, 0, res->n_nonmissing, &F_8_0);
1323   tab_double (t,  1, 2, 0, res->n_missing, &F_8_0);
1324
1325   total = res->n_nonmissing + res->n_missing;
1326   tab_double (t,  1, 3, 0, total , &F_8_0);
1327
1328   tab_double (t,  2, 1, 0, 100 * res->n_nonmissing / (double) total, 0);
1329   tab_double (t,  2, 2, 0, 100 * res->n_missing / (double) total, 0);
1330   tab_double (t,  2, 3, 0, 100 * total / (double) total, 0);
1331
1332   tab_submit (t);
1333 }
1334
1335 static void
1336 output_categories (const struct lr_spec *cmd, const struct lr_result *res)
1337 {
1338   const struct fmt_spec *wfmt =
1339     cmd->wv ? var_get_print_format (cmd->wv) : &F_8_0;
1340
1341   int cumulative_df;
1342   int i = 0;
1343   const int heading_columns = 2;
1344   const int heading_rows = 2;
1345   struct tab_table *t;
1346
1347   int nc ;
1348   int nr ;
1349
1350   int v;
1351   int r = 0;
1352
1353   int max_df = 0;
1354   int total_cats = 0;
1355   for (i = 0; i < cmd->n_cat_predictors; ++i)
1356     {
1357       size_t n = categoricals_n_count (res->cats, i);
1358       size_t df = categoricals_df (res->cats, i);
1359       if (max_df < df)
1360         max_df = df;
1361       total_cats += n;
1362     }
1363
1364   nc = heading_columns + 1 + max_df;
1365   nr = heading_rows + total_cats;
1366
1367   t = tab_create (nc, nr);
1368   tab_title (t, _("Categorical Variables' Codings"));
1369
1370   tab_headers (t, heading_columns, 0, heading_rows, 0);
1371
1372   tab_box (t, TAL_2, TAL_2, -1, TAL_1, 0, 0, nc - 1, nr - 1);
1373
1374   tab_hline (t, TAL_2, 0, nc - 1, heading_rows);
1375   tab_vline (t, TAL_2, heading_columns, 0, nr - 1);
1376
1377
1378   tab_text (t, heading_columns, 1, TAB_CENTER | TAT_TITLE, _("Frequency"));
1379
1380   tab_joint_text_format (t, heading_columns + 1, 0, nc - 1, 0,
1381                          TAB_CENTER | TAT_TITLE, _("Parameter coding"));
1382
1383
1384   for (i = 0; i < max_df; ++i)
1385     {
1386       int c = heading_columns + 1 + i;
1387       tab_text_format (t,  c, 1, TAB_CENTER | TAT_TITLE, _("(%d)"), i + 1);
1388     }
1389
1390   cumulative_df = 0;
1391   for (v = 0; v < cmd->n_cat_predictors; ++v)
1392     {
1393       int cat;
1394       const struct interaction *cat_predictors = cmd->cat_predictors[v];
1395       int df =  categoricals_df (res->cats, v);
1396       struct string str;
1397       ds_init_empty (&str);
1398
1399       interaction_to_string (cat_predictors, &str);
1400
1401       tab_text (t, 0, heading_rows + r, TAB_LEFT | TAT_TITLE, ds_cstr (&str) );
1402
1403       ds_destroy (&str);
1404
1405       for (cat = 0; cat < categoricals_n_count (res->cats, v) ; ++cat)
1406         {
1407           struct string str;
1408           const struct ccase *c = categoricals_get_case_by_category_real (res->cats, v, cat);
1409           const double *freq = categoricals_get_user_data_by_category_real (res->cats, v, cat);
1410           
1411           int x;
1412           ds_init_empty (&str);
1413
1414           for (x = 0; x < cat_predictors->n_vars; ++x)
1415             {
1416               const union value *val = case_data (c, cat_predictors->vars[x]);
1417               var_append_value_name (cat_predictors->vars[x], val, &str);
1418
1419               if (x < cat_predictors->n_vars - 1)
1420                 ds_put_cstr (&str, " ");
1421             }
1422           
1423           tab_text   (t, 1, heading_rows + r, 0, ds_cstr (&str));
1424           ds_destroy (&str);
1425           tab_double (t, 2, heading_rows + r, 0, *freq, wfmt);
1426
1427           for (x = 0; x < df; ++x)
1428             {
1429               tab_double (t, heading_columns + 1 + x, heading_rows + r, 0, (cat == x), &F_8_0);
1430             }
1431           ++r;
1432         }
1433       cumulative_df += df;
1434     }
1435
1436   tab_submit (t);
1437
1438 }