math: Make 'accumulate' a feature of order statistics, not all stats.
[pspp] / src / math / trimmed-mean.c
index 5e48689626817ddf59ad46363a17e1d0e763084d..dff256c0d3319cd4bd0892adbcf250219a0e7505 100644 (file)
@@ -33,66 +33,57 @@ acc (struct statistic *s, const struct ccase *cx UNUSED, double c, double cc, do
   struct trimmed_mean *tm = UP_CAST (s, struct trimmed_mean, parent.parent);
   struct order_stats *os = &tm->parent;
 
-  if ( cc > os->k[0].tc && cc < os->k[1].tc)
-      tm->sum += c * y;
+  if (cc > os->k[0].tc && cc <= os->k[1].tc)
+    tm->sum += c * y;
 
-  if ( tm->cyk1p1 == SYSMIS && cc >os->k[0].tc)
-      tm->cyk1p1 = c * y;
+  if (tm->cyk1p1 == SYSMIS && cc > os->k[0].tc)
+    tm->cyk1p1 = c * y;
 }
 
 static void
 destroy (struct statistic *s)
 {
   struct trimmed_mean *tm = UP_CAST (s, struct trimmed_mean, parent.parent);
-  struct order_stats *os = &tm->parent;
-  free (os->k);
   free (tm);
 }
 
 struct trimmed_mean *
 trimmed_mean_create (double W, double tail)
 {
-  struct trimmed_mean *tm = xzalloc (sizeof (*tm));
-  struct order_stats *os = &tm->parent;
-  struct statistic *stat = &os->parent;
-
-  os->n_k = 2;
-  os->k = xcalloc (sizeof (*os->k), 2);
-
   assert (tail >= 0);
   assert (tail <= 1);
 
-  os->k[0].tc = tail * W;
-  os->k[1].tc = W * (1 - tail);
-
-  stat->accumulate = acc;
-  stat->destroy = destroy;
-
-  tm->cyk1p1 = SYSMIS;
-  tm->w = W;
-  tm->tail = tail;
-
+  struct trimmed_mean *tm = xmalloc (sizeof *tm);
+  *tm = (struct trimmed_mean) {
+    .parent = {
+      .parent = {
+        .destroy = destroy,
+      },
+      .accumulate = acc,
+      .k = tm->k,
+      .n_k = 2,
+    },
+    .k[0] = { .tc = tail * W },
+    .k[1] = { .tc = W * (1 - tail) },
+    .cyk1p1 = SYSMIS,
+    .w = W,
+    .tail = tail,
+  };
   return tm;
 }
 
-
 double
 trimmed_mean_calculate (const struct trimmed_mean *tm)
 {
   const struct order_stats *os = (const struct order_stats *) tm;
 
-  assert (os->cc == tm->w);
-
   return
     (
-     (os->k[0].cc_p1 - os->k[0].tc) * os->k[0].y_p1
-     -
-     (os->k[1].cc - os->k[1].tc) * os->k[1].y_p1
+     (os->k[0].cc - os->k[0].tc) * os->k[0].y_p1
+     +
+      (tm->w - os->k[1].cc - os->k[0].tc) * os->k[1].y_p1
      +
-     tm->sum
-     -
-     tm->cyk1p1
-     )
-    /
-    ( (1.0 - 2 * tm->tail) * tm->w);
+      tm->sum
+)
+    / ((1.0 - tm->tail * 2) * tm->w);
 }