output: Introduce pivot tables.
[pspp] / src / language / stats / correlations.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <gsl/gsl_cdf.h>
20 #include <gsl/gsl_matrix.h>
21 #include <math.h>
22
23 #include "data/casegrouper.h"
24 #include "data/casereader.h"
25 #include "data/dataset.h"
26 #include "data/dictionary.h"
27 #include "data/format.h"
28 #include "data/variable.h"
29 #include "language/command.h"
30 #include "language/dictionary/split-file.h"
31 #include "language/lexer/lexer.h"
32 #include "language/lexer/variable-parser.h"
33 #include "libpspp/assertion.h"
34 #include "libpspp/message.h"
35 #include "libpspp/misc.h"
36 #include "math/correlation.h"
37 #include "math/covariance.h"
38 #include "math/moments.h"
39 #include "output/pivot-table.h"
40
41 #include "gl/xalloc.h"
42 #include "gl/minmax.h"
43
44 #include "gettext.h"
45 #define _(msgid) gettext (msgid)
46 #define N_(msgid) msgid
47
48
49 struct corr
50 {
51   size_t n_vars_total;
52   size_t n_vars1;
53
54   const struct variable **vars;
55 };
56
57
58 /* Handling of missing values. */
59 enum corr_missing_type
60   {
61     CORR_PAIRWISE,       /* Handle missing values on a per-variable-pair basis. */
62     CORR_LISTWISE        /* Discard entire case if any variable is missing. */
63   };
64
65 enum stats_opts
66   {
67     STATS_DESCRIPTIVES = 0x01,
68     STATS_XPROD = 0x02,
69     STATS_ALL = STATS_XPROD | STATS_DESCRIPTIVES
70   };
71
72 struct corr_opts
73 {
74   enum corr_missing_type missing_type;
75   enum mv_class exclude;      /* Classes of missing values to exclude. */
76
77   bool sig;   /* Flag significant values or not */
78   int tails;  /* Report significance with how many tails ? */
79   enum stats_opts statistics;
80
81   const struct variable *wv;  /* The weight variable (if any) */
82 };
83
84
85 static void
86 output_descriptives (const struct corr *corr, const struct corr_opts *opts,
87                      const gsl_matrix *means,
88                      const gsl_matrix *vars, const gsl_matrix *ns)
89 {
90   struct pivot_table *table = pivot_table_create (
91     N_("Descriptive Statistics"));
92   pivot_table_set_weight_var (table, opts->wv);
93
94   pivot_dimension_create (table, PIVOT_AXIS_COLUMN, N_("Statistics"),
95                           N_("Mean"), PIVOT_RC_OTHER,
96                           N_("Std. Deviation"), PIVOT_RC_OTHER,
97                           N_("N"), PIVOT_RC_COUNT);
98
99   struct pivot_dimension *variables = pivot_dimension_create (
100     table, PIVOT_AXIS_ROW, N_("Variable"));
101
102   for (size_t r = 0 ; r < corr->n_vars_total ; ++r)
103     {
104       const struct variable *v = corr->vars[r];
105
106       int row = pivot_category_create_leaf (variables->root,
107                                             pivot_value_new_variable (v));
108
109       double mean = gsl_matrix_get (means, r, 0);
110       /* Here we want to display the non-biased estimator */
111       double n = gsl_matrix_get (ns, r, 0);
112       double stddev = sqrt (gsl_matrix_get (vars, r, 0) * n / (n - 1));
113       double entries[] = { mean, stddev, n };
114       for (size_t i = 0; i < sizeof entries / sizeof *entries; i++)
115         pivot_table_put2 (table, i, row, pivot_value_new_number (entries[i]));
116     }
117
118   pivot_table_submit (table);
119 }
120
121 static void
122 output_correlation (const struct corr *corr, const struct corr_opts *opts,
123                     const gsl_matrix *cm, const gsl_matrix *samples,
124                     const gsl_matrix *cv)
125 {
126   struct pivot_table *table = pivot_table_create (N_("Correlations"));
127   pivot_table_set_weight_var (table, opts->wv);
128
129   /* Column variable dimension. */
130   struct pivot_dimension *columns = pivot_dimension_create (
131     table, PIVOT_AXIS_COLUMN, N_("Variables"));
132
133   int matrix_cols = (corr->n_vars_total > corr->n_vars1
134                      ? corr->n_vars_total - corr->n_vars1
135                      : corr->n_vars1);
136   for (int c = 0; c < matrix_cols; c++)
137     {
138       const struct variable *v = corr->n_vars_total > corr->n_vars1 ?
139         corr->vars[corr->n_vars1 + c] : corr->vars[c];
140       pivot_category_create_leaf (columns->root, pivot_value_new_variable (v));
141     }
142
143   /* Statistics dimension. */
144   struct pivot_dimension *statistics = pivot_dimension_create (
145     table, PIVOT_AXIS_ROW, N_("Statistics"),
146     N_("Pearson Correlation"), PIVOT_RC_CORRELATION,
147     opts->tails == 2 ? N_("Sig. (2-tailed)") : N_("Sig. (1-tailed)"),
148     PIVOT_RC_SIGNIFICANCE);
149
150   if (opts->statistics & STATS_XPROD)
151     pivot_category_create_leaves (statistics->root, N_("Cross-products"),
152                                   N_("Covariance"));
153
154   if (opts->missing_type != CORR_LISTWISE)
155     pivot_category_create_leaves (statistics->root, N_("N"), PIVOT_RC_COUNT);
156
157   /* Row variable dimension. */
158   struct pivot_dimension *rows = pivot_dimension_create (
159     table, PIVOT_AXIS_ROW, N_("Variables"));
160   for (size_t r = 0; r < corr->n_vars1; r++)
161     pivot_category_create_leaf (rows->root,
162                                 pivot_value_new_variable (corr->vars[r]));
163
164   struct pivot_footnote *sig_footnote = pivot_table_create_footnote (
165     table, pivot_value_new_text (N_("Significant at .05 level")));
166
167   for (int r = 0; r < corr->n_vars1; r++)
168     for (int c = 0; c < matrix_cols; c++)
169       {
170         const int col_index = (corr->n_vars_total > corr->n_vars1
171                                ? corr->n_vars1 + c
172                                : c);
173         double pearson = gsl_matrix_get (cm, r, col_index);
174         double w = gsl_matrix_get (samples, r, col_index);
175         double sig = opts->tails * significance_of_correlation (pearson, w);
176
177         double entries[5];
178         int n = 0;
179         entries[n++] = pearson;
180         entries[n++] = col_index != r ? sig : SYSMIS;
181         if (opts->statistics & STATS_XPROD)
182           {
183             double cov = gsl_matrix_get (cv, r, col_index);
184             const double xprod_dev = cov * w;
185             cov *= w / (w - 1.0);
186
187             entries[n++] = xprod_dev;
188             entries[n++] = cov;
189           }
190         if (opts->missing_type != CORR_LISTWISE)
191           entries[n++] = w;
192
193         for (int i = 0; i < n; i++)
194           if (entries[i] != SYSMIS)
195             {
196               struct pivot_value *v = pivot_value_new_number (entries[i]);
197               if (!i && opts->sig && col_index != r && sig < 0.05)
198                 pivot_value_add_footnote (v, sig_footnote);
199               pivot_table_put3 (table, c, i, r, v);
200             }
201       }
202
203   pivot_table_submit (table);
204 }
205
206
207 static void
208 run_corr (struct casereader *r, const struct corr_opts *opts, const struct corr *corr)
209 {
210   struct ccase *c;
211   const gsl_matrix *var_matrix,  *samples_matrix, *mean_matrix;
212   gsl_matrix *cov_matrix = NULL;
213   gsl_matrix *corr_matrix = NULL;
214   struct covariance *cov = covariance_2pass_create (corr->n_vars_total, corr->vars,
215                                                     NULL,
216                                                     opts->wv, opts->exclude,
217                                                     true);
218
219   struct casereader *rc = casereader_clone (r);
220   for ( ; (c = casereader_read (r) ); case_unref (c))
221     {
222       covariance_accumulate_pass1 (cov, c);
223     }
224
225   for ( ; (c = casereader_read (rc) ); case_unref (c))
226     {
227       covariance_accumulate_pass2 (cov, c);
228     }
229   casereader_destroy (rc);
230
231   cov_matrix = covariance_calculate (cov);
232   if (! cov_matrix)
233     {
234       msg (SE, _("The data for the chosen variables are all missing or empty."));
235       goto error;
236     }
237
238   samples_matrix = covariance_moments (cov, MOMENT_NONE);
239   var_matrix = covariance_moments (cov, MOMENT_VARIANCE);
240   mean_matrix = covariance_moments (cov, MOMENT_MEAN);
241
242   corr_matrix = correlation_from_covariance (cov_matrix, var_matrix);
243
244   if ( opts->statistics & STATS_DESCRIPTIVES)
245     output_descriptives (corr, opts, mean_matrix, var_matrix, samples_matrix);
246
247   output_correlation (corr, opts, corr_matrix,
248                       samples_matrix, cov_matrix);
249
250  error:
251   covariance_destroy (cov);
252   gsl_matrix_free (corr_matrix);
253   gsl_matrix_free (cov_matrix);
254 }
255
256 int
257 cmd_correlation (struct lexer *lexer, struct dataset *ds)
258 {
259   int i;
260   int n_all_vars = 0; /* Total number of variables involved in this command */
261   const struct variable **all_vars ;
262   const struct dictionary *dict = dataset_dict (ds);
263   bool ok = true;
264
265   struct casegrouper *grouper;
266   struct casereader *group;
267
268   struct corr *corr = NULL;
269   size_t n_corrs = 0;
270
271   struct corr_opts opts;
272   opts.missing_type = CORR_PAIRWISE;
273   opts.wv = dict_get_weight (dict);
274   opts.tails = 2;
275   opts.sig = false;
276   opts.exclude = MV_ANY;
277   opts.statistics = 0;
278
279   /* Parse CORRELATIONS. */
280   while (lex_token (lexer) != T_ENDCMD)
281     {
282       lex_match (lexer, T_SLASH);
283       if (lex_match_id (lexer, "MISSING"))
284         {
285           lex_match (lexer, T_EQUALS);
286           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
287             {
288               if (lex_match_id (lexer, "PAIRWISE"))
289                 opts.missing_type = CORR_PAIRWISE;
290               else if (lex_match_id (lexer, "LISTWISE"))
291                 opts.missing_type = CORR_LISTWISE;
292
293               else if (lex_match_id (lexer, "INCLUDE"))
294                 opts.exclude = MV_SYSTEM;
295               else if (lex_match_id (lexer, "EXCLUDE"))
296                 opts.exclude = MV_ANY;
297               else
298                 {
299                   lex_error (lexer, NULL);
300                   goto error;
301                 }
302               lex_match (lexer, T_COMMA);
303             }
304         }
305       else if (lex_match_id (lexer, "PRINT"))
306         {
307           lex_match (lexer, T_EQUALS);
308           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
309             {
310               if ( lex_match_id (lexer, "TWOTAIL"))
311                 opts.tails = 2;
312               else if (lex_match_id (lexer, "ONETAIL"))
313                 opts.tails = 1;
314               else if (lex_match_id (lexer, "SIG"))
315                 opts.sig = false;
316               else if (lex_match_id (lexer, "NOSIG"))
317                 opts.sig = true;
318               else
319                 {
320                   lex_error (lexer, NULL);
321                   goto error;
322                 }
323
324               lex_match (lexer, T_COMMA);
325             }
326         }
327       else if (lex_match_id (lexer, "STATISTICS"))
328         {
329           lex_match (lexer, T_EQUALS);
330           while (lex_token (lexer) != T_ENDCMD && lex_token (lexer) != T_SLASH)
331             {
332               if ( lex_match_id (lexer, "DESCRIPTIVES"))
333                 opts.statistics = STATS_DESCRIPTIVES;
334               else if (lex_match_id (lexer, "XPROD"))
335                 opts.statistics = STATS_XPROD;
336               else if (lex_token (lexer) == T_ALL)
337                 {
338                   opts.statistics = STATS_ALL;
339                   lex_get (lexer);
340                 }
341               else
342                 {
343                   lex_error (lexer, NULL);
344                   goto error;
345                 }
346
347               lex_match (lexer, T_COMMA);
348             }
349         }
350       else
351         {
352           if (lex_match_id (lexer, "VARIABLES"))
353             {
354               lex_match (lexer, T_EQUALS);
355             }
356
357           corr = xrealloc (corr, sizeof (*corr) * (n_corrs + 1));
358           corr[n_corrs].n_vars_total = corr[n_corrs].n_vars1 = 0;
359
360           if ( ! parse_variables_const (lexer, dict, &corr[n_corrs].vars,
361                                         &corr[n_corrs].n_vars_total,
362                                         PV_NUMERIC))
363             {
364               ok = false;
365               break;
366             }
367
368
369           corr[n_corrs].n_vars1 = corr[n_corrs].n_vars_total;
370
371           if ( lex_match (lexer, T_WITH))
372             {
373               if ( ! parse_variables_const (lexer, dict,
374                                             &corr[n_corrs].vars, &corr[n_corrs].n_vars_total,
375                                             PV_NUMERIC | PV_APPEND))
376                 {
377                   ok = false;
378                   break;
379                 }
380             }
381
382           n_all_vars += corr[n_corrs].n_vars_total;
383
384           n_corrs++;
385         }
386     }
387
388   if (n_corrs == 0)
389     {
390       msg (SE, _("No variables specified."));
391       goto error;
392     }
393
394
395   all_vars = xmalloc (sizeof (*all_vars) * n_all_vars);
396
397   {
398     /* FIXME:  Using a hash here would make more sense */
399     const struct variable **vv = all_vars;
400
401     for (i = 0 ; i < n_corrs; ++i)
402       {
403         int v;
404         const struct corr *c = &corr[i];
405         for (v = 0 ; v < c->n_vars_total; ++v)
406           *vv++ = c->vars[v];
407       }
408   }
409
410   grouper = casegrouper_create_splits (proc_open (ds), dict);
411
412   while (casegrouper_get_next_group (grouper, &group))
413     {
414       for (i = 0 ; i < n_corrs; ++i)
415         {
416           /* FIXME: No need to iterate the data multiple times */
417           struct casereader *r = casereader_clone (group);
418
419           if ( opts.missing_type == CORR_LISTWISE)
420             r = casereader_create_filter_missing (r, all_vars, n_all_vars,
421                                                   opts.exclude, NULL, NULL);
422
423
424           run_corr (r, &opts,  &corr[i]);
425           casereader_destroy (r);
426         }
427       casereader_destroy (group);
428     }
429
430   ok = casegrouper_destroy (grouper);
431   ok = proc_commit (ds) && ok;
432
433   free (all_vars);
434
435
436   /* Done. */
437   free (corr->vars);
438   free (corr);
439
440   return ok ? CMD_SUCCESS : CMD_CASCADING_FAILURE;
441
442  error:
443   if (corr)
444     free (corr->vars);
445   free (corr);
446   return CMD_FAILURE;
447 }