First working version of CORRELATIONS.
[pspp-builds.git] / src / language / stats / correlations.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <math/covariance.h>
20 #include <math/design-matrix.h>
21 #include <gsl/gsl_matrix.h>
22 #include <data/casegrouper.h>
23 #include <data/casereader.h>
24 #include <data/dictionary.h>
25 #include <data/procedure.h>
26 #include <data/variable.h>
27 #include <language/command.h>
28 #include <language/dictionary/split-file.h>
29 #include <language/lexer/lexer.h>
30 #include <language/lexer/variable-parser.h>
31 #include <output/manager.h>
32 #include <output/table.h>
33 #include <libpspp/message.h>
34 #include <data/format.h>
35 #include <math/moments.h>
36
37 #include <math.h>
38 #include "xalloc.h"
39 #include "minmax.h"
40 #include <libpspp/misc.h>
41 #include <gsl/gsl_cdf.h>
42
43 #include "gettext.h"
44 #define _(msgid) gettext (msgid)
45 #define N_(msgid) msgid
46
47
48 static double
49 significance_of_correlation (double rho, double w)
50 {
51   double t = w - 2;
52   t /= 1 - MIN (1, pow2 (rho));
53   t = sqrt (t);
54   t *= rho;
55   
56   if (t > 0)
57     return  gsl_cdf_tdist_Q (t, w - 2);
58   else
59     return  gsl_cdf_tdist_P (t, w - 2);
60 }
61
62
63 struct corr
64 {
65   size_t n_vars_total;
66   size_t n_vars1;
67
68   const struct variable **vars;
69 };
70
71
72 /* Handling of missing values. */
73 enum corr_missing_type
74   {
75     CORR_PAIRWISE,       /* Handle missing values on a per-variable-pair basis. */
76     CORR_LISTWISE        /* Discard entire case if any variable is missing. */
77   };
78
79 struct corr_opts
80 {
81   enum corr_missing_type missing_type;
82   enum mv_class exclude;      /* Classes of missing values to exclude. */
83
84   bool sig;   /* Flag significant values or not */
85   int tails;  /* Report significance with how many tails ? */
86
87   const struct variable *wv;  /* The weight variable (if any) */
88 };
89
90
91 static void
92 output_correlation (const struct corr *corr, const struct corr_opts *opts,
93                     const gsl_matrix *cm, const gsl_matrix *samples)
94 {
95   int r, c;
96   struct tab_table *t;
97   int matrix_cols;
98   int nr = corr->n_vars1;
99   int nc = matrix_cols = corr->n_vars_total > corr->n_vars1 ?
100     corr->n_vars_total - corr->n_vars1 : corr->n_vars1;
101
102   const struct fmt_spec *wfmt = opts->wv ? var_get_print_format (opts->wv) : & F_8_0;
103
104   const int heading_columns = 2;
105   const int heading_rows = 1;
106
107   const int rows_per_variable = opts->missing_type == CORR_LISTWISE ? 2 : 3;
108
109   /* Two header columns */
110   nc += heading_columns;
111
112   /* Three data per variable */
113   nr *= rows_per_variable;
114
115   /* One header row */
116   nr += heading_rows;
117
118   t = tab_create (nc, nr, 0);
119   tab_title (t, _("Correlations"));
120   tab_dim (t, tab_natural_dimensions, NULL);
121
122   tab_headers (t, heading_columns, 0, heading_rows, 0);
123
124   /* Outline the box */
125   tab_box (t,
126            TAL_2, TAL_2,
127            -1, -1,
128            0, 0,
129            nc - 1, nr - 1);
130
131   /* Vertical lines */
132   tab_box (t,
133            -1, -1,
134            -1, TAL_1,
135            heading_columns, 0,
136            nc - 1, nr - 1);
137
138   tab_vline (t, TAL_2, heading_columns, 0, nr - 1);
139   tab_vline (t, TAL_1, 1, heading_rows, nr - 1);
140
141   for (r = 0 ; r < corr->n_vars1 ; ++r)
142     {
143       tab_text (t, 0, 1 + r * rows_per_variable, TAB_LEFT | TAT_TITLE, 
144                 var_to_string (corr->vars[r]));
145
146       tab_text (t, 1, 1 + r * rows_per_variable, TAB_LEFT | TAT_TITLE, _("Pearson Correlation"));
147       tab_text (t, 1, 2 + r * rows_per_variable, TAB_LEFT | TAT_TITLE, 
148                 (opts->tails == 2) ? _("Sig. (2-tailed)") : _("Sig. (1-tailed)"));
149       if ( opts->missing_type != CORR_LISTWISE )
150         tab_text (t, 1, 3 + r * rows_per_variable, TAB_LEFT | TAT_TITLE, _("N"));
151       tab_hline (t, TAL_1, 0, nc - 1, r * rows_per_variable + 1);
152     }
153
154   for (c = 0 ; c < matrix_cols ; ++c)
155     {
156       const struct variable *v = corr->n_vars_total > corr->n_vars1 ? corr->vars[corr->n_vars_total - corr->n_vars1 + c] : corr->vars[c];
157       tab_text (t, heading_columns + c, 0, TAB_LEFT | TAT_TITLE, var_to_string (v));      
158     }
159
160   for (r = 0 ; r < corr->n_vars1 ; ++r)
161     {
162       const int row = r * rows_per_variable + heading_rows;
163       for (c = 0 ; c < matrix_cols ; ++c)
164         {
165           unsigned char flags = 0; 
166           int col_index = corr->n_vars_total - corr->n_vars1 + c;
167           double pearson = gsl_matrix_get (cm, r, col_index);
168           double w = gsl_matrix_get (samples, r, col_index);
169           double sig = opts->tails * significance_of_correlation (pearson, w);
170
171           if ( opts->missing_type != CORR_LISTWISE )
172             tab_double (t, c + heading_columns, row + 2, 0, w, wfmt);
173
174           if ( c != r)
175             tab_double (t, c + heading_columns, row + 1, 0,  sig, NULL);
176
177           if ( opts->sig && c != r && sig < 0.05)
178             flags = TAB_EMPH;
179           
180           tab_double (t, c + heading_columns, row, flags, pearson, NULL);
181         }
182     }
183
184   tab_submit (t);
185 }
186
187
188 static gsl_matrix *
189 correlation_from_covariance (const gsl_matrix *cv, const gsl_matrix *v)
190 {
191   size_t i, j;
192   gsl_matrix *corr = gsl_matrix_calloc (cv->size1, cv->size2);
193   
194   for (i = 0 ; i < cv->size1; ++i)
195     {
196       for (j = 0 ; j < cv->size2; ++j)
197         {
198           double rho = gsl_matrix_get (cv, i, j);
199           
200           rho /= sqrt (gsl_matrix_get (v, i, j))
201             * 
202             sqrt (gsl_matrix_get (v, j, i));
203           
204           gsl_matrix_set (corr, i, j, rho);
205         }
206     }
207   
208   return corr;
209 }
210
211
212
213
214 static void
215 run_corr (struct casereader *r, const struct corr_opts *opts, const struct corr *corr)
216 {
217   struct ccase *c;
218   const gsl_matrix *var_matrix;
219   const gsl_matrix *samples_matrix;
220   const gsl_matrix *cov_matrix;
221   gsl_matrix *corr_matrix;
222   struct covariance *cov = covariance_create (corr->n_vars_total, corr->vars,
223                                               opts->wv, opts->exclude);
224
225   for ( ; (c = casereader_read (r) ); case_unref (c))
226     {
227       covariance_accumulate (cov, c);
228     }
229
230   cov_matrix = covariance_calculate (cov);
231
232   samples_matrix = covariance_moments (cov, MOMENT_NONE);
233   var_matrix = covariance_moments (cov, MOMENT_VARIANCE);
234
235   corr_matrix = correlation_from_covariance (cov_matrix, var_matrix);
236
237   output_correlation (corr, opts,
238                       corr_matrix,
239                       samples_matrix );
240
241   covariance_destroy (cov);
242   gsl_matrix_free (corr_matrix);
243 }
244
245 int
246 cmd_correlation (struct lexer *lexer, struct dataset *ds)
247 {
248   int i;
249   int n_all_vars = 0; /* Total number of variables involved in this command */
250   const struct variable **all_vars ;
251   const struct dictionary *dict = dataset_dict (ds);
252   bool ok = true;
253
254   struct casegrouper *grouper;
255   struct casereader *group;
256
257   struct corr *corr = NULL;
258   size_t n_corrs = 0;
259
260   struct corr_opts opts;
261   opts.missing_type = CORR_PAIRWISE;
262   opts.wv = dict_get_weight (dict);
263   opts.tails = 2;
264   opts.sig = false;
265   opts.exclude = MV_ANY;
266
267   /* Parse CORRELATIONS. */
268   while (lex_token (lexer) != '.')
269     {
270       lex_match (lexer, '/');
271       if (lex_match_id (lexer, "MISSING"))
272         {
273           lex_match (lexer, '=');
274           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
275             {
276               if (lex_match_id (lexer, "PAIRWISE"))
277                 opts.missing_type = CORR_PAIRWISE;
278               else if (lex_match_id (lexer, "LISTWISE"))
279                 opts.missing_type = CORR_LISTWISE;
280
281               else if (lex_match_id (lexer, "INCLUDE"))
282                 opts.exclude = MV_SYSTEM;
283               else if (lex_match_id (lexer, "EXCLUDE"))
284                 opts.exclude = MV_ANY;
285               else
286                 {
287                   lex_error (lexer, NULL);
288                   goto error;
289                 }
290               lex_match (lexer, ',');
291             }
292         }
293       else if (lex_match_id (lexer, "PRINT"))
294         {
295           lex_match (lexer, '=');
296           while (lex_token (lexer) != '.' && lex_token (lexer) != '/')
297             {
298               if ( lex_match_id (lexer, "TWOTAIL"))
299                 opts.tails = 2;
300               else if (lex_match_id (lexer, "ONETAIL"))
301                 opts.tails = 1;
302               else if (lex_match_id (lexer, "SIG"))
303                 opts.sig = false;
304               else if (lex_match_id (lexer, "NOSIG"))
305                 opts.sig = true;
306               else
307                 {
308                   lex_error (lexer, NULL);
309                   goto error;
310                 }
311
312               lex_match (lexer, ',');
313             }
314         }
315       else
316         {
317           if (lex_match_id (lexer, "VARIABLES"))
318             {
319               lex_match (lexer, '=');
320             }
321
322           corr = xrealloc (corr, sizeof (*corr) * (n_corrs + 1));
323           corr[n_corrs].n_vars_total = corr[n_corrs].n_vars1 = 0;
324       
325           if ( ! parse_variables_const (lexer, dict, &corr[n_corrs].vars, 
326                                         &corr[n_corrs].n_vars_total,
327                                         PV_NUMERIC))
328             {
329               ok = false;
330               break;
331             }
332
333
334           corr[n_corrs].n_vars1 = corr[n_corrs].n_vars_total;
335
336           if ( lex_match (lexer, T_WITH))
337             {
338               if ( ! parse_variables_const (lexer, dict,
339                                             &corr[n_corrs].vars, &corr[n_corrs].n_vars_total,
340                                             PV_NUMERIC | PV_APPEND))
341                 {
342                   ok = false;
343                   break;
344                 }
345             }
346
347           n_all_vars += corr[n_corrs].n_vars_total;
348
349           n_corrs++;
350         }
351     }
352
353   if (n_corrs == 0)
354     {
355       msg (SE, _("No variables specified."));
356       goto error;
357     }
358
359
360   all_vars = xmalloc (sizeof (*all_vars) * n_all_vars);
361
362   {
363     /* FIXME:  Using a hash here would make more sense */
364     const struct variable **vv = all_vars;
365
366     for (i = 0 ; i < n_corrs; ++i)
367       {
368         int v;
369         const struct corr *c = &corr[i];
370         for (v = 0 ; v < c->n_vars_total; ++v)
371           *vv++ = c->vars[v];
372       }
373   }
374
375   grouper = casegrouper_create_splits (proc_open (ds), dict);
376
377   while (casegrouper_get_next_group (grouper, &group))
378     {
379       for (i = 0 ; i < n_corrs; ++i)
380         {
381           /* FIXME: No need to iterate the data multiple times */
382           struct casereader *r = casereader_clone (group);
383
384           if ( opts.missing_type == CORR_LISTWISE)
385             r = casereader_create_filter_missing (r, all_vars, n_all_vars,
386                                                   opts.exclude, NULL, NULL);
387
388
389           run_corr (r, &opts,  &corr[i]);
390           casereader_destroy (r);
391         }
392       casereader_destroy (group);
393     }
394
395   ok = casegrouper_destroy (grouper);
396   ok = proc_commit (ds) && ok;
397
398   free (all_vars);
399
400
401   /* Done. */
402   free (corr);
403   return ok ? CMD_SUCCESS : CMD_CASCADING_FAILURE;
404
405  error:
406   free (corr);
407   return CMD_FAILURE;
408 }