New module src/math/correlation
[pspp-builds.git] / src / language / stats / correlations.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <libpspp/assertion.h>
20 #include <math/covariance.h>
21 #include <math/correlation.h>
22 #include <math/design-matrix.h>
23 #include <gsl/gsl_matrix.h>
24 #include <data/casegrouper.h>
25 #include <data/casereader.h>
26 #include <data/dictionary.h>
27 #include <data/procedure.h>
28 #include <data/variable.h>
29 #include <language/command.h>
30 #include <language/dictionary/split-file.h>
31 #include <language/lexer/lexer.h>
32 #include <language/lexer/variable-parser.h>
33 #include <output/manager.h>
34 #include <output/table.h>
35 #include <libpspp/message.h>
36 #include <data/format.h>
37 #include <math/moments.h>
38
39 #include <math.h>
40 #include "xalloc.h"
41 #include "minmax.h"
42 #include <libpspp/misc.h>
43 #include <gsl/gsl_cdf.h>
44
45 #include "gettext.h"
46 #define _(msgid) gettext (msgid)
47 #define N_(msgid) msgid
48
49
50 struct corr
51 {
52   size_t n_vars_total;
53   size_t n_vars1;
54
55   const struct variable **vars;
56 };
57
58
59 /* Handling of missing values. */
60 enum corr_missing_type
61   {
62     CORR_PAIRWISE,       /* Handle missing values on a per-variable-pair basis. */
63     CORR_LISTWISE        /* Discard entire case if any variable is missing. */
64   };
65
66 enum stats_opts
67   {
68     STATS_DESCRIPTIVES = 0x01,
69     STATS_XPROD = 0x02,
70     STATS_ALL = STATS_XPROD | STATS_DESCRIPTIVES
71   };
72
73 struct corr_opts
74 {
75   enum corr_missing_type missing_type;
76   enum mv_class exclude;      /* Classes of missing values to exclude. */
77
78   bool sig;   /* Flag significant values or not */
79   int tails;  /* Report significance with how many tails ? */
80   enum stats_opts statistics;
81
82   const struct variable *wv;  /* The weight variable (if any) */
83 };
84
85
86 static void
87 output_descriptives (const struct corr *corr, const gsl_matrix *means,
88                      const gsl_matrix *vars, const gsl_matrix *ns)
89 {
90   const int nr = corr->n_vars_total + 1;
91   const int nc = 4;
92   int c, r;
93
94   const int heading_columns = 1;
95   const int heading_rows = 1;
96
97   struct tab_table *t = tab_create (nc, nr, 0);
98   tab_title (t, _("Descriptive Statistics"));
99   tab_dim (t, tab_natural_dimensions, NULL);
100
101   tab_headers (t, heading_columns, 0, heading_rows, 0);
102
103   /* Outline the box */
104   tab_box (t,
105            TAL_2, TAL_2,
106            -1, -1,
107            0, 0,
108            nc - 1, nr - 1);
109
110   /* Vertical lines */
111   tab_box (t,
112            -1, -1,
113            -1, TAL_1,
114            heading_columns, 0,
115            nc - 1, nr - 1);
116
117   tab_vline (t, TAL_2, heading_columns, 0, nr - 1);
118   tab_hline (t, TAL_1, 0, nc - 1, heading_rows);
119
120   tab_text (t, 1, 0, TAB_CENTER | TAT_TITLE, _("Mean"));
121   tab_text (t, 2, 0, TAB_CENTER | TAT_TITLE, _("Std. Deviation"));
122   tab_text (t, 3, 0, TAB_CENTER | TAT_TITLE, _("N"));
123
124   for (r = 0 ; r < corr->n_vars_total ; ++r)
125     {
126       const struct variable *v = corr->vars[r];
127       tab_text (t, 0, r + heading_rows, TAB_LEFT | TAT_TITLE, var_to_string (v));
128
129       for (c = 1 ; c < nc ; ++c)
130         {
131           double x ;
132           double n;
133           switch (c)
134             {
135             case 1:
136               x = gsl_matrix_get (means, r, 0);
137               break;
138             case 2:
139               x = gsl_matrix_get (vars, r, 0);
140
141               /* Here we want to display the non-biased estimator */
142               n = gsl_matrix_get (ns, r, 0);
143               x *= n / (n -1);
144
145               x = sqrt (x);
146               break;
147             case 3:
148               x = gsl_matrix_get (ns, r, 0);
149               break;
150             default: 
151               NOT_REACHED ();
152             };
153           
154           tab_double (t, c, r + heading_rows, 0, x, NULL);
155         }
156     }
157
158   tab_submit (t);
159 }
160
161 static void
162 output_correlation (const struct corr *corr, const struct corr_opts *opts,
163                     const gsl_matrix *cm, const gsl_matrix *samples,
164                     const gsl_matrix *cv)
165 {
166   int r, c;
167   struct tab_table *t;
168   int matrix_cols;
169   int nr = corr->n_vars1;
170   int nc = matrix_cols = corr->n_vars_total > corr->n_vars1 ?
171     corr->n_vars_total - corr->n_vars1 : corr->n_vars1;
172
173   const struct fmt_spec *wfmt = opts->wv ? var_get_print_format (opts->wv) : & F_8_0;
174
175   const int heading_columns = 2;
176   const int heading_rows = 1;
177
178   int rows_per_variable = opts->missing_type == CORR_LISTWISE ? 2 : 3;
179
180   if (opts->statistics & STATS_XPROD)
181     rows_per_variable += 2;
182
183   /* Two header columns */
184   nc += heading_columns;
185
186   /* Three data per variable */
187   nr *= rows_per_variable;
188
189   /* One header row */
190   nr += heading_rows;
191
192   t = tab_create (nc, nr, 0);
193   tab_title (t, _("Correlations"));
194   tab_dim (t, tab_natural_dimensions, NULL);
195
196   tab_headers (t, heading_columns, 0, heading_rows, 0);
197
198   /* Outline the box */
199   tab_box (t,
200            TAL_2, TAL_2,
201            -1, -1,
202            0, 0,
203            nc - 1, nr - 1);
204
205   /* Vertical lines */
206   tab_box (t,
207            -1, -1,
208            -1, TAL_1,
209            heading_columns, 0,
210            nc - 1, nr - 1);
211
212   tab_vline (t, TAL_2, heading_columns, 0, nr - 1);
213   tab_vline (t, TAL_1, 1, heading_rows, nr - 1);
214
215   for (r = 0 ; r < corr->n_vars1 ; ++r)
216     {
217       tab_text (t, 0, 1 + r * rows_per_variable, TAB_LEFT | TAT_TITLE, 
218                 var_to_string (corr->vars[r]));
219
220       tab_text (t, 1, 1 + r * rows_per_variable, TAB_LEFT | TAT_TITLE, _("Pearson Correlation"));
221       tab_text (t, 1, 2 + r * rows_per_variable, TAB_LEFT | TAT_TITLE, 
222                 (opts->tails == 2) ? _("Sig. (2-tailed)") : _("Sig. (1-tailed)"));
223
224       if (opts->statistics & STATS_XPROD)
225         {
226           tab_text (t, 1, 3 + r * rows_per_variable, TAB_LEFT | TAT_TITLE, _("Cross-products"));
227           tab_text (t, 1, 4 + r * rows_per_variable, TAB_LEFT | TAT_TITLE, _("Covariance"));
228         }
229
230       if ( opts->missing_type != CORR_LISTWISE )
231         tab_text (t, 1, rows_per_variable + r * rows_per_variable, TAB_LEFT | TAT_TITLE, _("N"));
232
233       tab_hline (t, TAL_1, 0, nc - 1, r * rows_per_variable + 1);
234     }
235
236   for (c = 0 ; c < matrix_cols ; ++c)
237     {
238       const struct variable *v = corr->n_vars_total > corr->n_vars1 ? corr->vars[corr->n_vars_total - corr->n_vars1 + c] : corr->vars[c];
239       tab_text (t, heading_columns + c, 0, TAB_LEFT | TAT_TITLE, var_to_string (v));      
240     }
241
242   for (r = 0 ; r < corr->n_vars1 ; ++r)
243     {
244       const int row = r * rows_per_variable + heading_rows;
245       for (c = 0 ; c < matrix_cols ; ++c)
246         {
247           unsigned char flags = 0; 
248           const int col_index = corr->n_vars_total - corr->n_vars1 + c;
249           double pearson = gsl_matrix_get (cm, r, col_index);
250           double w = gsl_matrix_get (samples, r, col_index);
251           double sig = opts->tails * significance_of_correlation (pearson, w);
252
253           if ( opts->missing_type != CORR_LISTWISE )
254             tab_double (t, c + heading_columns, row + rows_per_variable - 1, 0, w, wfmt);
255
256           if ( c != r)
257             tab_double (t, c + heading_columns, row + 1, 0,  sig, NULL);
258
259           if ( opts->sig && c != r && sig < 0.05)
260             flags = TAB_EMPH;
261           
262           tab_double (t, c + heading_columns, row, flags, pearson, NULL);
263
264           if (opts->statistics & STATS_XPROD)
265             {
266               double cov = gsl_matrix_get (cv, r, col_index);
267               const double xprod_dev = cov * w;
268               cov *= w / (w - 1.0);
269
270               tab_double (t, c + heading_columns, row + 2, 0, xprod_dev, NULL);
271               tab_double (t, c + heading_columns, row + 3, 0, cov, NULL);
272             }
273         }
274     }
275
276   tab_submit (t);
277 }
278
279
280 static void
281 run_corr (struct casereader *r, const struct corr_opts *opts, const struct corr *corr)
282 {
283   struct ccase *c;
284   const gsl_matrix *var_matrix,  *samples_matrix, *mean_matrix;
285   const gsl_matrix *cov_matrix;
286   gsl_matrix *corr_matrix;
287   struct covariance *cov = covariance_create (corr->n_vars_total, corr->vars,
288                                               opts->wv, opts->exclude);
289
290   for ( ; (c = casereader_read (r) ); case_unref (c))
291     {
292       covariance_accumulate (cov, c);
293     }
294
295   cov_matrix = covariance_calculate (cov);
296
297   samples_matrix = covariance_moments (cov, MOMENT_NONE);
298   var_matrix = covariance_moments (cov, MOMENT_VARIANCE);
299   mean_matrix = covariance_moments (cov, MOMENT_MEAN);
300
301   corr_matrix = correlation_from_covariance (cov_matrix, var_matrix);
302
303   if ( opts->statistics & STATS_DESCRIPTIVES) 
304     output_descriptives (corr, mean_matrix, var_matrix, samples_matrix);
305
306   output_correlation (corr, opts,
307                       corr_matrix,
308                       samples_matrix,
309                       cov_matrix);
310
311   covariance_destroy (cov);
312   gsl_matrix_free (corr_matrix);
313 }
314
315 int
316 cmd_correlation (struct lexer *lexer, struct dataset *ds)
317 {
318   int i;
319   int n_all_vars = 0; /* Total number of variables involved in this command */
320   const struct variable **all_vars ;
321   const struct dictionary *dict = dataset_dict (ds);
322   bool ok = true;
323
324   struct casegrouper *grouper;
325   struct casereader *group;
326
327   struct corr *corr = NULL;
328   size_t n_corrs = 0;
329
330   struct corr_opts opts;
331   opts.missing_type = CORR_PAIRWISE;
332   opts.wv = dict_get_weight (dict);
333   opts.tails = 2;
334   opts.sig = false;
335   opts.exclude = MV_ANY;
336   opts.statistics = 0;
337
338   /* Parse CORRELATIONS. */
339   while (lex_token (lexer) != '.')
340     {
341       lex_match (lexer, '/');
342       if (lex_match_id (lexer, "MISSING"))
343         {
344           lex_match (lexer, '=');
345           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
346             {
347               if (lex_match_id (lexer, "PAIRWISE"))
348                 opts.missing_type = CORR_PAIRWISE;
349               else if (lex_match_id (lexer, "LISTWISE"))
350                 opts.missing_type = CORR_LISTWISE;
351
352               else if (lex_match_id (lexer, "INCLUDE"))
353                 opts.exclude = MV_SYSTEM;
354               else if (lex_match_id (lexer, "EXCLUDE"))
355                 opts.exclude = MV_ANY;
356               else
357                 {
358                   lex_error (lexer, NULL);
359                   goto error;
360                 }
361               lex_match (lexer, ',');
362             }
363         }
364       else if (lex_match_id (lexer, "PRINT"))
365         {
366           lex_match (lexer, '=');
367           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
368             {
369               if ( lex_match_id (lexer, "TWOTAIL"))
370                 opts.tails = 2;
371               else if (lex_match_id (lexer, "ONETAIL"))
372                 opts.tails = 1;
373               else if (lex_match_id (lexer, "SIG"))
374                 opts.sig = false;
375               else if (lex_match_id (lexer, "NOSIG"))
376                 opts.sig = true;
377               else
378                 {
379                   lex_error (lexer, NULL);
380                   goto error;
381                 }
382
383               lex_match (lexer, ',');
384             }
385         }
386       else if (lex_match_id (lexer, "STATISTICS"))
387         {
388           lex_match (lexer, '=');
389           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
390             {
391               if ( lex_match_id (lexer, "DESCRIPTIVES"))
392                 opts.statistics = STATS_DESCRIPTIVES;
393               else if (lex_match_id (lexer, "XPROD"))
394                 opts.statistics = STATS_XPROD;
395               else if (lex_token (lexer) == T_ALL)
396                 {
397                   opts.statistics = STATS_ALL;
398                   lex_get (lexer);
399                 }
400               else 
401                 {
402                   lex_error (lexer, NULL);
403                   goto error;
404                 }
405
406               lex_match (lexer, ',');
407             }
408         }
409       else
410         {
411           if (lex_match_id (lexer, "VARIABLES"))
412             {
413               lex_match (lexer, '=');
414             }
415
416           corr = xrealloc (corr, sizeof (*corr) * (n_corrs + 1));
417           corr[n_corrs].n_vars_total = corr[n_corrs].n_vars1 = 0;
418       
419           if ( ! parse_variables_const (lexer, dict, &corr[n_corrs].vars, 
420                                         &corr[n_corrs].n_vars_total,
421                                         PV_NUMERIC))
422             {
423               ok = false;
424               break;
425             }
426
427
428           corr[n_corrs].n_vars1 = corr[n_corrs].n_vars_total;
429
430           if ( lex_match (lexer, T_WITH))
431             {
432               if ( ! parse_variables_const (lexer, dict,
433                                             &corr[n_corrs].vars, &corr[n_corrs].n_vars_total,
434                                             PV_NUMERIC | PV_APPEND))
435                 {
436                   ok = false;
437                   break;
438                 }
439             }
440
441           n_all_vars += corr[n_corrs].n_vars_total;
442
443           n_corrs++;
444         }
445     }
446
447   if (n_corrs == 0)
448     {
449       msg (SE, _("No variables specified."));
450       goto error;
451     }
452
453
454   all_vars = xmalloc (sizeof (*all_vars) * n_all_vars);
455
456   {
457     /* FIXME:  Using a hash here would make more sense */
458     const struct variable **vv = all_vars;
459
460     for (i = 0 ; i < n_corrs; ++i)
461       {
462         int v;
463         const struct corr *c = &corr[i];
464         for (v = 0 ; v < c->n_vars_total; ++v)
465           *vv++ = c->vars[v];
466       }
467   }
468
469   grouper = casegrouper_create_splits (proc_open (ds), dict);
470
471   while (casegrouper_get_next_group (grouper, &group))
472     {
473       for (i = 0 ; i < n_corrs; ++i)
474         {
475           /* FIXME: No need to iterate the data multiple times */
476           struct casereader *r = casereader_clone (group);
477
478           if ( opts.missing_type == CORR_LISTWISE)
479             r = casereader_create_filter_missing (r, all_vars, n_all_vars,
480                                                   opts.exclude, NULL, NULL);
481
482
483           run_corr (r, &opts,  &corr[i]);
484           casereader_destroy (r);
485         }
486       casereader_destroy (group);
487     }
488
489   ok = casegrouper_destroy (grouper);
490   ok = proc_commit (ds) && ok;
491
492   free (all_vars);
493
494
495   /* Done. */
496   free (corr);
497   return ok ? CMD_SUCCESS : CMD_CASCADING_FAILURE;
498
499  error:
500   free (corr);
501   return CMD_FAILURE;
502 }