From c8b02c29026c095ce912faf5fdba7e29b42cb135 Mon Sep 17 00:00:00 2001
From: John Darrington <john@darrington.wattle.id.au>
Date: Mon, 11 Oct 2010 20:02:54 +0200
Subject: [PATCH] Initial implementation of the Kruskal-Wallis test.

---
 src/language/stats/automake.mk      |  10 +-
 src/language/stats/kruskal-wallis.c | 332 ++++++++++++++++++++++++++++
 src/language/stats/kruskal-wallis.h |  42 ++++
 src/language/stats/npar.c           |  20 +-
 src/language/stats/npar.h           |   5 +-
 5 files changed, 391 insertions(+), 18 deletions(-)
 create mode 100644 src/language/stats/kruskal-wallis.c
 create mode 100644 src/language/stats/kruskal-wallis.h

diff --git a/src/language/stats/automake.mk b/src/language/stats/automake.mk
index e11a8de5..fb191d0c 100644
--- a/src/language/stats/automake.mk
+++ b/src/language/stats/automake.mk
@@ -20,16 +20,15 @@ language_stats_sources = \
 	src/language/stats/chisquare.h \
 	src/language/stats/correlations.c \
 	src/language/stats/descriptives.c \
-	src/language/stats/npar.h \
-	src/language/stats/sort-cases.c \
-	src/language/stats/sort-criteria.c \
-	src/language/stats/sort-criteria.h \
 	src/language/stats/factor.c \
 	src/language/stats/flip.c \
 	src/language/stats/freq.c \
 	src/language/stats/freq.h \
 	src/language/stats/glm.c \
+	src/language/stats/kruskal-wallis.c \
+	src/language/stats/kruskal-wallis.h \
 	src/language/stats/npar.c \
+	src/language/stats/npar.h \
 	src/language/stats/npar-summary.c \
 	src/language/stats/npar-summary.h \
 	src/language/stats/oneway.c \
@@ -38,6 +37,9 @@ language_stats_sources = \
 	src/language/stats/roc.h \
 	src/language/stats/sign.c \
 	src/language/stats/sign.h \
+	src/language/stats/sort-cases.c \
+	src/language/stats/sort-criteria.c \
+	src/language/stats/sort-criteria.h \
 	src/language/stats/wilcoxon.c \
 	src/language/stats/wilcoxon.h
 
