ea8fe8ceebdf5c01c8de67a94dc7a1eb7fba5a3b
[pspp-builds.git] / src / language / stats / kruskal-wallis.c
1 /* Pspp - a program for statistical analysis.
2    Copyright (C) 2010 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17
18 #include <config.h>
19
20 #include "kruskal-wallis.h"
21
22 #include <gsl/gsl_cdf.h>
23 #include <math.h>
24
25 #include <data/casereader.h>
26 #include <data/casewriter.h>
27 #include <data/dictionary.h>
28 #include <data/format.h>
29 #include <data/procedure.h>
30 #include <data/subcase.h>
31 #include <data/variable.h>
32
33 #include <libpspp/assertion.h>
34 #include <libpspp/message.h>
35 #include <libpspp/misc.h>
36 #include <libpspp/hmap.h>
37 #include <math/sort.h>
38
39
40 #include "minmax.h"
41 #include "xalloc.h"
42
43
44 /* Returns true iff the independent variable lies in the range [nst->val1, nst->val2] */
45 static bool
46 include_func (const struct ccase *c, void *aux)
47 {
48   const struct n_sample_test *nst = aux;
49
50   if (0 < value_compare_3way (&nst->val1, case_data (c, nst->indep_var), var_get_width (nst->indep_var)))
51     return false;
52
53   if (0 > value_compare_3way (&nst->val2, case_data (c, nst->indep_var), var_get_width (nst->indep_var)))
54     return false;
55
56   return true;
57 }
58
59
60 struct rank_entry
61 {
62   struct hmap_node node;
63   union value group;
64
65   double sum_of_ranks;
66   double n;
67 };
68
69 /* Return the entry with the key GROUP or null if there is no such entry */
70 static struct rank_entry *
71 find_rank_entry (const struct hmap *map, const union value *group, size_t width)
72 {
73   struct rank_entry *re = NULL;
74   size_t hash  = value_hash (group, width, 0);
75
76   HMAP_FOR_EACH_WITH_HASH (re, struct rank_entry, node, hash, map)
77     {
78       if (0 == value_compare_3way (group, &re->group, width))
79         return re;
80     }
81   
82   return re;
83 }
84
85 /* Calculates the adjustment necessary for tie compensation */
86 static void
87 distinct_callback (double v UNUSED, casenumber t, double w UNUSED, void *aux)
88 {
89   double *tiebreaker = aux;
90
91   *tiebreaker += pow3 (t) - t;
92 }
93
94
95 struct kw
96 {
97   struct hmap map;
98   double h;
99 };
100
101 static void show_ranks_box (const struct n_sample_test *nst, const struct kw *kw, int n_groups);
102 static void show_sig_box (const struct n_sample_test *nst, const struct kw *kw);
103
104 void
105 kruskal_wallis_execute (const struct dataset *ds,
106                         struct casereader *input,
107                         enum mv_class exclude,
108                         const struct npar_test *test,
109                         bool exact UNUSED,
110                         double timer UNUSED)
111 {
112   int i;
113   struct ccase *c;
114   bool warn = true;
115   const struct dictionary *dict = dataset_dict (ds);
116   const struct n_sample_test *nst = UP_CAST (test, const struct n_sample_test, parent);
117   const struct caseproto *proto ;
118   size_t rank_idx ;
119
120   int total_n_groups = 0.0;
121
122   struct kw *kw = xcalloc (nst->n_vars, sizeof *kw);
123
124   /* If the independent variable is missing, then we ignore the case */
125   input = casereader_create_filter_missing (input, 
126                                             &nst->indep_var, 1,
127                                             exclude,
128                                             NULL, NULL);
129
130   input = casereader_create_filter_weight (input, dict, &warn, NULL);
131
132   /* Remove all those cases which are outside the range (val1, val2) */
133   input = casereader_create_filter_func (input, include_func, NULL, 
134         CONST_CAST (struct n_sample_test *, nst), NULL);
135
136   proto = casereader_get_proto (input);
137   rank_idx = caseproto_get_n_widths (proto);
138
139   /* Rank cases by the v value */
140   for (i = 0; i < nst->n_vars; ++i)
141     {
142       double tiebreaker = 0.0;
143       bool warn = true;
144       enum rank_error rerr = 0;
145       struct casereader *rr;
146       struct casereader *r = casereader_clone (input);
147
148       r = sort_execute_1var (r, nst->vars[i]);
149
150       /* Ignore missings in the test variable */
151       r = casereader_create_filter_missing (r, &nst->vars[i], 1,
152                                             exclude,
153                                             NULL, NULL);
154
155       rr = casereader_create_append_rank (r, 
156                                           nst->vars[i],
157                                           dict_get_weight (dict),
158                                           &rerr,
159                                           distinct_callback, &tiebreaker);
160
161       hmap_init (&kw[i].map);
162       for (; (c = casereader_read (rr)); case_unref (c))
163         {
164           const union value *group = case_data (c, nst->indep_var);
165           const size_t group_var_width = var_get_width (nst->indep_var);
166           struct rank_entry *rank = find_rank_entry (&kw[i].map, group, group_var_width); 
167
168           if ( NULL == rank)
169             {
170               rank = xzalloc (sizeof *rank);
171               value_clone (&rank->group, group, group_var_width);
172
173               hmap_insert (&kw[i].map, &rank->node,
174                            value_hash (&rank->group, group_var_width, 0));
175             }
176
177           rank->sum_of_ranks += case_data_idx (c, rank_idx)->f;
178           rank->n += dict_get_case_weight (dict, c, &warn);
179
180           /* If this assertion fires, then either the data wasn't sorted or some other
181              problem occured */
182           assert (rerr == 0);
183         }
184
185       casereader_destroy (rr);
186
187       /* Calculate the value of h */
188       {
189         struct rank_entry *mre;
190         double n = 0.0;
191
192         HMAP_FOR_EACH (mre, struct rank_entry, node, &kw[i].map)
193           {
194             kw[i].h += pow2 (mre->sum_of_ranks) / mre->n;
195             n += mre->n;
196
197             total_n_groups ++;
198           }
199         kw[i].h *= 12 / (n * ( n + 1));
200         kw[i].h -= 3 * (n + 1) ; 
201
202         kw[i].h /= 1 - tiebreaker/ (pow3 (n) - n);
203       }
204     }
205
206   casereader_destroy (input);
207   
208   show_ranks_box (nst, kw, total_n_groups);
209   show_sig_box (nst, kw);
210
211   /* Cleanup allocated memory */
212   for (i = 0 ; i < nst->n_vars; ++i)
213     {
214       struct rank_entry *mre, *next;
215       HMAP_FOR_EACH_SAFE (mre, next, struct rank_entry, node, &kw[i].map)
216         {
217           hmap_delete (&kw[i].map, &mre->node);
218           free (mre);
219         }
220       hmap_destroy (&kw[i].map);
221     }
222
223   free (kw);
224 }
225
226 \f
227 #include <output/tab.h>
228 #include "gettext.h"
229 #define _(msgid) gettext (msgid)
230
231
232 static void
233 show_ranks_box (const struct n_sample_test *nst, const struct kw *kw, int n_groups)
234 {
235   int row;
236   int i;
237   const int row_headers = 2;
238   const int column_headers = 1;
239   struct tab_table *table =
240     tab_create (row_headers + 2, column_headers + n_groups + nst->n_vars);
241
242   tab_headers (table, row_headers, 0, column_headers, 0);
243
244   tab_title (table, _("Ranks"));
245
246   /* Vertical lines inside the box */
247   tab_box (table, 1, 0, -1, TAL_1,
248            row_headers, 0, tab_nc (table) - 1, tab_nr (table) - 1 );
249
250   /* Box around the table */
251   tab_box (table, TAL_2, TAL_2, -1, -1,
252            0,  0, tab_nc (table) - 1, tab_nr (table) - 1 );
253
254   tab_text (table, 1, 0, TAT_TITLE, 
255             var_to_string (nst->indep_var)
256             );
257
258   tab_text (table, 3, 0, 0, _("Mean Rank"));
259   tab_text (table, 2, 0, 0, _("N"));
260
261   tab_hline (table, TAL_2, 0, tab_nc (table) -1, column_headers);
262   tab_vline (table, TAL_2, row_headers, 0, tab_nr (table) - 1);
263
264
265   row = column_headers;
266   for (i = 0 ; i < nst->n_vars ; ++i)
267     {
268       int tot = 0;
269       const struct rank_entry *re;
270
271       if (i > 0)
272         tab_hline (table, TAL_1, 0, tab_nc (table) -1, row);
273       
274       tab_text (table,  0, row,
275                 TAT_TITLE, var_to_string (nst->vars[i]));
276
277       HMAP_FOR_EACH (re, const struct rank_entry, node, &kw[i].map)
278         {
279           struct string str;
280           ds_init_empty (&str);
281
282           var_append_value_name (nst->indep_var, &re->group, &str);
283
284           tab_text   (table, 1, row, TAB_LEFT, ds_cstr (&str));
285           tab_double (table, 2, row, TAB_LEFT, re->n, &F_8_0);
286           tab_double (table, 3, row, TAB_LEFT, re->sum_of_ranks / re->n, 0);
287
288           tot += re->n;
289           row++;
290           ds_destroy (&str);
291         }
292       tab_double (table, 2, row, TAB_LEFT,
293                   tot, &F_8_0);
294       tab_text (table, 1, row++, TAB_LEFT, _("Total"));
295     }
296
297   tab_submit (table);
298 }
299
300
301 static void
302 show_sig_box (const struct n_sample_test *nst, const struct kw *kw)
303 {
304   int i;
305   const int row_headers = 1;
306   const int column_headers = 1;
307   struct tab_table *table =
308     tab_create (row_headers + nst->n_vars * 2, column_headers + 3);
309
310   tab_headers (table, row_headers, 0, column_headers, 0);
311
312   tab_title (table, _("Test Statistics"));
313
314   tab_text (table,  0, column_headers,
315             TAT_TITLE | TAB_LEFT , _("Chi-Square"));
316
317   tab_text (table,  0, 1 + column_headers,
318             TAT_TITLE | TAB_LEFT, _("df"));
319
320   tab_text (table,  0, 2 + column_headers,
321             TAT_TITLE | TAB_LEFT, _("Asymp. Sig."));
322
323   /* Box around the table */
324   tab_box (table, TAL_2, TAL_2, -1, -1,
325            0,  0, tab_nc (table) - 1, tab_nr (table) - 1 );
326
327
328   tab_hline (table, TAL_2, 0, tab_nc (table) -1, column_headers);
329   tab_vline (table, TAL_2, row_headers, 0, tab_nr (table) - 1);
330
331   for (i = 0 ; i < nst->n_vars; ++i)
332     {
333       const double df = hmap_count (&kw[i].map) - 1;
334       tab_text (table, column_headers + 1 + i, 0, TAT_TITLE, 
335                 var_to_string (nst->vars[i])
336                 );
337
338       tab_double (table, column_headers + 1 + i, 1, 0,
339                   kw[i].h, 0);
340
341       tab_double (table, column_headers + 1 + i, 2, 0,
342                   df, &F_8_0);
343
344       tab_double (table, column_headers + 1 + i, 3, 0,
345                   gsl_cdf_chisq_Q (kw[i].h, df),
346                   0);
347     }
348
349   tab_submit (table);
350 }