Aggregate: Fixed bug when attempting to create duplicate variable
[pspp] / src / language / stats / aggregate.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 1997-9, 2000, 2006, 2008, 2009, 2010 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <stdlib.h>
20
21 #include <data/any-writer.h>
22 #include <data/case.h>
23 #include <data/casegrouper.h>
24 #include <data/casereader.h>
25 #include <data/casewriter.h>
26 #include <data/dictionary.h>
27 #include <data/file-handle-def.h>
28 #include <data/format.h>
29 #include <data/procedure.h>
30 #include <data/settings.h>
31 #include <data/subcase.h>
32 #include <data/sys-file-writer.h>
33 #include <data/variable.h>
34 #include <language/command.h>
35 #include <language/data-io/file-handle.h>
36 #include <language/lexer/lexer.h>
37 #include <language/lexer/variable-parser.h>
38 #include <language/stats/sort-criteria.h>
39 #include <libpspp/assertion.h>
40 #include <libpspp/message.h>
41 #include <libpspp/misc.h>
42 #include <libpspp/pool.h>
43 #include <libpspp/str.h>
44 #include <math/moments.h>
45 #include <math/sort.h>
46 #include <math/statistic.h>
47 #include <math/percentiles.h>
48
49 #include "aggregate.h"
50
51 #include "minmax.h"
52 #include "xalloc.h"
53
54 #include "gettext.h"
55 #define _(msgid) gettext (msgid)
56 #define N_(msgid) msgid
57
58 /* Argument for AGGREGATE function. */
59 union agr_argument
60   {
61     double f;                           /* Numeric. */
62     char *c;                            /* Short or long string. */
63   };
64
65 /* Specifies how to make an aggregate variable. */
66 struct agr_var
67   {
68     struct agr_var *next;               /* Next in list. */
69
70     /* Collected during parsing. */
71     const struct variable *src; /* Source variable. */
72     struct variable *dest;      /* Target variable. */
73     int function;               /* Function. */
74     enum mv_class exclude;      /* Classes of missing values to exclude. */
75     union agr_argument arg[2];  /* Arguments. */
76
77     /* Accumulated during AGGREGATE execution. */
78     double dbl[3];
79     int int1, int2;
80     char *string;
81     bool saw_missing;
82     struct moments1 *moments;
83     double cc;
84
85     struct variable *subject;
86     struct variable *weight;
87     struct casewriter *writer;
88   };
89
90 /* Aggregation functions. */
91 enum
92   {
93     SUM, MEAN, MEDIAN, SD, MAX, MIN, PGT, PLT, PIN, POUT, FGT, FLT, FIN,
94     FOUT, N, NU, NMISS, NUMISS, FIRST, LAST,
95
96     FUNC = 0x1f, /* Function mask. */
97     FSTRING = 1<<5, /* String function bit. */
98   };
99
100
101 /* Attributes of aggregation functions. */
102 const struct agr_func agr_func_tab[] =
103   {
104     {"SUM",     N_("Sum of values"), AGR_SV_YES, 0, -1,          {FMT_F, 8, 2}},
105     {"MEAN",    N_("Mean average"), AGR_SV_YES, 0, -1,          {FMT_F, 8, 2}},
106     {"MEDIAN",  N_("Median average"), AGR_SV_YES, 0, -1,          {FMT_F, 8, 2}},
107     {"SD",      N_("Standard deviation"), AGR_SV_YES, 0, -1,          {FMT_F, 8, 2}},
108     {"MAX",     N_("Maximum value"), AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
109     {"MIN",     N_("Minimum value"), AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
110     {"PGT",     N_("Percentage greater than"), AGR_SV_YES, 1, VAL_NUMERIC, {FMT_F, 5, 1}},
111     {"PLT",     N_("Percentage less than"), AGR_SV_YES, 1, VAL_NUMERIC, {FMT_F, 5, 1}},
112     {"PIN",     N_("Percentage included in range"), AGR_SV_YES, 2, VAL_NUMERIC, {FMT_F, 5, 1}},
113     {"POUT",    N_("Percentage excluded from range"), AGR_SV_YES, 2, VAL_NUMERIC, {FMT_F, 5, 1}},
114     {"FGT",     N_("Fraction greater than"), AGR_SV_YES, 1, VAL_NUMERIC, {FMT_F, 5, 3}},
115     {"FLT",     N_("Fraction less than"), AGR_SV_YES, 1, VAL_NUMERIC, {FMT_F, 5, 3}},
116     {"FIN",     N_("Fraction included in range"), AGR_SV_YES, 2, VAL_NUMERIC, {FMT_F, 5, 3}},
117     {"FOUT",    N_("Fraction excluded from range"), AGR_SV_YES, 2, VAL_NUMERIC, {FMT_F, 5, 3}},
118     {"N",       N_("Number of cases"), AGR_SV_NO, 0, VAL_NUMERIC, {FMT_F, 7, 0}},
119     {"NU",      N_("Number of cases (unweighted)"), AGR_SV_OPT, 0, VAL_NUMERIC, {FMT_F, 7, 0}},
120     {"NMISS",   N_("Number of missing values"), AGR_SV_YES, 0, VAL_NUMERIC, {FMT_F, 7, 0}},
121     {"NUMISS",  N_("Number of missing values (unweighted)"), AGR_SV_YES, 0, VAL_NUMERIC, {FMT_F, 7, 0}},
122     {"FIRST",   N_("First non-missing value"), AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
123     {"LAST",    N_("Last non-missing value"), AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
124     {NULL,      NULL, AGR_SV_NO, 0, -1,          {-1, -1, -1}},
125   };
126
127 /* Missing value types. */
128 enum missing_treatment
129   {
130     ITEMWISE,           /* Missing values item by item. */
131     COLUMNWISE          /* Missing values column by column. */
132   };
133
134 /* An entire AGGREGATE procedure. */
135 struct agr_proc
136   {
137     /* Break variables. */
138     struct subcase sort;                /* Sort criteria (break variables). */
139     const struct variable **break_vars;       /* Break variables. */
140     size_t break_var_cnt;               /* Number of break variables. */
141
142     enum missing_treatment missing;     /* How to treat missing values. */
143     struct agr_var *agr_vars;           /* First aggregate variable. */
144     struct dictionary *dict;            /* Aggregate dictionary. */
145     const struct dictionary *src_dict;  /* Dict of the source */
146     int case_cnt;                       /* Counts aggregated cases. */
147
148     bool add_variables;                 /* True iff the aggregated variables should
149                                            be appended to the existing dictionary */
150   };
151
152 static void initialize_aggregate_info (struct agr_proc *);
153
154 static void accumulate_aggregate_info (struct agr_proc *,
155                                        const struct ccase *);
156 /* Prototypes. */
157 static bool parse_aggregate_functions (struct lexer *, const struct dictionary *,
158                                        struct agr_proc *);
159 static void agr_destroy (struct agr_proc *);
160 static void dump_aggregate_info (const struct agr_proc *agr,
161                                  struct casewriter *output,
162                                  const struct ccase *break_case);
163 \f
164 /* Parsing. */
165
166 /* Parses and executes the AGGREGATE procedure. */
167 int
168 cmd_aggregate (struct lexer *lexer, struct dataset *ds)
169 {
170   struct dictionary *dict = dataset_dict (ds);
171   struct agr_proc agr;
172   struct file_handle *out_file = NULL;
173   struct casereader *input = NULL, *group;
174   struct casegrouper *grouper;
175   struct casewriter *output = NULL;
176
177   bool copy_documents = false;
178   bool presorted = false;
179   bool saw_direction;
180   bool ok;
181
182   memset(&agr, 0 , sizeof (agr));
183   agr.missing = ITEMWISE;
184   agr.src_dict = dict;
185   subcase_init_empty (&agr.sort);
186
187   /* OUTFILE subcommand must be first. */
188   lex_match (lexer, '/');
189   if (!lex_force_match_id (lexer, "OUTFILE"))
190     goto error;
191   lex_match (lexer, '=');
192   if (!lex_match (lexer, '*'))
193     {
194       out_file = fh_parse (lexer, FH_REF_FILE | FH_REF_SCRATCH);
195       if (out_file == NULL)
196         goto error;
197     }
198
199   if (out_file == NULL && lex_match_id (lexer, "MODE"))
200     {
201       lex_match (lexer, '=');
202       if (lex_match_id (lexer, "ADDVARIABLES"))
203         {
204           agr.add_variables = true;
205
206           /* presorted is assumed in ADDVARIABLES mode */
207           presorted = true;
208         }
209       else if (lex_match_id (lexer, "REPLACE"))
210         {
211           agr.add_variables = false;
212         }
213       else
214         goto error;
215     }
216
217   if ( agr.add_variables )
218     agr.dict = dict_clone (dict);
219   else
220     agr.dict = dict_create ();    
221
222   dict_set_label (agr.dict, dict_get_label (dict));
223   dict_set_documents (agr.dict, dict_get_documents (dict));
224
225   /* Read most of the subcommands. */
226   for (;;)
227     {
228       lex_match (lexer, '/');
229
230       if (lex_match_id (lexer, "MISSING"))
231         {
232           lex_match (lexer, '=');
233           if (!lex_match_id (lexer, "COLUMNWISE"))
234             {
235               lex_error (lexer, _("while expecting COLUMNWISE"));
236               goto error;
237             }
238           agr.missing = COLUMNWISE;
239         }
240       else if (lex_match_id (lexer, "DOCUMENT"))
241         copy_documents = true;
242       else if (lex_match_id (lexer, "PRESORTED"))
243         presorted = true;
244       else if (lex_match_id (lexer, "BREAK"))
245         {
246           int i;
247
248           lex_match (lexer, '=');
249           if (!parse_sort_criteria (lexer, dict, &agr.sort, &agr.break_vars,
250                                     &saw_direction))
251             goto error;
252           agr.break_var_cnt = subcase_get_n_fields (&agr.sort);
253
254           if  (! agr.add_variables)
255             for (i = 0; i < agr.break_var_cnt; i++)
256               dict_clone_var_assert (agr.dict, agr.break_vars[i]);
257
258           /* BREAK must follow the options. */
259           break;
260         }
261       else
262         {
263           lex_error (lexer, _("expecting BREAK"));
264           goto error;
265         }
266     }
267   if (presorted && saw_direction)
268     msg (SW, _("When PRESORTED is specified, specifying sorting directions "
269                "with (A) or (D) has no effect.  Output data will be sorted "
270                "the same way as the input data."));
271
272   /* Read in the aggregate functions. */
273   lex_match (lexer, '/');
274   if (!parse_aggregate_functions (lexer, dict, &agr))
275     goto error;
276
277   /* Delete documents. */
278   if (!copy_documents)
279     dict_clear_documents (agr.dict);
280
281   /* Cancel SPLIT FILE. */
282   dict_set_split_vars (agr.dict, NULL, 0);
283
284   /* Initialize. */
285   agr.case_cnt = 0;
286
287   if (out_file == NULL)
288     {
289       /* The active file will be replaced by the aggregated data,
290          so TEMPORARY is moot. */
291       proc_cancel_temporary_transformations (ds);
292       proc_discard_output (ds);
293       output = autopaging_writer_create (dict_get_proto (agr.dict));
294     }
295   else
296     {
297       output = any_writer_open (out_file, agr.dict);
298       if (output == NULL)
299         goto error;
300     }
301
302   input = proc_open (ds);
303   if (!subcase_is_empty (&agr.sort) && !presorted)
304     {
305       input = sort_execute (input, &agr.sort);
306       subcase_clear (&agr.sort);
307     }
308
309   for (grouper = casegrouper_create_vars (input, agr.break_vars,
310                                           agr.break_var_cnt);
311        casegrouper_get_next_group (grouper, &group);
312        casereader_destroy (group))
313     {
314       struct casereader *placeholder = NULL;
315       struct ccase *c = casereader_peek (group, 0);
316
317       if (c == NULL)
318         {
319           casereader_destroy (group);
320           continue;
321         }
322
323       initialize_aggregate_info (&agr);
324
325       if ( agr.add_variables )
326         placeholder = casereader_clone (group);
327
328       {
329         struct ccase *cg;
330         for (; (cg = casereader_read (group)) != NULL; case_unref (cg))
331           accumulate_aggregate_info (&agr, cg);
332       }
333
334
335       if  (agr.add_variables)
336         {
337           struct ccase *cg;
338           for (; (cg = casereader_read (placeholder)) != NULL; case_unref (cg))
339             dump_aggregate_info (&agr, output, cg);
340
341           casereader_destroy (placeholder);
342         }
343       else
344         {
345           dump_aggregate_info (&agr, output, c);
346           case_unref (c);
347         }
348     }
349   if (!casegrouper_destroy (grouper))
350     goto error;
351
352   if (!proc_commit (ds))
353     {
354       input = NULL;
355       goto error;
356     }
357   input = NULL;
358
359   if (out_file == NULL)
360     {
361       struct casereader *next_input = casewriter_make_reader (output);
362       if (next_input == NULL)
363         goto error;
364
365       proc_set_active_file (ds, next_input, agr.dict);
366       agr.dict = NULL;
367     }
368   else
369     {
370       ok = casewriter_destroy (output);
371       output = NULL;
372       if (!ok)
373         goto error;
374     }
375
376   agr_destroy (&agr);
377   fh_unref (out_file);
378   return CMD_SUCCESS;
379
380 error:
381   if (input != NULL)
382     proc_commit (ds);
383   casewriter_destroy (output);
384   agr_destroy (&agr);
385   fh_unref (out_file);
386   return CMD_CASCADING_FAILURE;
387 }
388
389 /* Parse all the aggregate functions. */
390 static bool
391 parse_aggregate_functions (struct lexer *lexer, const struct dictionary *dict,
392                            struct agr_proc *agr)
393 {
394   struct agr_var *tail; /* Tail of linked list starting at agr->vars. */
395
396   /* Parse everything. */
397   tail = NULL;
398   for (;;)
399     {
400       char **dest;
401       char **dest_label;
402       size_t n_dest;
403       struct string function_name;
404
405       enum mv_class exclude;
406       const struct agr_func *function;
407       int func_index;
408
409       union agr_argument arg[2];
410
411       const struct variable **src;
412       size_t n_src;
413
414       size_t i;
415
416       dest = NULL;
417       dest_label = NULL;
418       n_dest = 0;
419       src = NULL;
420       function = NULL;
421       n_src = 0;
422       arg[0].c = NULL;
423       arg[1].c = NULL;
424       ds_init_empty (&function_name);
425
426       /* Parse the list of target variables. */
427       while (!lex_match (lexer, '='))
428         {
429           size_t n_dest_prev = n_dest;
430
431           if (!parse_DATA_LIST_vars (lexer, &dest, &n_dest,
432                                      (PV_APPEND | PV_SINGLE | PV_NO_SCRATCH
433                                       | PV_NO_DUPLICATE)))
434             goto error;
435
436           /* Assign empty labels. */
437           {
438             int j;
439
440             dest_label = xnrealloc (dest_label, n_dest, sizeof *dest_label);
441             for (j = n_dest_prev; j < n_dest; j++)
442               dest_label[j] = NULL;
443           }
444
445
446
447           if (lex_token (lexer) == T_STRING)
448             {
449               struct string label;
450               ds_init_string (&label, lex_tokstr (lexer));
451
452               ds_truncate (&label, 255);
453               dest_label[n_dest - 1] = ds_xstrdup (&label);
454               lex_get (lexer);
455               ds_destroy (&label);
456             }
457         }
458
459       /* Get the name of the aggregation function. */
460       if (lex_token (lexer) != T_ID)
461         {
462           lex_error (lexer, _("expecting aggregation function"));
463           goto error;
464         }
465
466       exclude = MV_ANY;
467
468       ds_assign_string (&function_name, lex_tokstr (lexer));
469
470       ds_chomp (&function_name, '.');
471
472       if (lex_tokid(lexer)[strlen (lex_tokid (lexer)) - 1] == '.')
473         exclude = MV_SYSTEM;
474
475       for (function = agr_func_tab; function->name; function++)
476         if (!strcasecmp (function->name, ds_cstr (&function_name)))
477           break;
478       if (NULL == function->name)
479         {
480           msg (SE, _("Unknown aggregation function %s."),
481                ds_cstr (&function_name));
482           goto error;
483         }
484       ds_destroy (&function_name);
485       func_index = function - agr_func_tab;
486       lex_get (lexer);
487
488       /* Check for leading lparen. */
489       if (!lex_match (lexer, '('))
490         {
491           if (function->src_vars == AGR_SV_YES)
492             {
493               lex_error (lexer, _("expecting `('"));
494               goto error;
495             }
496         }
497       else
498         {
499           /* Parse list of source variables. */
500           {
501             int pv_opts = PV_NO_SCRATCH;
502
503             if (func_index == SUM || func_index == MEAN || func_index == SD)
504               pv_opts |= PV_NUMERIC;
505             else if (function->n_args)
506               pv_opts |= PV_SAME_TYPE;
507
508             if (!parse_variables_const (lexer, dict, &src, &n_src, pv_opts))
509               goto error;
510           }
511
512           /* Parse function arguments, for those functions that
513              require arguments. */
514           if (function->n_args != 0)
515             for (i = 0; i < function->n_args; i++)
516               {
517                 int type;
518
519                 lex_match (lexer, ',');
520                 if (lex_token (lexer) == T_STRING)
521                   {
522                     arg[i].c = ds_xstrdup (lex_tokstr (lexer));
523                     type = VAL_STRING;
524                   }
525                 else if (lex_is_number (lexer))
526                   {
527                     arg[i].f = lex_tokval (lexer);
528                     type = VAL_NUMERIC;
529                   }
530                 else
531                   {
532                     msg (SE, _("Missing argument %zu to %s."),
533                          i + 1, function->name);
534                     goto error;
535                   }
536
537                 lex_get (lexer);
538
539                 if (type != var_get_type (src[0]))
540                   {
541                     msg (SE, _("Arguments to %s must be of same type as "
542                                "source variables."),
543                          function->name);
544                     goto error;
545                   }
546               }
547
548           /* Trailing rparen. */
549           if (!lex_match (lexer, ')'))
550             {
551               lex_error (lexer, _("expecting `)'"));
552               goto error;
553             }
554
555           /* Now check that the number of source variables match
556              the number of target variables.  If we check earlier
557              than this, the user can get very misleading error
558              message, i.e. `AGGREGATE x=SUM(y t).' will get this
559              error message when a proper message would be more
560              like `unknown variable t'. */
561           if (n_src != n_dest)
562             {
563               msg (SE, _("Number of source variables (%zu) does not match "
564                          "number of target variables (%zu)."),
565                     n_src, n_dest);
566               goto error;
567             }
568
569           if ((func_index == PIN || func_index == POUT
570               || func_index == FIN || func_index == FOUT)
571               && (var_is_numeric (src[0])
572                   ? arg[0].f > arg[1].f
573                   : str_compare_rpad (arg[0].c, arg[1].c) > 0))
574             {
575               union agr_argument t = arg[0];
576               arg[0] = arg[1];
577               arg[1] = t;
578
579               msg (SW, _("The value arguments passed to the %s function "
580                          "are out-of-order.  They will be treated as if "
581                          "they had been specified in the correct order."),
582                    function->name);
583             }
584         }
585
586       /* Finally add these to the linked list of aggregation
587          variables. */
588       for (i = 0; i < n_dest; i++)
589         {
590           struct agr_var *v = xzalloc (sizeof *v);
591
592           /* Add variable to chain. */
593           if (agr->agr_vars != NULL)
594             tail->next = v;
595           else
596             agr->agr_vars = v;
597           tail = v;
598           tail->next = NULL;
599           v->moments = NULL;
600
601           /* Create the target variable in the aggregate
602              dictionary. */
603           {
604             struct variable *destvar;
605
606             v->function = func_index;
607
608             if (src)
609               {
610                 v->src = src[i];
611
612                 if (var_is_alpha (src[i]))
613                   {
614                     v->function |= FSTRING;
615                     v->string = xmalloc (var_get_width (src[i]));
616                   }
617
618                 if (function->alpha_type == VAL_STRING)
619                   destvar = dict_clone_var_as (agr->dict, v->src, dest[i]);
620                 else
621                   {
622                     assert (var_is_numeric (v->src)
623                             || function->alpha_type == VAL_NUMERIC);
624                     destvar = dict_create_var (agr->dict, dest[i], 0);
625                     if (destvar != NULL)
626                       {
627                         struct fmt_spec f;
628                         if ((func_index == N || func_index == NMISS)
629                             && dict_get_weight (dict) != NULL)
630                           f = fmt_for_output (FMT_F, 8, 2);
631                         else
632                           f = function->format;
633                         var_set_both_formats (destvar, &f);
634                       }
635                   }
636               } else {
637                 struct fmt_spec f;
638                 v->src = NULL;
639                 destvar = dict_create_var (agr->dict, dest[i], 0);
640                 if (destvar != NULL)
641                   {
642                     if ((func_index == N || func_index == NMISS)
643                         && dict_get_weight (dict) != NULL)
644                       f = fmt_for_output (FMT_F, 8, 2);
645                     else
646                       f = function->format;
647                     var_set_both_formats (destvar, &f);
648                   }
649             }
650
651             if (!destvar)
652               {
653                 msg (SE, _("Variable name %s is not unique within the "
654                            "aggregate file dictionary, which contains "
655                            "the aggregate variables and the break "
656                            "variables."),
657                      dest[i]);
658                 goto error;
659               }
660
661             free (dest[i]);
662             if (dest_label[i])
663               var_set_label (destvar, dest_label[i]);
664
665             v->dest = destvar;
666           }
667
668           v->exclude = exclude;
669
670           if (v->src != NULL)
671             {
672               int j;
673
674               if (var_is_numeric (v->src))
675                 for (j = 0; j < function->n_args; j++)
676                   v->arg[j].f = arg[j].f;
677               else
678                 for (j = 0; j < function->n_args; j++)
679                   v->arg[j].c = xstrdup (arg[j].c);
680             }
681         }
682
683       if (src != NULL && var_is_alpha (src[0]))
684         for (i = 0; i < function->n_args; i++)
685           {
686             free (arg[i].c);
687             arg[i].c = NULL;
688           }
689
690       free (src);
691       free (dest);
692       free (dest_label);
693
694       if (!lex_match (lexer, '/'))
695         {
696           if (lex_token (lexer) == '.')
697             return true;
698
699           lex_error (lexer, "expecting end of command");
700           return false;
701         }
702       continue;
703
704     error:
705       ds_destroy (&function_name);
706       for (i = 0; i < n_dest; i++)
707         {
708           free (dest[i]);
709           free (dest_label[i]);
710         }
711       free (dest);
712       free (dest_label);
713       free (arg[0].c);
714       free (arg[1].c);
715       if (src && n_src && var_is_alpha (src[0]))
716         for (i = 0; i < function->n_args; i++)
717           {
718             free (arg[i].c);
719             arg[i].c = NULL;
720           }
721       free (src);
722
723       return false;
724     }
725 }
726
727 /* Destroys AGR. */
728 static void
729 agr_destroy (struct agr_proc *agr)
730 {
731   struct agr_var *iter, *next;
732
733   subcase_destroy (&agr->sort);
734   free (agr->break_vars);
735   for (iter = agr->agr_vars; iter; iter = next)
736     {
737       next = iter->next;
738
739       if (iter->function & FSTRING)
740         {
741           size_t n_args;
742           size_t i;
743
744           n_args = agr_func_tab[iter->function & FUNC].n_args;
745           for (i = 0; i < n_args; i++)
746             free (iter->arg[i].c);
747           free (iter->string);
748         }
749       else if (iter->function == SD)
750         moments1_destroy (iter->moments);
751
752       dict_destroy_internal_var (iter->subject);
753       dict_destroy_internal_var (iter->weight);
754
755       free (iter);
756     }
757   if (agr->dict != NULL)
758     dict_destroy (agr->dict);
759 }
760 \f
761 /* Execution. */
762
763 /* Accumulates aggregation data from the case INPUT. */
764 static void
765 accumulate_aggregate_info (struct agr_proc *agr, const struct ccase *input)
766 {
767   struct agr_var *iter;
768   double weight;
769   bool bad_warn = true;
770
771   weight = dict_get_case_weight (agr->src_dict, input, &bad_warn);
772
773   for (iter = agr->agr_vars; iter; iter = iter->next)
774     if (iter->src)
775       {
776         const union value *v = case_data (input, iter->src);
777         int src_width = var_get_width (iter->src);
778
779         if (var_is_value_missing (iter->src, v, iter->exclude))
780           {
781             switch (iter->function)
782               {
783               case NMISS:
784               case NMISS | FSTRING:
785                 iter->dbl[0] += weight;
786                 break;
787               case NUMISS:
788               case NUMISS | FSTRING:
789                 iter->int1++;
790                 break;
791               }
792             iter->saw_missing = true;
793             continue;
794           }
795
796         /* This is horrible.  There are too many possibilities. */
797         switch (iter->function)
798           {
799           case SUM:
800             iter->dbl[0] += v->f * weight;
801             iter->int1 = 1;
802             break;
803           case MEAN:
804             iter->dbl[0] += v->f * weight;
805             iter->dbl[1] += weight;
806             break;
807           case MEDIAN:
808             {
809               double wv ;
810               struct ccase *cout;
811
812               cout = case_create (casewriter_get_proto (iter->writer));
813
814               case_data_rw (cout, iter->subject)->f
815                 = case_data (input, iter->src)->f;
816
817               wv = dict_get_case_weight (agr->src_dict, input, NULL);
818
819               case_data_rw (cout, iter->weight)->f = wv;
820
821               iter->cc += wv;
822
823               casewriter_write (iter->writer, cout);
824             }
825             break;
826           case SD:
827             moments1_add (iter->moments, v->f, weight);
828             break;
829           case MAX:
830             iter->dbl[0] = MAX (iter->dbl[0], v->f);
831             iter->int1 = 1;
832             break;
833           case MAX | FSTRING:
834             if (memcmp (iter->string, value_str (v, src_width), src_width) < 0)
835               memcpy (iter->string, value_str (v, src_width), src_width);
836             iter->int1 = 1;
837             break;
838           case MIN:
839             iter->dbl[0] = MIN (iter->dbl[0], v->f);
840             iter->int1 = 1;
841             break;
842           case MIN | FSTRING:
843             if (memcmp (iter->string, value_str (v, src_width), src_width) > 0)
844               memcpy (iter->string, value_str (v, src_width), src_width);
845             iter->int1 = 1;
846             break;
847           case FGT:
848           case PGT:
849             if (v->f > iter->arg[0].f)
850               iter->dbl[0] += weight;
851             iter->dbl[1] += weight;
852             break;
853           case FGT | FSTRING:
854           case PGT | FSTRING:
855             if (memcmp (iter->arg[0].c,
856                         value_str (v, src_width), src_width) < 0)
857               iter->dbl[0] += weight;
858             iter->dbl[1] += weight;
859             break;
860           case FLT:
861           case PLT:
862             if (v->f < iter->arg[0].f)
863               iter->dbl[0] += weight;
864             iter->dbl[1] += weight;
865             break;
866           case FLT | FSTRING:
867           case PLT | FSTRING:
868             if (memcmp (iter->arg[0].c,
869                         value_str (v, src_width), src_width) > 0)
870               iter->dbl[0] += weight;
871             iter->dbl[1] += weight;
872             break;
873           case FIN:
874           case PIN:
875             if (iter->arg[0].f <= v->f && v->f <= iter->arg[1].f)
876               iter->dbl[0] += weight;
877             iter->dbl[1] += weight;
878             break;
879           case FIN | FSTRING:
880           case PIN | FSTRING:
881             if (memcmp (iter->arg[0].c,
882                         value_str (v, src_width), src_width) <= 0
883                 && memcmp (iter->arg[1].c,
884                            value_str (v, src_width), src_width) >= 0)
885               iter->dbl[0] += weight;
886             iter->dbl[1] += weight;
887             break;
888           case FOUT:
889           case POUT:
890             if (iter->arg[0].f > v->f || v->f > iter->arg[1].f)
891               iter->dbl[0] += weight;
892             iter->dbl[1] += weight;
893             break;
894           case FOUT | FSTRING:
895           case POUT | FSTRING:
896             if (memcmp (iter->arg[0].c,
897                         value_str (v, src_width), src_width) > 0
898                 || memcmp (iter->arg[1].c,
899                            value_str (v, src_width), src_width) < 0)
900               iter->dbl[0] += weight;
901             iter->dbl[1] += weight;
902             break;
903           case N:
904           case N | FSTRING:
905             iter->dbl[0] += weight;
906             break;
907           case NU:
908           case NU | FSTRING:
909             iter->int1++;
910             break;
911           case FIRST:
912             if (iter->int1 == 0)
913               {
914                 iter->dbl[0] = v->f;
915                 iter->int1 = 1;
916               }
917             break;
918           case FIRST | FSTRING:
919             if (iter->int1 == 0)
920               {
921                 memcpy (iter->string, value_str (v, src_width), src_width);
922                 iter->int1 = 1;
923               }
924             break;
925           case LAST:
926             iter->dbl[0] = v->f;
927             iter->int1 = 1;
928             break;
929           case LAST | FSTRING:
930             memcpy (iter->string, value_str (v, src_width), src_width);
931             iter->int1 = 1;
932             break;
933           case NMISS:
934           case NMISS | FSTRING:
935           case NUMISS:
936           case NUMISS | FSTRING:
937             /* Our value is not missing or it would have been
938                caught earlier.  Nothing to do. */
939             break;
940           default:
941             NOT_REACHED ();
942           }
943       } else {
944       switch (iter->function)
945         {
946         case N:
947           iter->dbl[0] += weight;
948           break;
949         case NU:
950           iter->int1++;
951           break;
952         default:
953           NOT_REACHED ();
954         }
955     }
956 }
957
958 /* Writes an aggregated record to OUTPUT. */
959 static void
960 dump_aggregate_info (const struct agr_proc *agr, struct casewriter *output, const struct ccase *break_case)
961 {
962   struct ccase *c = case_create (dict_get_proto (agr->dict));
963
964   if ( agr->add_variables)
965     {
966       case_copy (c, 0, break_case, 0, dict_get_var_cnt (agr->src_dict));
967     }
968   else
969     {
970       int value_idx = 0;
971       int i;
972
973       for (i = 0; i < agr->break_var_cnt; i++)
974         {
975           const struct variable *v = agr->break_vars[i];
976           value_copy (case_data_rw_idx (c, value_idx),
977                       case_data (break_case, v),
978                       var_get_width (v));
979           value_idx++;
980         }
981     }
982
983   {
984     struct agr_var *i;
985
986     for (i = agr->agr_vars; i; i = i->next)
987       {
988         union value *v = case_data_rw (c, i->dest);
989         int width = var_get_width (i->dest);
990
991         if (agr->missing == COLUMNWISE && i->saw_missing
992             && (i->function & FUNC) != N && (i->function & FUNC) != NU
993             && (i->function & FUNC) != NMISS && (i->function & FUNC) != NUMISS)
994           {
995             value_set_missing (v, width);
996             casewriter_destroy (i->writer);
997             continue;
998           }
999
1000         switch (i->function)
1001           {
1002           case SUM:
1003             v->f = i->int1 ? i->dbl[0] : SYSMIS;
1004             break;
1005           case MEAN:
1006             v->f = i->dbl[1] != 0.0 ? i->dbl[0] / i->dbl[1] : SYSMIS;
1007             break;
1008           case MEDIAN:
1009             {
1010               if ( i->writer)
1011                 {
1012                   struct percentile *median = percentile_create (0.5, i->cc);
1013                   struct order_stats *os = &median->parent;
1014                   struct casereader *sorted_reader = casewriter_make_reader (i->writer);
1015                   i->writer = NULL;
1016
1017                   order_stats_accumulate (&os, 1,
1018                                           sorted_reader,
1019                                           i->weight,
1020                                           i->subject,
1021                                           i->exclude);
1022                   i->dbl[0] = percentile_calculate (median, PC_HAVERAGE);
1023                   statistic_destroy (&median->parent.parent);
1024                 }
1025               v->f = i->dbl[0];
1026             }
1027             break;
1028           case SD:
1029             {
1030               double variance;
1031
1032               /* FIXME: we should use two passes. */
1033               moments1_calculate (i->moments, NULL, NULL, &variance,
1034                                  NULL, NULL);
1035               if (variance != SYSMIS)
1036                 v->f = sqrt (variance);
1037               else
1038                 v->f = SYSMIS;
1039             }
1040             break;
1041           case MAX:
1042           case MIN:
1043             v->f = i->int1 ? i->dbl[0] : SYSMIS;
1044             break;
1045           case MAX | FSTRING:
1046           case MIN | FSTRING:
1047             if (i->int1)
1048               memcpy (value_str_rw (v, width), i->string, width);
1049             else
1050               value_set_missing (v, width);
1051             break;
1052           case FGT:
1053           case FGT | FSTRING:
1054           case FLT:
1055           case FLT | FSTRING:
1056           case FIN:
1057           case FIN | FSTRING:
1058           case FOUT:
1059           case FOUT | FSTRING:
1060             v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] : SYSMIS;
1061             break;
1062           case PGT:
1063           case PGT | FSTRING:
1064           case PLT:
1065           case PLT | FSTRING:
1066           case PIN:
1067           case PIN | FSTRING:
1068           case POUT:
1069           case POUT | FSTRING:
1070             v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] * 100.0 : SYSMIS;
1071             break;
1072           case N:
1073           case N | FSTRING:
1074               v->f = i->dbl[0];
1075             break;
1076           case NU:
1077           case NU | FSTRING:
1078             v->f = i->int1;
1079             break;
1080           case FIRST:
1081           case LAST:
1082             v->f = i->int1 ? i->dbl[0] : SYSMIS;
1083             break;
1084           case FIRST | FSTRING:
1085           case LAST | FSTRING:
1086             if (i->int1)
1087               memcpy (value_str_rw (v, width), i->string, width);
1088             else
1089               value_set_missing (v, width);
1090             break;
1091           case NMISS:
1092           case NMISS | FSTRING:
1093             v->f = i->dbl[0];
1094             break;
1095           case NUMISS:
1096           case NUMISS | FSTRING:
1097             v->f = i->int1;
1098             break;
1099           default:
1100             NOT_REACHED ();
1101           }
1102       }
1103   }
1104
1105   casewriter_write (output, c);
1106 }
1107
1108 /* Resets the state for all the aggregate functions. */
1109 static void
1110 initialize_aggregate_info (struct agr_proc *agr)
1111 {
1112   struct agr_var *iter;
1113
1114   for (iter = agr->agr_vars; iter; iter = iter->next)
1115     {
1116       iter->saw_missing = false;
1117       iter->dbl[0] = iter->dbl[1] = iter->dbl[2] = 0.0;
1118       iter->int1 = iter->int2 = 0;
1119       switch (iter->function)
1120         {
1121         case MIN:
1122           iter->dbl[0] = DBL_MAX;
1123           break;
1124         case MIN | FSTRING:
1125           memset (iter->string, 255, var_get_width (iter->src));
1126           break;
1127         case MAX:
1128           iter->dbl[0] = -DBL_MAX;
1129           break;
1130         case MAX | FSTRING:
1131           memset (iter->string, 0, var_get_width (iter->src));
1132           break;
1133         case MEDIAN:
1134           {
1135             struct caseproto *proto;
1136             struct subcase ordering;
1137
1138             proto = caseproto_create ();
1139             proto = caseproto_add_width (proto, 0);
1140             proto = caseproto_add_width (proto, 0);
1141
1142             if ( ! iter->subject)
1143               iter->subject = dict_create_internal_var (0, 0);
1144
1145             if ( ! iter->weight)
1146               iter->weight = dict_create_internal_var (1, 0);
1147
1148             subcase_init_var (&ordering, iter->subject, SC_ASCEND);
1149             iter->writer = sort_create_writer (&ordering, proto);
1150             subcase_destroy (&ordering);
1151             caseproto_unref (proto);
1152
1153             iter->cc = 0;
1154           }
1155           break;
1156         case SD:
1157           if (iter->moments == NULL)
1158             iter->moments = moments1_create (MOMENT_VARIANCE);
1159           else
1160             moments1_clear (iter->moments);
1161           break;
1162         default:
1163           break;
1164         }
1165     }
1166 }