diff --git a/src/language/stats/kruskal-wallis.c b/src/language/stats/kruskal-wallis.c
new file mode 100644
index 00000000..420b1a03
--- /dev/null
+++ b/src/language/stats/kruskal-wallis.c
@@ -0,0 +1,332 @@
+/* Pspp - a program for statistical analysis.
+   Copyright (C) 2010 Free Software Foundation, Inc.
+
+   This program is free software: you can redistribute it and/or modify
+   it under the terms of the GNU General Public License as published by
+   the Free Software Foundation, either version 3 of the License, or
+   (at your option) any later version.
+
+   This program is distributed in the hope that it will be useful,
+   but WITHOUT ANY WARRANTY; without even the implied warranty of
+   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+   GNU General Public License for more details.
+
+   You should have received a copy of the GNU General Public License
+   along with this program.  If not, see <http://www.gnu.org/licenses/>. */
+
+
+#include <config.h>
+
+#include "kruskal-wallis.h"
+
+#include <gsl/gsl_cdf.h>
+#include <math.h>
+
+#include <data/casereader.h>
+#include <data/casewriter.h>
+#include <data/dictionary.h>
+#include <data/format.h>
+#include <data/procedure.h>
+#include <data/subcase.h>
+#include <data/variable.h>
+
+#include <libpspp/assertion.h>
+#include <libpspp/message.h>
+#include <libpspp/misc.h>
+#include <libpspp/hmap.h>
+#include <math/sort.h>
+
+
+#include "minmax.h"
+#include "xalloc.h"
+
+
+static bool
+include_func (const struct ccase *c, void *aux)
+{
+  const struct n_sample_test *nst = aux;
+
+  if (0 < value_compare_3way (&nst->val1, case_data (c, nst->indep_var), var_get_width (nst->indep_var)))
+    return false;
+
+  if (0 > value_compare_3way (&nst->val2, case_data (c, nst->indep_var), var_get_width (nst->indep_var)))
+    return false;
+
+  return true;
+}
+
+
+struct rank_entry
+{
+  struct hmap_node node;
+  union value group;
+
+  double sum_of_ranks;
+  double n;
+};
+
+static struct rank_entry *
+find_rank_entry (const struct hmap *map, const union value *group, size_t width)
+{
+  struct rank_entry *re = NULL;
+  size_t hash  = value_hash (group, width, 0);
+
+  HMAP_FOR_EACH_WITH_HASH (re, struct rank_entry, node, hash, map)
+    {
+      if (0 == value_compare_3way (group, &re->group, width))
+	return re;
+    }
+  
+  return re;
+}
+
+static void
+distinct_callback (double v UNUSED, casenumber t, double w UNUSED, void *aux)
+{
+  double *tiebreaker = aux;
+
+  *tiebreaker += pow3 (t) - t;
+}
+
+
+struct kw
+{
+  struct hmap map;
+  double h;
+};
+
+static void show_ranks_box (const struct n_sample_test *nst, const struct kw *kw, int n_groups);
+static void show_sig_box (const struct n_sample_test *nst, const struct kw *kw);
+
+void
+kruskal_wallis_execute (const struct dataset *ds,
+			struct casereader *input,
+			enum mv_class exclude,
+			const struct npar_test *test,
+			bool exact UNUSED,
+			double timer UNUSED)
+{
+  int i;
+  struct ccase *c;
+  bool warn = true;
+  const struct dictionary *dict = dataset_dict (ds);
+  const struct n_sample_test *nst = UP_CAST (test, const struct n_sample_test, parent);
+  const struct caseproto *proto ;
+  size_t rank_idx ;
+
+  int total_n_groups = 0.0;
+
+  struct kw *kw = xcalloc (nst->n_vars, sizeof *kw);
+
+  /* If the independent variable is missing, then we ignore the case */
+  input = casereader_create_filter_missing (input, 
+					    &nst->indep_var, 1,
+					    exclude,
+					    NULL, NULL);
+
+  input = casereader_create_filter_weight (input, dict, &warn, NULL);
+
+  /* Remove all those cases which are outside the range (val1, val2) */
+  input = casereader_create_filter_func (input, include_func, NULL, nst, NULL);
+
+  proto = casereader_get_proto (input);
+  rank_idx = caseproto_get_n_widths (proto);
+
+  /* Rank cases by the v value */
+  for (i = 0; i < nst->n_vars; ++i)
+    {
+      double tiebreaker = 0.0;
+      bool warn = true;
+      enum rank_error rerr = 0;
+      struct casereader *rr;
+      struct casereader *r = casereader_clone (input);
+
+      r = sort_execute_1var (r, nst->vars[i]);
+
+      /* Ignore missings in the test variable */
+      r = casereader_create_filter_missing (r, &nst->vars[i], 1,
+					    exclude,
+					    NULL, NULL);
+
+      rr = casereader_create_append_rank (r, 
+					  nst->vars[i],
+					  dict_get_weight (dict),
+					  &rerr,
+					  distinct_callback, &tiebreaker);
+
+      hmap_init (&kw[i].map);
+      for (; (c = casereader_read (rr)); case_unref (c))
+	{
+	  const union value *group = case_data (c, nst->indep_var);
+	  const size_t group_var_width = var_get_width (nst->indep_var);
+	  struct rank_entry *rank = find_rank_entry (&kw[i].map, group, group_var_width); 
+
+	  if ( NULL == rank)
+	    {
+	      rank = xzalloc (sizeof *rank);
+	      value_clone (&rank->group, group, group_var_width);
+
+	      hmap_insert (&kw[i].map, &rank->node,
+			   value_hash (&rank->group, group_var_width, 0));
+	    }
+
+	  rank->sum_of_ranks += case_data_idx (c, rank_idx)->f;
+	  rank->n += dict_get_case_weight (dict, c, &warn);
+
+	  /* If this assertion fires, then either the data wasn't sorted or some other
+	     problem occured */
+	  assert (rerr == 0);
+	}
+
+      casereader_destroy (rr);
+
+      {
+	struct rank_entry *mre;
+	double n = 0.0;
+
+	HMAP_FOR_EACH (mre, struct rank_entry, node, &kw[i].map)
+	  {
+	    kw[i].h += pow2 (mre->sum_of_ranks) / mre->n;
+	    n += mre->n;
+
+	    total_n_groups ++;
+	  }
+	kw[i].h *= 12 / (n * ( n + 1));
+	kw[i].h -= 3 * (n + 1) ; 
+
+	kw[i].h /= 1 - tiebreaker/ (pow3 (n) - n);
+      }
+    }
+
+  casereader_destroy (input);
+  
+  show_ranks_box (nst, kw, total_n_groups);
+  show_sig_box (nst, kw);
+
+  free (kw);
+}
+
+
+#include <output/tab.h>
+#include "gettext.h"
+#define _(msgid) gettext (msgid)
+
+
+static void
+show_ranks_box (const struct n_sample_test *nst, const struct kw *kw, int n_groups)
+{
+  int i;
+  const int row_headers = 2;
+  const int column_headers = 1;
+  struct tab_table *table =
+    tab_create (row_headers + 2, column_headers + n_groups + nst->n_vars);
+
+  tab_headers (table, row_headers, 0, column_headers, 0);
+
+  tab_title (table, _("Ranks"));
+
+  /* Vertical lines inside the box */
+  tab_box (table, 1, 0, -1, TAL_1,
+	   row_headers, 0, tab_nc (table) - 1, tab_nr (table) - 1 );
+
+  /* Box around the table */
+  tab_box (table, TAL_2, TAL_2, -1, -1,
+	   0,  0, tab_nc (table) - 1, tab_nr (table) - 1 );
+
+  tab_text (table, 1, 0, TAT_TITLE, 
+	    var_to_string (nst->indep_var)
+	    );
+
+  tab_text (table, 3, 0, 0, _("Mean Rank"));
+  tab_text (table, 2, 0, 0, _("N"));
+
+  tab_hline (table, TAL_2, 0, tab_nc (table) -1, column_headers);
+  tab_vline (table, TAL_2, row_headers, 0, tab_nr (table) - 1);
+
+
+  int x = column_headers;
+  for (i = 0 ; i < nst->n_vars ; ++i)
+    {
+      int tot = 0;
+      const struct rank_entry *re;
+
+      if (i > 0)
+	tab_hline (table, TAL_1, 0, tab_nc (table) -1, x);
+      
+      tab_text (table,  0, x,
+		TAT_TITLE, var_to_string (nst->vars[i]));
+
+      HMAP_FOR_EACH (re, const struct rank_entry, node, &kw[i].map)
+	{
+	  struct string str;
+	  ds_init_empty (&str);
+
+	  var_append_value_name (nst->indep_var, &re->group, &str);
+
+	  tab_text   (table, 1, x, TAB_LEFT, ds_cstr (&str));
+	  tab_double (table, 2, x, TAB_LEFT, re->n, &F_8_0);
+	  tab_double (table, 3, x, TAB_LEFT, re->sum_of_ranks / re->n, 0);
+
+	  tot += re->n;
+	  x++;
+	  ds_destroy (&str);
+	}
+      tab_double (table, 2, x, TAB_LEFT,
+		  tot, &F_8_0);
+      tab_text (table, 1, x++, TAB_LEFT, _("Total"));
+    }
+
+  tab_submit (table);
+}
+
+
+static void
+show_sig_box (const struct n_sample_test *nst, const struct kw *kw)
+{
+  int i;
+  const int row_headers = 1;
+  const int column_headers = 1;
+  struct tab_table *table =
+    tab_create (row_headers + nst->n_vars * 2, column_headers + 3);
+
+  tab_headers (table, row_headers, 0, column_headers, 0);
+
+  tab_title (table, _("Test Statistics"));
+
+  tab_text (table,  0, column_headers,
+	    TAT_TITLE | TAB_LEFT , _("Chi-Square"));
+
+  tab_text (table,  0, 1 + column_headers,
+	    TAT_TITLE | TAB_LEFT, _("df"));
+
+  tab_text (table,  0, 2 + column_headers,
+	    TAT_TITLE | TAB_LEFT, _("Asymp. Sig."));
+
+  /* Box around the table */
+  tab_box (table, TAL_2, TAL_2, -1, -1,
+	   0,  0, tab_nc (table) - 1, tab_nr (table) - 1 );
+
+
+  tab_hline (table, TAL_2, 0, tab_nc (table) -1, column_headers);
+  tab_vline (table, TAL_2, row_headers, 0, tab_nr (table) - 1);
+
+  for (i = 0 ; i < nst->n_vars; ++i)
+    {
+      const double df = hmap_count (&kw[i].map) - 1;
+      tab_text (table, column_headers + 1 + i, 0, TAT_TITLE, 
+		var_to_string (nst->vars[i])
+		);
+
+      tab_double (table, column_headers + 1 + i, 1, 0,
+		  kw[i].h, 0);
+
+      tab_double (table, column_headers + 1 + i, 2, 0,
+		  df, &F_8_0);
+
+      tab_double (table, column_headers + 1 + i, 3, 0,
+		  gsl_cdf_chisq_Q (kw[i].h, df),
+		  0);
+    }
+
+  tab_submit (table);
+}
diff --git a/src/language/stats/kruskal-wallis.h b/src/language/stats/kruskal-wallis.h
new file mode 100644
index 00000000..6194759d
--- /dev/null
+++ b/src/language/stats/kruskal-wallis.h
@@ -0,0 +1,42 @@
+/* PSPP - a program for statistical analysis.
+   Copyright (C) 2010 Free Software Foundation, Inc.
+
+   This program is free software: you can redistribute it and/or modify
+   it under the terms of the GNU General Public License as published by
+   the Free Software Foundation, either version 3 of the License, or
+   (at your option) any later version.
+
+   This program is distributed in the hope that it will be useful,
+   but WITHOUT ANY WARRANTY; without even the implied warranty of
+   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+   GNU General Public License for more details.
+
+   You should have received a copy of the GNU General Public License
+   along with this program.  If not, see <http://www.gnu.org/licenses/>. */
+
+#if !kruskal_wallis_h
+#define kruskal_wallis_h 1
+
+#include <stddef.h>
+#include <stdbool.h>
+#include <language/stats/npar.h>
+#include <data/case.h>
+
+
+struct kruskal_wallis_test
+{
+  struct two_sample_test parent;
+};
+
+struct casereader;
+struct dataset;
+
+void kruskal_wallis_execute (const struct dataset *ds,
+		       struct casereader *input,
+		       enum mv_class exclude,
+		       const struct npar_test *test,
+		       bool exact,
+		       double timer
+		       );
+
+#endif
diff --git a/src/language/stats/npar.c b/src/language/stats/npar.c
index 6e2fae71..27d2d22d 100644
--- a/src/language/stats/npar.c
+++ b/src/language/stats/npar.c
@@ -36,6 +36,7 @@
 #include <language/lexer/value-parser.h>
 #include <language/stats/binomial.h>
 #include <language/stats/chisquare.h>
