8ee7060f6319a55f1899c0b819959dcf4408db0e
[pspp] / src / language / stats / aggregate.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 1997-9, 2000, 2006, 2008, 2009, 2010, 2011, 2012, 2014 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/aggregate.h"
20
21 #include <stdlib.h>
22
23 #include "data/any-writer.h"
24 #include "data/case.h"
25 #include "data/casegrouper.h"
26 #include "data/casereader.h"
27 #include "data/casewriter.h"
28 #include "data/dataset.h"
29 #include "data/dictionary.h"
30 #include "data/file-handle-def.h"
31 #include "data/format.h"
32 #include "data/settings.h"
33 #include "data/subcase.h"
34 #include "data/sys-file-writer.h"
35 #include "data/variable.h"
36 #include "language/command.h"
37 #include "language/data-io/file-handle.h"
38 #include "language/lexer/lexer.h"
39 #include "language/lexer/variable-parser.h"
40 #include "language/stats/sort-criteria.h"
41 #include "libpspp/assertion.h"
42 #include "libpspp/i18n.h"
43 #include "libpspp/message.h"
44 #include "libpspp/misc.h"
45 #include "libpspp/pool.h"
46 #include "libpspp/str.h"
47 #include "math/moments.h"
48 #include "math/percentiles.h"
49 #include "math/sort.h"
50 #include "math/statistic.h"
51
52 #include "gl/c-strcase.h"
53 #include "gl/minmax.h"
54 #include "gl/xalloc.h"
55
56 #include "gettext.h"
57 #define _(msgid) gettext (msgid)
58 #define N_(msgid) msgid
59
60 /* Argument for AGGREGATE function. */
61 union agr_argument
62   {
63     double f;                           /* Numeric. */
64     char *c;                            /* Short or long string. */
65   };
66
67 /* Specifies how to make an aggregate variable. */
68 struct agr_var
69   {
70     struct agr_var *next;               /* Next in list. */
71
72     /* Collected during parsing. */
73     const struct variable *src; /* Source variable. */
74     struct variable *dest;      /* Target variable. */
75     int function;               /* Function. */
76     enum mv_class exclude;      /* Classes of missing values to exclude. */
77     union agr_argument arg[2];  /* Arguments. */
78
79     /* Accumulated during AGGREGATE execution. */
80     double dbl[3];
81     int int1;
82     char *string;
83     bool saw_missing;
84     struct moments1 *moments;
85     double cc;
86
87     struct variable *subject;
88     struct variable *weight;
89     struct casewriter *writer;
90   };
91
92
93 /* Attributes of aggregation functions. */
94 const struct agr_func agr_func_tab[] =
95   {
96     {"SUM",     N_("Sum of values"),                         AGR_SV_YES, 0, -1,          { .type = FMT_F, .w = 8, .d = 2 }},
97     {"MEAN",    N_("Mean average"),                          AGR_SV_YES, 0, -1,          { .type = FMT_F, .w = 8, .d = 2 }},
98     {"MEDIAN",  N_("Median average"),                        AGR_SV_YES, 0, -1,          { .type = FMT_F, .w = 8, .d = 2 }},
99     {"SD",      N_("Standard deviation"),                    AGR_SV_YES, 0, -1,          { .type = FMT_F, .w = 8, .d = 2 }},
100     {"MAX",     N_("Maximum value"),                         AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
101     {"MIN",     N_("Minimum value"),                         AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
102     {"PGT",     N_("Percentage greater than"),               AGR_SV_YES, 1, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 1 }},
103     {"PLT",     N_("Percentage less than"),                  AGR_SV_YES, 1, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 1 }},
104     {"PIN",     N_("Percentage included in range"),          AGR_SV_YES, 2, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 1 }},
105     {"POUT",    N_("Percentage excluded from range"),        AGR_SV_YES, 2, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 1 }},
106     {"FGT",     N_("Fraction greater than"),                 AGR_SV_YES, 1, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 3 }},
107     {"FLT",     N_("Fraction less than"),                    AGR_SV_YES, 1, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 3 }},
108     {"FIN",     N_("Fraction included in range"),            AGR_SV_YES, 2, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 3 }},
109     {"FOUT",    N_("Fraction excluded from range"),          AGR_SV_YES, 2, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 3 }},
110     {"N",       N_("Number of cases"),                       AGR_SV_NO,  0, VAL_NUMERIC, { .type = FMT_F, .w = 7, .d = 0 }},
111     {"NU",      N_("Number of cases (unweighted)"),          AGR_SV_OPT, 0, VAL_NUMERIC, { .type = FMT_F, .w = 7, .d = 0 }},
112     {"NMISS",   N_("Number of missing values"),              AGR_SV_YES, 0, VAL_NUMERIC, { .type = FMT_F, .w = 7, .d = 0 }},
113     {"NUMISS",  N_("Number of missing values (unweighted)"), AGR_SV_YES, 0, VAL_NUMERIC, { .type = FMT_F, .w = 7, .d = 0 }},
114     {"FIRST",   N_("First non-missing value"),               AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
115     {"LAST",    N_("Last non-missing value"),                AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
116     {NULL,      NULL,                                        AGR_SV_NO,  0, -1,          {-1, -1, -1}},
117   };
118
119 /* Missing value types. */
120 enum missing_treatment
121   {
122     ITEMWISE,           /* Missing values item by item. */
123     COLUMNWISE          /* Missing values column by column. */
124   };
125
126 /* An entire AGGREGATE procedure. */
127 struct agr_proc
128   {
129     /* Break variables. */
130     struct subcase sort;                /* Sort criteria (break variables). */
131     const struct variable **break_vars;       /* Break variables. */
132     size_t break_n_vars;                /* Number of break variables. */
133
134     enum missing_treatment missing;     /* How to treat missing values. */
135     struct agr_var *agr_vars;           /* First aggregate variable. */
136     struct dictionary *dict;            /* Aggregate dictionary. */
137     const struct dictionary *src_dict;  /* Dict of the source */
138     int n_cases;                        /* Counts aggregated cases. */
139
140     bool add_variables;                 /* True iff the aggregated variables should
141                                            be appended to the existing dictionary */
142   };
143
144 static void initialize_aggregate_info (struct agr_proc *);
145
146 static void accumulate_aggregate_info (struct agr_proc *,
147                                        const struct ccase *);
148 /* Prototypes. */
149 static bool parse_aggregate_functions (struct lexer *, const struct dictionary *,
150                                        struct agr_proc *);
151 static void agr_destroy (struct agr_proc *);
152 static void dump_aggregate_info (const struct agr_proc *agr,
153                                  struct casewriter *output,
154                                  const struct ccase *break_case);
155 \f
156 /* Parsing. */
157
158 /* Parses and executes the AGGREGATE procedure. */
159 int
160 cmd_aggregate (struct lexer *lexer, struct dataset *ds)
161 {
162   struct dictionary *dict = dataset_dict (ds);
163   struct agr_proc agr;
164   struct file_handle *out_file = NULL;
165   struct casereader *input = NULL, *group;
166   struct casegrouper *grouper;
167   struct casewriter *output = NULL;
168
169   bool copy_documents = false;
170   bool presorted = false;
171   bool saw_direction;
172   bool ok;
173
174   memset(&agr, 0 , sizeof (agr));
175   agr.missing = ITEMWISE;
176   agr.src_dict = dict;
177   subcase_init_empty (&agr.sort);
178
179   /* OUTFILE subcommand must be first. */
180   lex_match (lexer, T_SLASH);
181   if (!lex_force_match_id (lexer, "OUTFILE"))
182     goto error;
183   lex_match (lexer, T_EQUALS);
184   if (!lex_match (lexer, T_ASTERISK))
185     {
186       out_file = fh_parse (lexer, FH_REF_FILE, dataset_session (ds));
187       if (out_file == NULL)
188         goto error;
189     }
190
191   if (out_file == NULL && lex_match_id (lexer, "MODE"))
192     {
193       lex_match (lexer, T_EQUALS);
194       if (lex_match_id (lexer, "ADDVARIABLES"))
195         {
196           agr.add_variables = true;
197
198           /* presorted is assumed in ADDVARIABLES mode */
199           presorted = true;
200         }
201       else if (lex_match_id (lexer, "REPLACE"))
202         {
203           agr.add_variables = false;
204         }
205       else
206         goto error;
207     }
208
209   if (agr.add_variables)
210     agr.dict = dict_clone (dict);
211   else
212     agr.dict = dict_create (dict_get_encoding (dict));
213
214   dict_set_label (agr.dict, dict_get_label (dict));
215   dict_set_documents (agr.dict, dict_get_documents (dict));
216
217   /* Read most of the subcommands. */
218   for (;;)
219     {
220       lex_match (lexer, T_SLASH);
221
222       if (lex_match_id (lexer, "MISSING"))
223         {
224           lex_match (lexer, T_EQUALS);
225           if (!lex_match_id (lexer, "COLUMNWISE"))
226             {
227               lex_error_expecting (lexer, "COLUMNWISE");
228               goto error;
229             }
230           agr.missing = COLUMNWISE;
231         }
232       else if (lex_match_id (lexer, "DOCUMENT"))
233         copy_documents = true;
234       else if (lex_match_id (lexer, "PRESORTED"))
235         presorted = true;
236       else if (lex_force_match_id (lexer, "BREAK"))
237         {
238           int i;
239
240           lex_match (lexer, T_EQUALS);
241           if (!parse_sort_criteria (lexer, dict, &agr.sort, &agr.break_vars,
242                                     &saw_direction))
243             goto error;
244           agr.break_n_vars = subcase_get_n_fields (&agr.sort);
245
246           if  (! agr.add_variables)
247             for (i = 0; i < agr.break_n_vars; i++)
248               dict_clone_var_assert (agr.dict, agr.break_vars[i]);
249
250           /* BREAK must follow the options. */
251           break;
252         }
253       else
254         goto error;
255
256     }
257   if (presorted && saw_direction)
258     msg (SW, _("When PRESORTED is specified, specifying sorting directions "
259                "with (A) or (D) has no effect.  Output data will be sorted "
260                "the same way as the input data."));
261
262   /* Read in the aggregate functions. */
263   lex_match (lexer, T_SLASH);
264   if (!parse_aggregate_functions (lexer, dict, &agr))
265     goto error;
266
267   /* Delete documents. */
268   if (!copy_documents)
269     dict_clear_documents (agr.dict);
270
271   /* Cancel SPLIT FILE. */
272   dict_clear_split_vars (agr.dict);
273
274   /* Initialize. */
275   agr.n_cases = 0;
276
277   if (out_file == NULL)
278     {
279       /* The active dataset will be replaced by the aggregated data,
280          so TEMPORARY is moot. */
281       proc_cancel_temporary_transformations (ds);
282       proc_discard_output (ds);
283       output = autopaging_writer_create (dict_get_proto (agr.dict));
284     }
285   else
286     {
287       output = any_writer_open (out_file, agr.dict);
288       if (output == NULL)
289         goto error;
290     }
291
292   input = proc_open (ds);
293   if (!subcase_is_empty (&agr.sort) && !presorted)
294     {
295       input = sort_execute (input, &agr.sort);
296       subcase_clear (&agr.sort);
297     }
298
299   for (grouper = casegrouper_create_vars (input, agr.break_vars,
300                                           agr.break_n_vars);
301        casegrouper_get_next_group (grouper, &group);
302        casereader_destroy (group))
303     {
304       struct casereader *placeholder = NULL;
305       struct ccase *c = casereader_peek (group, 0);
306
307       if (c == NULL)
308         {
309           casereader_destroy (group);
310           continue;
311         }
312
313       initialize_aggregate_info (&agr);
314
315       if (agr.add_variables)
316         placeholder = casereader_clone (group);
317
318       {
319         struct ccase *cg;
320         for (; (cg = casereader_read (group)) != NULL; case_unref (cg))
321           accumulate_aggregate_info (&agr, cg);
322       }
323
324
325       if  (agr.add_variables)
326         {
327           struct ccase *cg;
328           for (; (cg = casereader_read (placeholder)) != NULL; case_unref (cg))
329             dump_aggregate_info (&agr, output, cg);
330
331           casereader_destroy (placeholder);
332         }
333       else
334         {
335           dump_aggregate_info (&agr, output, c);
336         }
337       case_unref (c);
338     }
339   if (!casegrouper_destroy (grouper))
340     goto error;
341
342   if (!proc_commit (ds))
343     {
344       input = NULL;
345       goto error;
346     }
347   input = NULL;
348
349   if (out_file == NULL)
350     {
351       struct casereader *next_input = casewriter_make_reader (output);
352       if (next_input == NULL)
353         goto error;
354
355       dataset_set_dict (ds, agr.dict);
356       dataset_set_source (ds, next_input);
357       agr.dict = NULL;
358     }
359   else
360     {
361       ok = casewriter_destroy (output);
362       output = NULL;
363       if (!ok)
364         goto error;
365     }
366
367   agr_destroy (&agr);
368   fh_unref (out_file);
369   return CMD_SUCCESS;
370
371 error:
372   if (input != NULL)
373     proc_commit (ds);
374   casewriter_destroy (output);
375   agr_destroy (&agr);
376   fh_unref (out_file);
377   return CMD_CASCADING_FAILURE;
378 }
379
380 /* Parse all the aggregate functions. */
381 static bool
382 parse_aggregate_functions (struct lexer *lexer, const struct dictionary *dict,
383                            struct agr_proc *agr)
384 {
385   struct agr_var *tail; /* Tail of linked list starting at agr->vars. */
386
387   /* Parse everything. */
388   tail = NULL;
389   for (;;)
390     {
391       char **dest;
392       char **dest_label;
393       size_t n_dest;
394       struct string function_name;
395
396       enum mv_class exclude;
397       const struct agr_func *function;
398       int func_index;
399
400       union agr_argument arg[2];
401
402       const struct variable **src;
403       size_t n_src;
404
405       size_t i;
406
407       dest = NULL;
408       dest_label = NULL;
409       n_dest = 0;
410       src = NULL;
411       function = NULL;
412       n_src = 0;
413       arg[0].c = NULL;
414       arg[1].c = NULL;
415       ds_init_empty (&function_name);
416
417       /* Parse the list of target variables. */
418       while (!lex_match (lexer, T_EQUALS))
419         {
420           size_t n_dest_prev = n_dest;
421
422           if (!parse_DATA_LIST_vars (lexer, dict, &dest, &n_dest,
423                                      (PV_APPEND | PV_SINGLE | PV_NO_SCRATCH
424                                       | PV_NO_DUPLICATE)))
425             goto error;
426
427           /* Assign empty labels. */
428           {
429             int j;
430
431             dest_label = xnrealloc (dest_label, n_dest, sizeof *dest_label);
432             for (j = n_dest_prev; j < n_dest; j++)
433               dest_label[j] = NULL;
434           }
435
436
437
438           if (lex_is_string (lexer))
439             {
440               dest_label[n_dest - 1] = xstrdup (lex_tokcstr (lexer));
441               lex_get (lexer);
442             }
443         }
444
445       /* Get the name of the aggregation function. */
446       if (lex_token (lexer) != T_ID)
447         {
448           lex_error (lexer, _("Syntax error expecting aggregation function."));
449           goto error;
450         }
451
452       ds_assign_substring (&function_name, lex_tokss (lexer));
453       exclude = ds_chomp_byte (&function_name, '.') ? MV_SYSTEM : MV_ANY;
454
455       for (function = agr_func_tab; function->name; function++)
456         if (!c_strcasecmp (function->name, ds_cstr (&function_name)))
457           break;
458       if (NULL == function->name)
459         {
460           lex_error (lexer, _("Unknown aggregation function %s."),
461                      ds_cstr (&function_name));
462           goto error;
463         }
464       ds_destroy (&function_name);
465       func_index = function - agr_func_tab;
466       lex_get (lexer);
467
468       /* Check for leading lparen. */
469       if (!lex_match (lexer, T_LPAREN))
470         {
471           if (function->src_vars == AGR_SV_YES)
472             {
473               goto error;
474             }
475         }
476       else
477         {
478           /* Parse list of source variables. */
479           int pv_opts = PV_NO_SCRATCH;
480           if (func_index == SUM || func_index == MEAN || func_index == SD)
481             pv_opts |= PV_NUMERIC;
482           else if (function->n_args)
483             pv_opts |= PV_SAME_TYPE;
484
485           int vars_start_ofs = lex_ofs (lexer);
486           if (!parse_variables_const (lexer, dict, &src, &n_src, pv_opts))
487             goto error;
488           int vars_end_ofs = lex_ofs (lexer) - 1;
489
490           /* Parse function arguments, for those functions that
491              require arguments. */
492           int args_start_ofs = 0;
493           if (function->n_args != 0)
494             for (i = 0; i < function->n_args; i++)
495               {
496                 int type;
497
498                 lex_match (lexer, T_COMMA);
499                 if (i == 0)
500                   args_start_ofs = lex_ofs (lexer);
501                 if (lex_is_string (lexer))
502                   {
503                     arg[i].c = recode_string (dict_get_encoding (agr->dict),
504                                               "UTF-8", lex_tokcstr (lexer),
505                                               -1);
506                     type = VAL_STRING;
507                   }
508                 else if (lex_is_number (lexer))
509                   {
510                     arg[i].f = lex_tokval (lexer);
511                     type = VAL_NUMERIC;
512                   }
513                 else
514                   {
515                     lex_error (lexer, _("Missing argument %zu to %s."),
516                                i + 1, function->name);
517                     goto error;
518                   }
519                 if (type != var_get_type (src[0]))
520                   {
521                     msg (SE, _("Arguments to %s must be of same type as "
522                                "source variables."),
523                          function->name);
524                     if (type == VAL_NUMERIC)
525                       {
526                         lex_next_msg (lexer, SN, 0, 0,
527                                       _("The argument is numeric."));
528                         lex_ofs_msg (lexer, SN, vars_start_ofs, vars_end_ofs,
529                                      _("The variables have string type."));
530                       }
531                     else
532                       {
533                         lex_next_msg (lexer, SN, 0, 0,
534                                       _("The argument is a string."));
535                         lex_ofs_msg (lexer, SN, vars_start_ofs, vars_end_ofs,
536                                      _("The variables are numeric."));
537                       }
538                     goto error;
539                   }
540
541                 lex_get (lexer);
542               }
543           int args_end_ofs = lex_ofs (lexer) - 1;
544
545           /* Trailing rparen. */
546           if (!lex_force_match (lexer, T_RPAREN))
547             goto error;
548
549           /* Now check that the number of source variables match
550              the number of target variables.  If we check earlier
551              than this, the user can get very misleading error
552              message, i.e. `AGGREGATE x=SUM(y t).' will get this
553              error message when a proper message would be more
554              like `unknown variable t'. */
555           if (n_src != n_dest)
556             {
557               msg (SE, _("Number of source variables (%zu) does not match "
558                          "number of target variables (%zu)."),
559                     n_src, n_dest);
560               goto error;
561             }
562
563           if ((func_index == PIN || func_index == POUT
564               || func_index == FIN || func_index == FOUT)
565               && (var_is_numeric (src[0])
566                   ? arg[0].f > arg[1].f
567                   : str_compare_rpad (arg[0].c, arg[1].c) > 0))
568             {
569               union agr_argument t = arg[0];
570               arg[0] = arg[1];
571               arg[1] = t;
572
573               lex_ofs_msg (lexer, SW, args_start_ofs, args_end_ofs,
574                            _("The value arguments passed to the %s function "
575                              "are out of order.  They will be treated as if "
576                              "they had been specified in the correct order."),
577                            function->name);
578             }
579         }
580
581       /* Finally add these to the linked list of aggregation
582          variables. */
583       for (i = 0; i < n_dest; i++)
584         {
585           struct agr_var *v = XZALLOC (struct agr_var);
586
587           /* Add variable to chain. */
588           if (agr->agr_vars != NULL)
589             tail->next = v;
590           else
591             agr->agr_vars = v;
592           tail = v;
593           tail->next = NULL;
594           v->moments = NULL;
595
596           /* Create the target variable in the aggregate
597              dictionary. */
598           {
599             struct variable *destvar;
600
601             v->function = func_index;
602
603             if (src)
604               {
605                 v->src = src[i];
606
607                 if (var_is_alpha (src[i]))
608                   {
609                     v->function |= FSTRING;
610                     v->string = xmalloc (var_get_width (src[i]));
611                   }
612
613                 if (function->alpha_type == VAL_STRING)
614                   destvar = dict_clone_var_as (agr->dict, v->src, dest[i]);
615                 else
616                   {
617                     assert (var_is_numeric (v->src)
618                             || function->alpha_type == VAL_NUMERIC);
619                     destvar = dict_create_var (agr->dict, dest[i], 0);
620                     if (destvar != NULL)
621                       {
622                         struct fmt_spec f;
623                         if ((func_index == N || func_index == NMISS)
624                             && dict_get_weight (dict) != NULL)
625                           f = fmt_for_output (FMT_F, 8, 2);
626                         else
627                           f = function->format;
628                         var_set_both_formats (destvar, &f);
629                       }
630                   }
631               } else {
632                 struct fmt_spec f;
633                 v->src = NULL;
634                 destvar = dict_create_var (agr->dict, dest[i], 0);
635                 if (destvar != NULL)
636                   {
637                     if ((func_index == N || func_index == NMISS)
638                         && dict_get_weight (dict) != NULL)
639                       f = fmt_for_output (FMT_F, 8, 2);
640                     else
641                       f = function->format;
642                     var_set_both_formats (destvar, &f);
643                   }
644             }
645
646             if (!destvar)
647               {
648                 msg (SE, _("Variable name %s is not unique within the "
649                            "aggregate file dictionary, which contains "
650                            "the aggregate variables and the break "
651                            "variables."),
652                      dest[i]);
653                 goto error;
654               }
655
656             free (dest[i]);
657             if (dest_label[i])
658               var_set_label (destvar, dest_label[i]);
659
660             v->dest = destvar;
661           }
662
663           v->exclude = exclude;
664
665           if (v->src != NULL)
666             {
667               int j;
668
669               if (var_is_numeric (v->src))
670                 for (j = 0; j < function->n_args; j++)
671                   v->arg[j].f = arg[j].f;
672               else
673                 for (j = 0; j < function->n_args; j++)
674                   v->arg[j].c = xstrdup (arg[j].c);
675             }
676         }
677
678       if (src != NULL && var_is_alpha (src[0]))
679         for (i = 0; i < function->n_args; i++)
680           {
681             free (arg[i].c);
682             arg[i].c = NULL;
683           }
684
685       free (src);
686       free (dest);
687       free (dest_label);
688
689       if (!lex_match (lexer, T_SLASH))
690         {
691           if (lex_token (lexer) == T_ENDCMD)
692             return true;
693
694           lex_error (lexer, "Syntax error expecting end of command.");
695           return false;
696         }
697       continue;
698
699     error:
700       ds_destroy (&function_name);
701       for (i = 0; i < n_dest; i++)
702         {
703           free (dest[i]);
704           free (dest_label[i]);
705         }
706       free (dest);
707       free (dest_label);
708       free (arg[0].c);
709       free (arg[1].c);
710       if (src && n_src && var_is_alpha (src[0]))
711         for (i = 0; i < function->n_args; i++)
712           {
713             free (arg[i].c);
714             arg[i].c = NULL;
715           }
716       free (src);
717
718       return false;
719     }
720 }
721
722 /* Destroys AGR. */
723 static void
724 agr_destroy (struct agr_proc *agr)
725 {
726   struct agr_var *iter, *next;
727
728   subcase_uninit (&agr->sort);
729   free (agr->break_vars);
730   for (iter = agr->agr_vars; iter; iter = next)
731     {
732       next = iter->next;
733
734       if (iter->function & FSTRING)
735         {
736           size_t n_args;
737           size_t i;
738
739           n_args = agr_func_tab[iter->function & FUNC].n_args;
740           for (i = 0; i < n_args; i++)
741             free (iter->arg[i].c);
742           free (iter->string);
743         }
744       else if (iter->function == SD)
745         moments1_destroy (iter->moments);
746
747       dict_destroy_internal_var (iter->subject);
748       dict_destroy_internal_var (iter->weight);
749
750       free (iter);
751     }
752   if (agr->dict != NULL)
753     dict_unref (agr->dict);
754 }
755 \f
756 /* Execution. */
757
758 /* Accumulates aggregation data from the case INPUT. */
759 static void
760 accumulate_aggregate_info (struct agr_proc *agr, const struct ccase *input)
761 {
762   struct agr_var *iter;
763   double weight;
764   bool bad_warn = true;
765
766   weight = dict_get_case_weight (agr->src_dict, input, &bad_warn);
767
768   for (iter = agr->agr_vars; iter; iter = iter->next)
769     if (iter->src)
770       {
771         const union value *v = case_data (input, iter->src);
772         int src_width = var_get_width (iter->src);
773
774         if (var_is_value_missing (iter->src, v) & iter->exclude)
775           {
776             switch (iter->function)
777               {
778               case NMISS:
779               case NMISS | FSTRING:
780                 iter->dbl[0] += weight;
781                 break;
782               case NUMISS:
783               case NUMISS | FSTRING:
784                 iter->int1++;
785                 break;
786               }
787             iter->saw_missing = true;
788             continue;
789           }
790
791         /* This is horrible.  There are too many possibilities. */
792         switch (iter->function)
793           {
794           case SUM:
795             iter->dbl[0] += v->f * weight;
796             iter->int1 = 1;
797             break;
798           case MEAN:
799             iter->dbl[0] += v->f * weight;
800             iter->dbl[1] += weight;
801             break;
802           case MEDIAN:
803             {
804               double wv ;
805               struct ccase *cout;
806
807               cout = case_create (casewriter_get_proto (iter->writer));
808
809               *case_num_rw (cout, iter->subject) = case_num (input, iter->src);
810
811               wv = dict_get_case_weight (agr->src_dict, input, NULL);
812
813               *case_num_rw (cout, iter->weight) = wv;
814
815               iter->cc += wv;
816
817               casewriter_write (iter->writer, cout);
818             }
819             break;
820           case SD:
821             moments1_add (iter->moments, v->f, weight);
822             break;
823           case MAX:
824             iter->dbl[0] = MAX (iter->dbl[0], v->f);
825             iter->int1 = 1;
826             break;
827           case MAX | FSTRING:
828             /* Need to do some kind of Unicode collation thingy here */
829             if (memcmp (iter->string, v->s, src_width) < 0)
830               memcpy (iter->string, v->s, src_width);
831             iter->int1 = 1;
832             break;
833           case MIN:
834             iter->dbl[0] = MIN (iter->dbl[0], v->f);
835             iter->int1 = 1;
836             break;
837           case MIN | FSTRING:
838             if (memcmp (iter->string, v->s, src_width) > 0)
839               memcpy (iter->string, v->s, src_width);
840             iter->int1 = 1;
841             break;
842           case FGT:
843           case PGT:
844             if (v->f > iter->arg[0].f)
845               iter->dbl[0] += weight;
846             iter->dbl[1] += weight;
847             break;
848           case FGT | FSTRING:
849           case PGT | FSTRING:
850             if (memcmp (iter->arg[0].c, v->s, src_width) < 0)
851               iter->dbl[0] += weight;
852             iter->dbl[1] += weight;
853             break;
854           case FLT:
855           case PLT:
856             if (v->f < iter->arg[0].f)
857               iter->dbl[0] += weight;
858             iter->dbl[1] += weight;
859             break;
860           case FLT | FSTRING:
861           case PLT | FSTRING:
862             if (memcmp (iter->arg[0].c, v->s, src_width) > 0)
863               iter->dbl[0] += weight;
864             iter->dbl[1] += weight;
865             break;
866           case FIN:
867           case PIN:
868             if (iter->arg[0].f <= v->f && v->f <= iter->arg[1].f)
869               iter->dbl[0] += weight;
870             iter->dbl[1] += weight;
871             break;
872           case FIN | FSTRING:
873           case PIN | FSTRING:
874             if (memcmp (iter->arg[0].c, v->s, src_width) <= 0
875                 && memcmp (iter->arg[1].c, v->s, src_width) >= 0)
876               iter->dbl[0] += weight;
877             iter->dbl[1] += weight;
878             break;
879           case FOUT:
880           case POUT:
881             if (iter->arg[0].f > v->f || v->f > iter->arg[1].f)
882               iter->dbl[0] += weight;
883             iter->dbl[1] += weight;
884             break;
885           case FOUT | FSTRING:
886           case POUT | FSTRING:
887             if (memcmp (iter->arg[0].c, v->s, src_width) > 0
888                 || memcmp (iter->arg[1].c, v->s, src_width) < 0)
889               iter->dbl[0] += weight;
890             iter->dbl[1] += weight;
891             break;
892           case N:
893           case N | FSTRING:
894             iter->dbl[0] += weight;
895             break;
896           case NU:
897           case NU | FSTRING:
898             iter->int1++;
899             break;
900           case FIRST:
901             if (iter->int1 == 0)
902               {
903                 iter->dbl[0] = v->f;
904                 iter->int1 = 1;
905               }
906             break;
907           case FIRST | FSTRING:
908             if (iter->int1 == 0)
909               {
910                 memcpy (iter->string, v->s, src_width);
911                 iter->int1 = 1;
912               }
913             break;
914           case LAST:
915             iter->dbl[0] = v->f;
916             iter->int1 = 1;
917             break;
918           case LAST | FSTRING:
919             memcpy (iter->string, v->s, src_width);
920             iter->int1 = 1;
921             break;
922           case NMISS:
923           case NMISS | FSTRING:
924           case NUMISS:
925           case NUMISS | FSTRING:
926             /* Our value is not missing or it would have been
927                caught earlier.  Nothing to do. */
928             break;
929           default:
930             NOT_REACHED ();
931           }
932       } else {
933       switch (iter->function)
934         {
935         case N:
936           iter->dbl[0] += weight;
937           break;
938         case NU:
939           iter->int1++;
940           break;
941         default:
942           NOT_REACHED ();
943         }
944     }
945 }
946
947 /* Writes an aggregated record to OUTPUT. */
948 static void
949 dump_aggregate_info (const struct agr_proc *agr, struct casewriter *output, const struct ccase *break_case)
950 {
951   struct ccase *c = case_create (dict_get_proto (agr->dict));
952
953   if (agr->add_variables)
954     {
955       case_copy (c, 0, break_case, 0, dict_get_n_vars (agr->src_dict));
956     }
957   else
958     {
959       int value_idx = 0;
960       int i;
961
962       for (i = 0; i < agr->break_n_vars; i++)
963         {
964           const struct variable *v = agr->break_vars[i];
965           value_copy (case_data_rw_idx (c, value_idx),
966                       case_data (break_case, v),
967                       var_get_width (v));
968           value_idx++;
969         }
970     }
971
972   {
973     struct agr_var *i;
974
975     for (i = agr->agr_vars; i; i = i->next)
976       {
977         union value *v = case_data_rw (c, i->dest);
978         int width = var_get_width (i->dest);
979
980         if (agr->missing == COLUMNWISE && i->saw_missing
981             && (i->function & FUNC) != N && (i->function & FUNC) != NU
982             && (i->function & FUNC) != NMISS && (i->function & FUNC) != NUMISS)
983           {
984             value_set_missing (v, width);
985             casewriter_destroy (i->writer);
986             continue;
987           }
988
989         switch (i->function)
990           {
991           case SUM:
992             v->f = i->int1 ? i->dbl[0] : SYSMIS;
993             break;
994           case MEAN:
995             v->f = i->dbl[1] != 0.0 ? i->dbl[0] / i->dbl[1] : SYSMIS;
996             break;
997           case MEDIAN:
998             {
999               if (i->writer)
1000                 {
1001                   struct percentile *median = percentile_create (0.5, i->cc);
1002                   struct order_stats *os = &median->parent;
1003                   struct casereader *sorted_reader = casewriter_make_reader (i->writer);
1004                   i->writer = NULL;
1005
1006                   order_stats_accumulate (&os, 1,
1007                                           sorted_reader,
1008                                           i->weight,
1009                                           i->subject,
1010                                           i->exclude);
1011                   i->dbl[0] = percentile_calculate (median, PC_HAVERAGE);
1012                   statistic_destroy (&median->parent.parent);
1013                 }
1014               v->f = i->dbl[0];
1015             }
1016             break;
1017           case SD:
1018             {
1019               double variance;
1020
1021               /* FIXME: we should use two passes. */
1022               moments1_calculate (i->moments, NULL, NULL, &variance,
1023                                  NULL, NULL);
1024               if (variance != SYSMIS)
1025                 v->f = sqrt (variance);
1026               else
1027                 v->f = SYSMIS;
1028             }
1029             break;
1030           case MAX:
1031           case MIN:
1032             v->f = i->int1 ? i->dbl[0] : SYSMIS;
1033             break;
1034           case MAX | FSTRING:
1035           case MIN | FSTRING:
1036             if (i->int1)
1037               memcpy (v->s, i->string, width);
1038             else
1039               value_set_missing (v, width);
1040             break;
1041           case FGT:
1042           case FGT | FSTRING:
1043           case FLT:
1044           case FLT | FSTRING:
1045           case FIN:
1046           case FIN | FSTRING:
1047           case FOUT:
1048           case FOUT | FSTRING:
1049             v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] : SYSMIS;
1050             break;
1051           case PGT:
1052           case PGT | FSTRING:
1053           case PLT:
1054           case PLT | FSTRING:
1055           case PIN:
1056           case PIN | FSTRING:
1057           case POUT:
1058           case POUT | FSTRING:
1059             v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] * 100.0 : SYSMIS;
1060             break;
1061           case N:
1062           case N | FSTRING:
1063               v->f = i->dbl[0];
1064             break;
1065           case NU:
1066           case NU | FSTRING:
1067             v->f = i->int1;
1068             break;
1069           case FIRST:
1070           case LAST:
1071             v->f = i->int1 ? i->dbl[0] : SYSMIS;
1072             break;
1073           case FIRST | FSTRING:
1074           case LAST | FSTRING:
1075             if (i->int1)
1076               memcpy (v->s, i->string, width);
1077             else
1078               value_set_missing (v, width);
1079             break;
1080           case NMISS:
1081           case NMISS | FSTRING:
1082             v->f = i->dbl[0];
1083             break;
1084           case NUMISS:
1085           case NUMISS | FSTRING:
1086             v->f = i->int1;
1087             break;
1088           default:
1089             NOT_REACHED ();
1090           }
1091       }
1092   }
1093
1094   casewriter_write (output, c);
1095 }
1096
1097 /* Resets the state for all the aggregate functions. */
1098 static void
1099 initialize_aggregate_info (struct agr_proc *agr)
1100 {
1101   struct agr_var *iter;
1102
1103   for (iter = agr->agr_vars; iter; iter = iter->next)
1104     {
1105       iter->saw_missing = false;
1106       iter->dbl[0] = iter->dbl[1] = iter->dbl[2] = 0.0;
1107       iter->int1 = 0;
1108       switch (iter->function)
1109         {
1110         case MIN:
1111           iter->dbl[0] = DBL_MAX;
1112           break;
1113         case MIN | FSTRING:
1114           memset (iter->string, 255, var_get_width (iter->src));
1115           break;
1116         case MAX:
1117           iter->dbl[0] = -DBL_MAX;
1118           break;
1119         case MAX | FSTRING:
1120           memset (iter->string, 0, var_get_width (iter->src));
1121           break;
1122         case MEDIAN:
1123           {
1124             struct caseproto *proto;
1125             struct subcase ordering;
1126
1127             proto = caseproto_create ();
1128             proto = caseproto_add_width (proto, 0);
1129             proto = caseproto_add_width (proto, 0);
1130
1131             if (! iter->subject)
1132               iter->subject = dict_create_internal_var (0, 0);
1133
1134             if (! iter->weight)
1135               iter->weight = dict_create_internal_var (1, 0);
1136
1137             subcase_init_var (&ordering, iter->subject, SC_ASCEND);
1138             iter->writer = sort_create_writer (&ordering, proto);
1139             subcase_uninit (&ordering);
1140             caseproto_unref (proto);
1141
1142             iter->cc = 0;
1143           }
1144           break;
1145         case SD:
1146           if (iter->moments == NULL)
1147             iter->moments = moments1_create (MOMENT_VARIANCE);
1148           else
1149             moments1_clear (iter->moments);
1150           break;
1151         default:
1152           break;
1153         }
1154     }
1155 }