Initial implementation of the Kruskal-Wallis test.
[pspp-builds.git] / src / language / stats / kruskal-wallis.c
1 /* Pspp - a program for statistical analysis.
2    Copyright (C) 2010 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17
18 #include <config.h>
19
20 #include "kruskal-wallis.h"
21
22 #include <gsl/gsl_cdf.h>
23 #include <math.h>
24
25 #include <data/casereader.h>
26 #include <data/casewriter.h>
27 #include <data/dictionary.h>
28 #include <data/format.h>
29 #include <data/procedure.h>
30 #include <data/subcase.h>
31 #include <data/variable.h>
32
33 #include <libpspp/assertion.h>
34 #include <libpspp/message.h>
35 #include <libpspp/misc.h>
36 #include <libpspp/hmap.h>
37 #include <math/sort.h>
38
39
40 #include "minmax.h"
41 #include "xalloc.h"
42
43
44 static bool
45 include_func (const struct ccase *c, void *aux)
46 {
47   const struct n_sample_test *nst = aux;
48
49   if (0 < value_compare_3way (&nst->val1, case_data (c, nst->indep_var), var_get_width (nst->indep_var)))
50     return false;
51
52   if (0 > value_compare_3way (&nst->val2, case_data (c, nst->indep_var), var_get_width (nst->indep_var)))
53     return false;
54
55   return true;
56 }
57
58
59 struct rank_entry
60 {
61   struct hmap_node node;
62   union value group;
63
64   double sum_of_ranks;
65   double n;
66 };
67
68 static struct rank_entry *
69 find_rank_entry (const struct hmap *map, const union value *group, size_t width)
70 {
71   struct rank_entry *re = NULL;
72   size_t hash  = value_hash (group, width, 0);
73
74   HMAP_FOR_EACH_WITH_HASH (re, struct rank_entry, node, hash, map)
75     {
76       if (0 == value_compare_3way (group, &re->group, width))
77         return re;
78     }
79   
80   return re;
81 }
82
83 static void
84 distinct_callback (double v UNUSED, casenumber t, double w UNUSED, void *aux)
85 {
86   double *tiebreaker = aux;
87
88   *tiebreaker += pow3 (t) - t;
89 }
90
91
92 struct kw
93 {
94   struct hmap map;
95   double h;
96 };
97
98 static void show_ranks_box (const struct n_sample_test *nst, const struct kw *kw, int n_groups);
99 static void show_sig_box (const struct n_sample_test *nst, const struct kw *kw);
100
101 void
102 kruskal_wallis_execute (const struct dataset *ds,
103                         struct casereader *input,
104                         enum mv_class exclude,
105                         const struct npar_test *test,
106                         bool exact UNUSED,
107                         double timer UNUSED)
108 {
109   int i;
110   struct ccase *c;
111   bool warn = true;
112   const struct dictionary *dict = dataset_dict (ds);
113   const struct n_sample_test *nst = UP_CAST (test, const struct n_sample_test, parent);
114   const struct caseproto *proto ;
115   size_t rank_idx ;
116
117   int total_n_groups = 0.0;
118
119   struct kw *kw = xcalloc (nst->n_vars, sizeof *kw);
120
121   /* If the independent variable is missing, then we ignore the case */
122   input = casereader_create_filter_missing (input, 
123                                             &nst->indep_var, 1,
124                                             exclude,
125                                             NULL, NULL);
126
127   input = casereader_create_filter_weight (input, dict, &warn, NULL);
128
129   /* Remove all those cases which are outside the range (val1, val2) */
130   input = casereader_create_filter_func (input, include_func, NULL, nst, NULL);
131
132   proto = casereader_get_proto (input);
133   rank_idx = caseproto_get_n_widths (proto);
134
135   /* Rank cases by the v value */
136   for (i = 0; i < nst->n_vars; ++i)
137     {
138       double tiebreaker = 0.0;
139       bool warn = true;
140       enum rank_error rerr = 0;
141       struct casereader *rr;
142       struct casereader *r = casereader_clone (input);
143
144       r = sort_execute_1var (r, nst->vars[i]);
145
146       /* Ignore missings in the test variable */
147       r = casereader_create_filter_missing (r, &nst->vars[i], 1,
148                                             exclude,
149                                             NULL, NULL);
150
151       rr = casereader_create_append_rank (r, 
152                                           nst->vars[i],
153                                           dict_get_weight (dict),
154                                           &rerr,
155                                           distinct_callback, &tiebreaker);
156
157       hmap_init (&kw[i].map);
158       for (; (c = casereader_read (rr)); case_unref (c))
159         {
160           const union value *group = case_data (c, nst->indep_var);
161           const size_t group_var_width = var_get_width (nst->indep_var);
162           struct rank_entry *rank = find_rank_entry (&kw[i].map, group, group_var_width); 
163
164           if ( NULL == rank)
165             {
166               rank = xzalloc (sizeof *rank);
167               value_clone (&rank->group, group, group_var_width);
168
169               hmap_insert (&kw[i].map, &rank->node,
170                            value_hash (&rank->group, group_var_width, 0));
171             }
172
173           rank->sum_of_ranks += case_data_idx (c, rank_idx)->f;
174           rank->n += dict_get_case_weight (dict, c, &warn);
175
176           /* If this assertion fires, then either the data wasn't sorted or some other
177              problem occured */
178           assert (rerr == 0);
179         }
180
181       casereader_destroy (rr);
182
183       {
184         struct rank_entry *mre;
185         double n = 0.0;
186
187         HMAP_FOR_EACH (mre, struct rank_entry, node, &kw[i].map)
188           {
189             kw[i].h += pow2 (mre->sum_of_ranks) / mre->n;
190             n += mre->n;
191
192             total_n_groups ++;
193           }
194         kw[i].h *= 12 / (n * ( n + 1));
195         kw[i].h -= 3 * (n + 1) ; 
196
197         kw[i].h /= 1 - tiebreaker/ (pow3 (n) - n);
198       }
199     }
200
201   casereader_destroy (input);
202   
203   show_ranks_box (nst, kw, total_n_groups);
204   show_sig_box (nst, kw);
205
206   free (kw);
207 }
208
209 \f
210 #include <output/tab.h>
211 #include "gettext.h"
212 #define _(msgid) gettext (msgid)
213
214
215 static void
216 show_ranks_box (const struct n_sample_test *nst, const struct kw *kw, int n_groups)
217 {
218   int i;
219   const int row_headers = 2;
220   const int column_headers = 1;
221   struct tab_table *table =
222     tab_create (row_headers + 2, column_headers + n_groups + nst->n_vars);
223
224   tab_headers (table, row_headers, 0, column_headers, 0);
225
226   tab_title (table, _("Ranks"));
227
228   /* Vertical lines inside the box */
229   tab_box (table, 1, 0, -1, TAL_1,
230            row_headers, 0, tab_nc (table) - 1, tab_nr (table) - 1 );
231
232   /* Box around the table */
233   tab_box (table, TAL_2, TAL_2, -1, -1,
234            0,  0, tab_nc (table) - 1, tab_nr (table) - 1 );
235
236   tab_text (table, 1, 0, TAT_TITLE, 
237             var_to_string (nst->indep_var)
238             );
239
240   tab_text (table, 3, 0, 0, _("Mean Rank"));
241   tab_text (table, 2, 0, 0, _("N"));
242
243   tab_hline (table, TAL_2, 0, tab_nc (table) -1, column_headers);
244   tab_vline (table, TAL_2, row_headers, 0, tab_nr (table) - 1);
245
246
247   int x = column_headers;
248   for (i = 0 ; i < nst->n_vars ; ++i)
249     {
250       int tot = 0;
251       const struct rank_entry *re;
252
253       if (i > 0)
254         tab_hline (table, TAL_1, 0, tab_nc (table) -1, x);
255       
256       tab_text (table,  0, x,
257                 TAT_TITLE, var_to_string (nst->vars[i]));
258
259       HMAP_FOR_EACH (re, const struct rank_entry, node, &kw[i].map)
260         {
261           struct string str;
262           ds_init_empty (&str);
263
264           var_append_value_name (nst->indep_var, &re->group, &str);
265
266           tab_text   (table, 1, x, TAB_LEFT, ds_cstr (&str));
267           tab_double (table, 2, x, TAB_LEFT, re->n, &F_8_0);
268           tab_double (table, 3, x, TAB_LEFT, re->sum_of_ranks / re->n, 0);
269
270           tot += re->n;
271           x++;
272           ds_destroy (&str);
273         }
274       tab_double (table, 2, x, TAB_LEFT,
275                   tot, &F_8_0);
276       tab_text (table, 1, x++, TAB_LEFT, _("Total"));
277     }
278
279   tab_submit (table);
280 }
281
282
283 static void
284 show_sig_box (const struct n_sample_test *nst, const struct kw *kw)
285 {
286   int i;
287   const int row_headers = 1;
288   const int column_headers = 1;
289   struct tab_table *table =
290     tab_create (row_headers + nst->n_vars * 2, column_headers + 3);
291
292   tab_headers (table, row_headers, 0, column_headers, 0);
293
294   tab_title (table, _("Test Statistics"));
295
296   tab_text (table,  0, column_headers,
297             TAT_TITLE | TAB_LEFT , _("Chi-Square"));
298
299   tab_text (table,  0, 1 + column_headers,
300             TAT_TITLE | TAB_LEFT, _("df"));
301
302   tab_text (table,  0, 2 + column_headers,
303             TAT_TITLE | TAB_LEFT, _("Asymp. Sig."));
304
305   /* Box around the table */
306   tab_box (table, TAL_2, TAL_2, -1, -1,
307            0,  0, tab_nc (table) - 1, tab_nr (table) - 1 );
308
309
310   tab_hline (table, TAL_2, 0, tab_nc (table) -1, column_headers);
311   tab_vline (table, TAL_2, row_headers, 0, tab_nr (table) - 1);
312
313   for (i = 0 ; i < nst->n_vars; ++i)
314     {
315       const double df = hmap_count (&kw[i].map) - 1;
316       tab_text (table, column_headers + 1 + i, 0, TAT_TITLE, 
317                 var_to_string (nst->vars[i])
318                 );
319
320       tab_double (table, column_headers + 1 + i, 1, 0,
321                   kw[i].h, 0);
322
323       tab_double (table, column_headers + 1 + i, 2, 0,
324                   df, &F_8_0);
325
326       tab_double (table, column_headers + 1 + i, 3, 0,
327                   gsl_cdf_chisq_Q (kw[i].h, df),
328                   0);
329     }
330
331   tab_submit (table);
332 }