+#include <language/stats/kruskal-wallis.h>
 #include <language/stats/wilcoxon.h>
 #include <language/stats/sign.h>
 #include <libpspp/assertion.h>
@@ -169,7 +170,7 @@ parse_npar_tests (struct lexer *lexer, struct dataset *ds, struct cmd_npar_tests
               NOT_REACHED ();
             }
         }
-      else if (lex_match_hyphenated_word (lexer, "K-S") ||
+      else if (lex_match_hyphenated_word (lexer, "K-W") ||
 	       lex_match_hyphenated_word (lexer, "KRUSKAL-WALLIS"))
         {
           lex_match (lexer, '=');
@@ -753,8 +754,6 @@ parse_n_sample_related_test (struct lexer *lexer,
 			     struct pool *pool
 			     )
 {
-  union value val1, val2;
-
   if (!parse_variables_const_pool (lexer, pool,
 				   dict,
 				   &nst->vars, &nst->n_vars,
@@ -769,20 +768,20 @@ parse_n_sample_related_test (struct lexer *lexer,
   if ( ! lex_force_match (lexer, '('))
     return false;
 
-  value_init (&val1, var_get_width (nst->indep_var));
-  if ( ! parse_value (lexer, &val1, var_get_width (nst->indep_var)))
+  value_init (&nst->val1, var_get_width (nst->indep_var));
+  if ( ! parse_value (lexer, &nst->val1, var_get_width (nst->indep_var)))
     {
-      value_destroy (&val1, var_get_width (nst->indep_var));
+      value_destroy (&nst->val1, var_get_width (nst->indep_var));
       return false;
     }
 
   if ( ! lex_force_match (lexer, ','))
     return false;
 
-  value_init (&val2, var_get_width (nst->indep_var));
-  if ( ! parse_value (lexer, &val2, var_get_width (nst->indep_var)))
+  value_init (&nst->val2, var_get_width (nst->indep_var));
+  if ( ! parse_value (lexer, &nst->val2, var_get_width (nst->indep_var)))
     {
-      value_destroy (&val2, var_get_width (nst->indep_var));
+      value_destroy (&nst->val2, var_get_width (nst->indep_var));
       return false;
     }
 
@@ -847,7 +846,7 @@ npar_kruskal_wallis (struct lexer *lexer, struct dataset *ds,
 
   nt->insert_variables = n_sample_insert_variables;
 
-  //  nt->execute = kruskall_wallis_execute;
+  nt->execute = kruskal_wallis_execute;
 
   if (!parse_n_sample_related_test (lexer, dataset_dict (ds),
 				      tp, specs->pool) )
@@ -880,7 +879,6 @@ two_sample_insert_variables (const struct npar_test *test,
 			     struct const_hsh_table *var_hash)
 {
   int i;
-
   const struct two_sample_test *tst = UP_CAST (test, const struct two_sample_test, parent);
 
   for ( i = 0 ; i < tst->n_pairs ; ++i )
diff --git a/src/language/stats/npar.h b/src/language/stats/npar.h
index 082d396b..1c2605fe 100644
--- a/src/language/stats/npar.h
+++ b/src/language/stats/npar.h
@@ -20,9 +20,7 @@
 #include <stddef.h>
 #include <stdbool.h>
 #include <data/missing-values.h>
-
-#include <stddef.h>
-#include <data/missing-values.h>
+#include <data/value.h>
 
 typedef const struct variable *variable_pair[2];
 
@@ -67,6 +65,7 @@ struct n_sample_test
   const struct variable **vars;
   size_t n_vars;
 
+  union value val1, val2;
   const struct variable *indep_var;
 };
 
-- 
2.30.2