session: Fix two memory leaks.
[pspp] / src / language / stats / kruskal-wallis.c
1 /* Pspp - a program for statistical analysis.
2    Copyright (C) 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17
18 #include <config.h>
19
20 #include "kruskal-wallis.h"
21
22 #include <gsl/gsl_cdf.h>
23 #include <math.h>
24
25 #include "data/casereader.h"
26 #include "data/casewriter.h"
27 #include "data/dataset.h"
28 #include "data/dictionary.h"
29 #include "data/format.h"
30 #include "data/subcase.h"
31 #include "data/variable.h"
32 #include "libpspp/assertion.h"
33 #include "libpspp/hmap.h"
34 #include "libpspp/bt.h"
35 #include "libpspp/message.h"
36 #include "libpspp/misc.h"
37 #include "math/sort.h"
38 #include "output/tab.h"
39
40 #include "gl/minmax.h"
41 #include "gl/xalloc.h"
42
43
44 /* Returns true iff the independent variable lies in the range [nst->val1, nst->val2] */
45 static bool
46 include_func (const struct ccase *c, void *aux)
47 {
48   const struct n_sample_test *nst = aux;
49
50   if (0 < value_compare_3way (&nst->val1, case_data (c, nst->indep_var), var_get_width (nst->indep_var)))
51     return false;
52
53   if (0 > value_compare_3way (&nst->val2, case_data (c, nst->indep_var), var_get_width (nst->indep_var)))
54     return false;
55
56   return true;
57 }
58
59
60 struct rank_entry
61 {
62   struct hmap_node node;
63   struct bt_node btn;
64   union value group;
65
66   double sum_of_ranks;
67   double n;
68 };
69
70
71 static int
72 compare_rank_entries_3way (const struct bt_node *a,
73                            const struct bt_node *b,
74                            const void *aux)
75 {
76   const struct variable *var = aux;
77   const struct rank_entry *rea = bt_data (a, struct rank_entry, btn);
78   const struct rank_entry *reb = bt_data (b, struct rank_entry, btn);
79
80   return value_compare_3way (&rea->group, &reb->group, var_get_width (var));
81 }
82
83
84 /* Return the entry with the key GROUP or null if there is no such entry */
85 static struct rank_entry *
86 find_rank_entry (const struct hmap *map, const union value *group, size_t width)
87 {
88   struct rank_entry *re = NULL;
89   size_t hash  = value_hash (group, width, 0);
90
91   HMAP_FOR_EACH_WITH_HASH (re, struct rank_entry, node, hash, map)
92     {
93       if (0 == value_compare_3way (group, &re->group, width))
94         return re;
95     }
96   
97   return re;
98 }
99
100 /* Calculates the adjustment necessary for tie compensation */
101 static void
102 distinct_callback (double v UNUSED, casenumber t, double w UNUSED, void *aux)
103 {
104   double *tiebreaker = aux;
105
106   *tiebreaker += pow3 (t) - t;
107 }
108
109
110 struct kw
111 {
112   struct hmap map;
113   double h;
114 };
115
116 static void show_ranks_box (const struct n_sample_test *nst, const struct kw *kw, int n_groups);
117 static void show_sig_box (const struct n_sample_test *nst, const struct kw *kw);
118
119 void
120 kruskal_wallis_execute (const struct dataset *ds,
121                         struct casereader *input,
122                         enum mv_class exclude,
123                         const struct npar_test *test,
124                         bool exact UNUSED,
125                         double timer UNUSED)
126 {
127   int i;
128   struct ccase *c;
129   bool warn = true;
130   const struct dictionary *dict = dataset_dict (ds);
131   const struct n_sample_test *nst = UP_CAST (test, const struct n_sample_test, parent);
132   const struct caseproto *proto ;
133   size_t rank_idx ;
134
135   int total_n_groups = 0.0;
136
137   struct kw *kw = xcalloc (nst->n_vars, sizeof *kw);
138
139   /* If the independent variable is missing, then we ignore the case */
140   input = casereader_create_filter_missing (input, 
141                                             &nst->indep_var, 1,
142                                             exclude,
143                                             NULL, NULL);
144
145   input = casereader_create_filter_weight (input, dict, &warn, NULL);
146
147   /* Remove all those cases which are outside the range (val1, val2) */
148   input = casereader_create_filter_func (input, include_func, NULL, 
149         CONST_CAST (struct n_sample_test *, nst), NULL);
150
151   proto = casereader_get_proto (input);
152   rank_idx = caseproto_get_n_widths (proto);
153
154   /* Rank cases by the v value */
155   for (i = 0; i < nst->n_vars; ++i)
156     {
157       double tiebreaker = 0.0;
158       bool warn = true;
159       enum rank_error rerr = 0;
160       struct casereader *rr;
161       struct casereader *r = casereader_clone (input);
162
163       r = sort_execute_1var (r, nst->vars[i]);
164
165       /* Ignore missings in the test variable */
166       r = casereader_create_filter_missing (r, &nst->vars[i], 1,
167                                             exclude,
168                                             NULL, NULL);
169
170       rr = casereader_create_append_rank (r, 
171                                           nst->vars[i],
172                                           dict_get_weight (dict),
173                                           &rerr,
174                                           distinct_callback, &tiebreaker);
175
176       hmap_init (&kw[i].map);
177       for (; (c = casereader_read (rr)); case_unref (c))
178         {
179           const union value *group = case_data (c, nst->indep_var);
180           const size_t group_var_width = var_get_width (nst->indep_var);
181           struct rank_entry *rank = find_rank_entry (&kw[i].map, group, group_var_width); 
182
183           if ( NULL == rank)
184             {
185               rank = xzalloc (sizeof *rank);
186               value_clone (&rank->group, group, group_var_width);
187
188               hmap_insert (&kw[i].map, &rank->node,
189                            value_hash (&rank->group, group_var_width, 0));
190             }
191
192           rank->sum_of_ranks += case_data_idx (c, rank_idx)->f;
193           rank->n += dict_get_case_weight (dict, c, &warn);
194
195           /* If this assertion fires, then either the data wasn't sorted or some other
196              problem occured */
197           assert (rerr == 0);
198         }
199
200       casereader_destroy (rr);
201
202       /* Calculate the value of h */
203       {
204         struct rank_entry *mre;
205         double n = 0.0;
206
207         HMAP_FOR_EACH (mre, struct rank_entry, node, &kw[i].map)
208           {
209             kw[i].h += pow2 (mre->sum_of_ranks) / mre->n;
210             n += mre->n;
211
212             total_n_groups ++;
213           }
214         kw[i].h *= 12 / (n * ( n + 1));
215         kw[i].h -= 3 * (n + 1) ; 
216
217         kw[i].h /= 1 - tiebreaker/ (pow3 (n) - n);
218       }
219     }
220
221   casereader_destroy (input);
222   
223   show_ranks_box (nst, kw, total_n_groups);
224   show_sig_box (nst, kw);
225
226   /* Cleanup allocated memory */
227   for (i = 0 ; i < nst->n_vars; ++i)
228     {
229       struct rank_entry *mre, *next;
230       HMAP_FOR_EACH_SAFE (mre, next, struct rank_entry, node, &kw[i].map)
231         {
232           hmap_delete (&kw[i].map, &mre->node);
233           free (mre);
234         }
235       hmap_destroy (&kw[i].map);
236     }
237
238   free (kw);
239 }
240
241 \f
242 #include "gettext.h"
243 #define _(msgid) gettext (msgid)
244
245
246 static void
247 show_ranks_box (const struct n_sample_test *nst, const struct kw *kw, int n_groups)
248 {
249   int row;
250   int i;
251   const int row_headers = 2;
252   const int column_headers = 1;
253   struct tab_table *table =
254     tab_create (row_headers + 2, column_headers + n_groups + nst->n_vars);
255
256   tab_headers (table, row_headers, 0, column_headers, 0);
257
258   tab_title (table, _("Ranks"));
259
260   /* Vertical lines inside the box */
261   tab_box (table, 1, 0, -1, TAL_1,
262            row_headers, 0, tab_nc (table) - 1, tab_nr (table) - 1 );
263
264   /* Box around the table */
265   tab_box (table, TAL_2, TAL_2, -1, -1,
266            0,  0, tab_nc (table) - 1, tab_nr (table) - 1 );
267
268   tab_text (table, 1, 0, TAT_TITLE, 
269             var_to_string (nst->indep_var)
270             );
271
272   tab_text (table, 3, 0, 0, _("Mean Rank"));
273   tab_text (table, 2, 0, 0, _("N"));
274
275   tab_hline (table, TAL_2, 0, tab_nc (table) -1, column_headers);
276   tab_vline (table, TAL_2, row_headers, 0, tab_nr (table) - 1);
277
278
279   row = column_headers;
280   for (i = 0 ; i < nst->n_vars ; ++i)
281     {
282       int tot = 0;
283       struct rank_entry *re_x;
284       struct bt_node *bt_n = NULL;
285       struct bt bt;
286
287       if (i > 0)
288         tab_hline (table, TAL_1, 0, tab_nc (table) -1, row);
289       
290       tab_text (table,  0, row,
291                 TAT_TITLE, var_to_string (nst->vars[i]));
292
293       /* Sort the rank entries, by iteratin the hash and putting the entries
294          into a binary tree. */
295       bt_init (&bt, compare_rank_entries_3way, nst->vars[i]);
296       HMAP_FOR_EACH (re_x, struct rank_entry, node, &kw[i].map)
297         {
298           bt_insert (&bt, &re_x->btn);
299         }
300
301       /* Report the rank entries in sorted order. */
302       for (bt_n = bt_first (&bt);
303            bt_n != NULL;
304            bt_n = bt_next (&bt, bt_n) )
305         {
306           const struct rank_entry *re =
307             bt_data (bt_n, const struct rank_entry, btn);
308
309           struct string str;
310           ds_init_empty (&str);
311           
312           var_append_value_name (nst->indep_var, &re->group, &str);
313           
314           tab_text   (table, 1, row, TAB_LEFT, ds_cstr (&str));
315           tab_double (table, 2, row, TAB_LEFT, re->n, &F_8_0);
316           tab_double (table, 3, row, TAB_LEFT, re->sum_of_ranks / re->n, 0);
317           
318           tot += re->n;
319           row++;
320           ds_destroy (&str);
321         }
322
323       tab_double (table, 2, row, TAB_LEFT,
324                   tot, &F_8_0);
325       tab_text (table, 1, row++, TAB_LEFT, _("Total"));
326     }
327
328   tab_submit (table);
329 }
330
331
332 static void
333 show_sig_box (const struct n_sample_test *nst, const struct kw *kw)
334 {
335   int i;
336   const int row_headers = 1;
337   const int column_headers = 1;
338   struct tab_table *table =
339     tab_create (row_headers + nst->n_vars * 2, column_headers + 3);
340
341   tab_headers (table, row_headers, 0, column_headers, 0);
342
343   tab_title (table, _("Test Statistics"));
344
345   tab_text (table,  0, column_headers,
346             TAT_TITLE | TAB_LEFT , _("Chi-Square"));
347
348   tab_text (table,  0, 1 + column_headers,
349             TAT_TITLE | TAB_LEFT, _("df"));
350
351   tab_text (table,  0, 2 + column_headers,
352             TAT_TITLE | TAB_LEFT, _("Asymp. Sig."));
353
354   /* Box around the table */
355   tab_box (table, TAL_2, TAL_2, -1, -1,
356            0,  0, tab_nc (table) - 1, tab_nr (table) - 1 );
357
358
359   tab_hline (table, TAL_2, 0, tab_nc (table) -1, column_headers);
360   tab_vline (table, TAL_2, row_headers, 0, tab_nr (table) - 1);
361
362   for (i = 0 ; i < nst->n_vars; ++i)
363     {
364       const double df = hmap_count (&kw[i].map) - 1;
365       tab_text (table, column_headers + 1 + i, 0, TAT_TITLE, 
366                 var_to_string (nst->vars[i])
367                 );
368
369       tab_double (table, column_headers + 1 + i, 1, 0,
370                   kw[i].h, 0);
371
372       tab_double (table, column_headers + 1 + i, 2, 0,
373                   df, &F_8_0);
374
375       tab_double (table, column_headers + 1 + i, 3, 0,
376                   gsl_cdf_chisq_Q (kw[i].h, df),
377                   0);
378     }
379
380   tab_submit (table);
381 }