Merge commit 'origin/covariance'
[pspp-builds.git] / src / language / stats / glm.q
index 0804945fabaa1504eff3779906b77dce6d3b9bda..1a941152d9fa334d408eb10002314c17febd4049 100644 (file)
@@ -23,7 +23,6 @@
 #include <stdlib.h>
 
 #include <data/case.h>
-#include <data/category.h>
 #include <data/casegrouper.h>
 #include <data/casereader.h>
 #include <data/dictionary.h>
@@ -39,7 +38,7 @@
 #include <libpspp/compiler.h>
 #include <libpspp/hash.h>
 #include <libpspp/message.h>
-#include <math/covariance-matrix.h>
+#include <math/covariance.h>
 #include <math/coefficient.h>
 #include <math/linreg.h>
 #include <math/moments.h>
@@ -228,18 +227,22 @@ glm_custom_dependent (struct lexer *lexer, struct dataset *ds,
                      struct cmd_glm *cmd UNUSED, void *aux UNUSED)
 {
   const struct dictionary *dict = dataset_dict (ds);
+  size_t i;
 
   if ((lex_token (lexer) != T_ID
        || dict_lookup_var (dict, lex_tokid (lexer)) == NULL)
       && lex_token (lexer) != T_ALL)
     return 2;
 
-  if (!parse_variables_const
-      (lexer, dict, &v_dependent, &n_dependent, PV_NONE))
+  if (!parse_variables_const (lexer, dict, &v_dependent, &n_dependent, PV_NONE))
     {
       free (v_dependent);
       return 0;
     }
+  for (i = 0; i < n_dependent; i++)
+    {
+      assert (var_is_numeric (v_dependent[i]));
+    }
   assert (n_dependent);
   if (n_dependent > 1)
     msg (SE, _("Multivariate GLM not yet supported"));
@@ -248,29 +251,13 @@ glm_custom_dependent (struct lexer *lexer, struct dataset *ds,
   return 1;
 }
 
-/*
-  COV is the covariance matrix for variables included in the
-  model. That means the dependent variable is in there, too.
- */
-static void
-coeff_init (pspp_linreg_cache * c, const struct design_matrix *cov)
-{
-  c->coeff = xnmalloc (cov->m->size2, sizeof (*c->coeff));
-  c->n_coeffs = cov->m->size2 - 1;
-  pspp_coeff_init (c->coeff, cov);
-}
-
-
-static pspp_linreg_cache *
-fit_model (const struct covariance_matrix *cov,
+static linreg *
+fit_model (const struct covariance *cov,
           const struct variable *dep_var, 
           const struct variable ** indep_vars, 
           size_t n_data, size_t n_indep)
 {
-  pspp_linreg_cache *result = NULL;
-  result = pspp_linreg_cache_alloc (dep_var, indep_vars, n_data, n_indep);
-  coeff_init (result, covariance_to_design (cov));
-  pspp_linreg_with_cov (cov, result);  
+  linreg *result = NULL;
   
   return result;
 }
@@ -281,17 +268,18 @@ run_glm (struct casereader *input,
         const struct dataset *ds)
 {
   casenumber row;
-  const struct variable **indep_vars;
-  const struct variable **all_vars;
+  const struct variable **numerics = NULL;
+  const struct variable **categoricals = NULL;
   int n_indep = 0;
-  pspp_linreg_cache *model = NULL; 
+  linreg *model = NULL; 
   pspp_linreg_opts lopts;
   struct ccase *c;
   size_t i;
-  size_t n_all_vars;
   size_t n_data;               /* Number of valid cases. */
+  size_t n_categoricals = 0;
+  size_t n_numerics;
   struct casereader *reader;
-  struct covariance_matrix *cov;
+  struct covariance *cov;
 
   c = casereader_peek (input, 0);
   if (c == NULL)
@@ -311,50 +299,93 @@ run_glm (struct casereader *input,
   lopts.get_depvar_mean_std = 1;
 
   lopts.get_indep_mean_std = xnmalloc (n_dependent, sizeof (int));
-  indep_vars = xnmalloc (cmd->n_by, sizeof *indep_vars);
-  n_all_vars = cmd->n_by + n_dependent;
-  all_vars = xnmalloc (n_all_vars, sizeof *all_vars);
 
-  for (i = 0; i < n_dependent; i++)
+
+  n_numerics = n_dependent;
+  for (i = 0; i < cmd->n_with; i++)
+    {
+      if (var_is_alpha (cmd->v_with[i]))
+       {
+         n_categoricals++;
+       }
+      else
+       {
+         n_numerics++;
+       }
+    }
+  for (i = 0; i < cmd->n_by; i++)
     {
-      all_vars[i] = v_dependent[i];
+      if (var_is_alpha (cmd->v_by[i]))
+       {
+         n_categoricals++;
+       }
+      else
+       {
+         n_numerics++;
+       }
     }
+  numerics = xnmalloc (n_numerics, sizeof *numerics);
+  categoricals = xnmalloc (n_categoricals, sizeof (*categoricals));
+  size_t j = 0;
+  size_t k = 0;
   for (i = 0; i < cmd->n_by; i++)
     {
-      indep_vars[i] = cmd->v_by[i];
-      all_vars[i + n_dependent] = cmd->v_by[i];
+      if (var_is_alpha (cmd->v_by[i]))
+       {
+         categoricals[j] = cmd->v_by[i];
+         j++;
+       }
+      else
+       {
+         numerics[k] = cmd->v_by[i];
+         k++;
+       }
+    }
+  for (i = 0; i < cmd->n_with; i++)
+    {
+      if (var_is_alpha (cmd->v_with[i]))
+       {
+         categoricals[j] = cmd->v_with[i];
+         j++;
+       }
+      else
+       {
+         numerics[k] = cmd->v_with[i];
+         k++;
+       }
+    }
+  for (i = 0; i < n_dependent; i++)
+    {
+      numerics[k] = v_dependent[i];
+      k++;
     }
-  n_indep = cmd->n_by;
+
+  cov = covariance_2pass_create (n_numerics, numerics, n_categoricals, categoricals, NULL, MV_NEVER);
 
   reader = casereader_clone (input);
-  reader = casereader_create_filter_missing (reader, indep_vars, n_indep,
+  reader = casereader_create_filter_missing (reader, numerics, n_numerics,
                                             MV_ANY, NULL, NULL);
-  reader = casereader_create_filter_missing (reader, v_dependent, 1,
+  reader = casereader_create_filter_missing (reader, categoricals, n_categoricals,
                                             MV_ANY, NULL, NULL);
+  struct casereader *r = casereader_clone (reader);
 
-  if (n_indep > 0)
+  reader = casereader_create_counter (reader, &row, -1);
+  
+  for (; (c = casereader_read (reader)) != NULL; case_unref (c))
     {
-      for (i = 0; i < n_all_vars; i++)
-       if (var_is_alpha (all_vars[i]))
-         cat_stored_values_create (all_vars[i]);
-      
-      reader = casereader_create_counter (reader, &row, -1);
-
-      for (i = 0; i < n_inter; i++)
-      for (; (c = casereader_read (reader)) != NULL; case_unref (c))
-       {
-         /* 
-            Accumulate the covariance matrix.
-         */
-         n_data++;
-       }
-      casereader_destroy (reader);
+      covariance_accumulate_pass1 (cov, c);
     }
-  else
+  for (; (c = casereader_read (r)) != NULL; case_unref (c))
     {
-      msg (SE, gettext ("No valid data found. This command was skipped."));
+      covariance_accumulate_pass2 (cov, c);
     }
-  free (indep_vars);
+
+  covariance_destroy (cov);
+  casereader_destroy (reader);
+  casereader_destroy (r);
+  
+  free (numerics);
+  free (categoricals);
   free (lopts.get_indep_mean_std);
   casereader_destroy (input);