1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 1997-9, 2000, 2006, 2008, 2009, 2010, 2011, 2012, 2014 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
19 #include "language/stats/aggregate.h"
23 #include "data/any-writer.h"
24 #include "data/case.h"
25 #include "data/casegrouper.h"
26 #include "data/casereader.h"
27 #include "data/casewriter.h"
28 #include "data/dataset.h"
29 #include "data/dictionary.h"
30 #include "data/file-handle-def.h"
31 #include "data/format.h"
32 #include "data/settings.h"
33 #include "data/subcase.h"
34 #include "data/sys-file-writer.h"
35 #include "data/variable.h"
36 #include "language/command.h"
37 #include "language/data-io/file-handle.h"
38 #include "language/lexer/lexer.h"
39 #include "language/lexer/variable-parser.h"
40 #include "language/stats/sort-criteria.h"
41 #include "libpspp/assertion.h"
42 #include "libpspp/i18n.h"
43 #include "libpspp/message.h"
44 #include "libpspp/misc.h"
45 #include "libpspp/pool.h"
46 #include "libpspp/str.h"
47 #include "math/moments.h"
48 #include "math/percentiles.h"
49 #include "math/sort.h"
50 #include "math/statistic.h"
52 #include "gl/c-strcase.h"
53 #include "gl/minmax.h"
54 #include "gl/xalloc.h"
57 #define _(msgid) gettext (msgid)
58 #define N_(msgid) msgid
60 /* Argument for AGGREGATE function.
62 Only one of the members is used, so this could be a union, but it's simpler
66 double f; /* Numeric. */
67 struct substring s; /* String. */
70 /* Specifies how to make an aggregate variable. */
73 /* Collected during parsing. */
74 const struct variable *src; /* Source variable. */
75 struct variable *dest; /* Target variable. */
76 enum agr_function function; /* Function. */
77 enum mv_class exclude; /* Classes of missing values to exclude. */
78 struct agr_argument arg[2]; /* Arguments. */
80 /* Accumulated during AGGREGATE execution. */
82 double W; /* Total non-missing weight. */
86 struct moments1 *moments;
88 struct variable *subject;
89 struct variable *weight;
90 struct casewriter *writer;
93 /* Attributes of aggregation functions. */
94 const struct agr_func agr_func_tab[] =
96 #define AGRF(ENUM, NAME, DESCRIPTION, SRC_VARS, N_ARGS, ALPHA_TYPE, W, D) \
97 [ENUM] = { NAME, DESCRIPTION, SRC_VARS, N_ARGS, ALPHA_TYPE, \
98 { .type = (W) > 0 ? FMT_F : -1, .w = W, .d = D } },
101 {NULL, NULL, AGR_SV_NO, 0, -1, {-1, -1, -1}},
104 /* Missing value types. */
105 enum missing_treatment
107 ITEMWISE, /* Missing values item by item. */
108 COLUMNWISE /* Missing values column by column. */
111 /* An entire AGGREGATE procedure. */
114 /* Break variables. */
115 struct subcase sort; /* Sort criteria (break variables). */
116 const struct variable **break_vars; /* Break variables. */
117 size_t break_n_vars; /* Number of break variables. */
119 enum missing_treatment missing; /* How to treat missing values. */
120 struct agr_var *agr_vars; /* Aggregate variables. */
122 struct dictionary *dict; /* Aggregate dictionary. */
123 const struct dictionary *src_dict; /* Dict of the source */
124 int n_cases; /* Counts aggregated cases. */
126 bool add_variables; /* True iff the aggregated variables should
127 be appended to the existing dictionary */
130 static void initialize_aggregate_info (struct agr_proc *);
132 static void accumulate_aggregate_info (struct agr_proc *,
133 const struct ccase *);
135 static bool parse_aggregate_functions (struct lexer *, const struct dictionary *,
137 static void agr_destroy (struct agr_proc *);
138 static void dump_aggregate_info (const struct agr_proc *agr,
139 struct casewriter *output,
140 const struct ccase *break_case);
144 /* Parses and executes the AGGREGATE procedure. */
146 cmd_aggregate (struct lexer *lexer, struct dataset *ds)
148 struct dictionary *dict = dataset_dict (ds);
149 struct agr_proc agr = {
153 struct file_handle *out_file = NULL;
154 struct casereader *input = NULL;
155 struct casewriter *output = NULL;
157 bool copy_documents = false;
158 bool presorted = false;
159 int addvariables_ofs = 0;
161 /* OUTFILE subcommand must be first. */
162 if (lex_match_phrase (lexer, "/OUTFILE") || lex_match_id (lexer, "OUTFILE"))
164 lex_match (lexer, T_EQUALS);
165 if (!lex_match (lexer, T_ASTERISK))
167 out_file = fh_parse (lexer, FH_REF_FILE, dataset_session (ds));
168 if (out_file == NULL)
172 if (!out_file && lex_match_id (lexer, "MODE"))
174 lex_match (lexer, T_EQUALS);
175 if (lex_match_id (lexer, "ADDVARIABLES"))
177 addvariables_ofs = lex_ofs (lexer) - 1;
178 agr.add_variables = true;
181 else if (lex_match_id (lexer, "REPLACE"))
182 agr.add_variables = false;
185 lex_error_expecting (lexer, "ADDVARIABLES", "REPLACE");
192 agr.add_variables = true;
196 if (lex_match_phrase (lexer, "/MISSING"))
198 lex_match (lexer, T_EQUALS);
199 if (!lex_match_id (lexer, "COLUMNWISE"))
201 lex_error_expecting (lexer, "COLUMNWISE");
204 agr.missing = COLUMNWISE;
207 int presorted_ofs = 0;
209 if (lex_match_phrase (lexer, "/DOCUMENT"))
210 copy_documents = true;
211 else if (lex_match_phrase (lexer, "/PRESORTED"))
214 presorted_ofs = lex_ofs (lexer) - 1;
219 if (agr.add_variables)
220 agr.dict = dict_clone (dict);
222 agr.dict = dict_create (dict_get_encoding (dict));
224 dict_set_label (agr.dict, dict_get_label (dict));
225 dict_set_documents (agr.dict, dict_get_documents (dict));
227 if (lex_match_phrase (lexer, "/BREAK"))
229 lex_match (lexer, T_EQUALS);
231 int break_start = lex_ofs (lexer);
232 if (!parse_sort_criteria (lexer, dict, &agr.sort, &agr.break_vars,
235 int break_end = lex_ofs (lexer) - 1;
236 agr.break_n_vars = subcase_get_n_fields (&agr.sort);
238 if (! agr.add_variables)
239 for (size_t i = 0; i < agr.break_n_vars; i++)
240 dict_clone_var_assert (agr.dict, agr.break_vars[i]);
242 if (presorted && saw_direction)
244 lex_ofs_msg (lexer, SW, break_start, break_end,
245 _("When the input data is presorted, specifying "
246 "sorting directions with (A) or (D) has no effect. "
247 "Output data will be sorted the same way as the "
250 lex_ofs_msg (lexer, SN, presorted_ofs, presorted_ofs,
251 _("The PRESORTED subcommand state that the "
252 "input data is presorted."));
253 else if (addvariables_ofs)
254 lex_ofs_msg (lexer, SN, addvariables_ofs, addvariables_ofs,
255 _("ADDVARIABLES implies that the input data "
258 msg (SN, _("The input data must be presorted because the "
259 "OUTFILE subcommand is not specified."));
263 /* Read in the aggregate functions. */
264 if (!parse_aggregate_functions (lexer, dict, &agr))
267 /* Delete documents. */
269 dict_clear_documents (agr.dict);
271 /* Cancel SPLIT FILE. */
272 dict_clear_split_vars (agr.dict);
277 if (out_file == NULL)
279 /* The active dataset will be replaced by the aggregated data,
280 so TEMPORARY is moot. */
281 proc_cancel_temporary_transformations (ds);
282 proc_discard_output (ds);
283 output = autopaging_writer_create (dict_get_proto (agr.dict));
287 output = any_writer_open (out_file, agr.dict);
292 input = proc_open (ds);
293 if (!subcase_is_empty (&agr.sort) && !presorted)
295 input = sort_execute (input, &agr.sort);
296 subcase_clear (&agr.sort);
299 struct casegrouper *grouper;
300 struct casereader *group;
301 for (grouper = casegrouper_create_vars (input, agr.break_vars,
303 casegrouper_get_next_group (grouper, &group);
304 casereader_destroy (group))
306 struct casereader *placeholder = NULL;
307 struct ccase *c = casereader_peek (group, 0);
311 casereader_destroy (group);
315 initialize_aggregate_info (&agr);
317 if (agr.add_variables)
318 placeholder = casereader_clone (group);
322 for (; (cg = casereader_read (group)) != NULL; case_unref (cg))
323 accumulate_aggregate_info (&agr, cg);
327 if (agr.add_variables)
330 for (; (cg = casereader_read (placeholder)) != NULL; case_unref (cg))
331 dump_aggregate_info (&agr, output, cg);
333 casereader_destroy (placeholder);
337 dump_aggregate_info (&agr, output, c);
341 if (!casegrouper_destroy (grouper))
344 bool ok = proc_commit (ds);
349 if (out_file == NULL)
351 struct casereader *next_input = casewriter_make_reader (output);
352 if (next_input == NULL)
355 dataset_set_dict (ds, agr.dict);
356 dataset_set_source (ds, next_input);
361 ok = casewriter_destroy (output);
374 casewriter_destroy (output);
377 return CMD_CASCADING_FAILURE;
381 parse_agr_func_name (struct lexer *lexer, int *func_index,
382 enum mv_class *exclude)
384 if (lex_token (lexer) != T_ID)
386 lex_error (lexer, _("Syntax error expecting aggregation function."));
390 struct substring name = lex_tokss (lexer);
391 *exclude = ss_chomp_byte (&name, '.') ? MV_SYSTEM : MV_ANY;
393 for (const struct agr_func *f = agr_func_tab; f->name; f++)
394 if (ss_equals_case (ss_cstr (f->name), name))
396 *func_index = f - agr_func_tab;
400 lex_error (lexer, _("Unknown aggregation function %s."), lex_tokcstr (lexer));
404 /* Parse all the aggregate functions. */
406 parse_aggregate_functions (struct lexer *lexer, const struct dictionary *dict,
407 struct agr_proc *agr)
409 if (!lex_force_match (lexer, T_SLASH))
412 size_t starting_n_vars = dict_get_n_vars (dict);
413 size_t allocated_agr_vars = 0;
415 /* Parse everything. */
419 char **dest_label = NULL;
422 struct agr_argument arg[2] = { { .f = 0 }, { .f = 0 } };
424 const struct variable **src = NULL;
426 /* Parse the list of target variables. */
427 int dst_start_ofs = lex_ofs (lexer);
428 while (!lex_match (lexer, T_EQUALS))
430 size_t n_vars_prev = n_vars;
432 if (!parse_DATA_LIST_vars (lexer, dict, &dest, &n_vars,
433 (PV_APPEND | PV_SINGLE | PV_NO_SCRATCH
437 /* Assign empty labels. */
438 dest_label = xnrealloc (dest_label, n_vars, sizeof *dest_label);
439 for (size_t j = n_vars_prev; j < n_vars; j++)
440 dest_label[j] = NULL;
442 if (lex_is_string (lexer))
444 dest_label[n_vars - 1] = xstrdup (lex_tokcstr (lexer));
448 int dst_end_ofs = lex_ofs (lexer) - 2;
450 /* Get the name of the aggregation function. */
452 enum mv_class exclude;
453 if (!parse_agr_func_name (lexer, &func_index, &exclude))
455 const struct agr_func *function = &agr_func_tab[func_index];
457 /* Check for leading lparen. */
458 if (!lex_match (lexer, T_LPAREN))
460 if (function->src_vars == AGR_SV_YES)
462 bool ok UNUSED = lex_force_match (lexer, T_LPAREN);
468 /* Parse list of source variables. */
469 int pv_opts = PV_NO_SCRATCH;
470 if (func_index == AGRF_SUM || func_index == AGRF_MEAN
471 || func_index == AGRF_MEDIAN || func_index == AGRF_SD)
472 pv_opts |= PV_NUMERIC;
473 else if (function->n_args)
474 pv_opts |= PV_SAME_TYPE;
476 int src_start_ofs = lex_ofs (lexer);
478 if (!parse_variables_const (lexer, dict, &src, &n_src, pv_opts))
480 int src_end_ofs = lex_ofs (lexer) - 1;
482 /* Parse function arguments, for those functions that
483 require arguments. */
484 int args_start_ofs = 0;
485 if (function->n_args != 0)
486 for (size_t i = 0; i < function->n_args; i++)
488 lex_match (lexer, T_COMMA);
491 if (lex_is_string (lexer))
493 else if (lex_is_number (lexer))
497 lex_error (lexer, _("Missing argument %zu to %s."),
498 i + 1, function->name);
502 if (type != var_get_type (src[0]))
504 msg (SE, _("Arguments to %s must be of same type as "
505 "source variables."),
507 if (type == VAL_NUMERIC)
509 lex_next_msg (lexer, SN, 0, 0,
510 _("The argument is numeric."));
511 lex_ofs_msg (lexer, SN, src_start_ofs, src_end_ofs,
512 _("The variables have string type."));
516 lex_next_msg (lexer, SN, 0, 0,
517 _("The argument is a string."));
518 lex_ofs_msg (lexer, SN, src_start_ofs, src_end_ofs,
519 _("The variables are numeric."));
525 args_start_ofs = lex_ofs (lexer);
526 if (type == VAL_NUMERIC)
527 arg[i].f = lex_tokval (lexer);
529 arg[i].s = recode_substring_pool (dict_get_encoding (agr->dict),
530 "UTF-8", lex_tokss (lexer),
534 int args_end_ofs = lex_ofs (lexer) - 1;
536 /* Trailing rparen. */
537 if (!lex_force_match (lexer, T_RPAREN))
540 /* Now check that the number of source variables match
541 the number of target variables. If we check earlier
542 than this, the user can get very misleading error
543 message, i.e. `AGGREGATE x=SUM(y t).' will get this
544 error message when a proper message would be more
545 like `unknown variable t'. */
548 msg (SE, _("Number of source variables (%zu) does not match "
549 "number of target variables (%zu)."),
551 lex_ofs_msg (lexer, SN, src_start_ofs, src_end_ofs,
552 _("These are the source variables."));
553 lex_ofs_msg (lexer, SN, dst_start_ofs, dst_end_ofs,
554 _("These are the target variables."));
558 if ((func_index == AGRF_PIN || func_index == AGRF_POUT
559 || func_index == AGRF_FIN || func_index == AGRF_FOUT)
560 && (var_is_numeric (src[0])
561 ? arg[0].f > arg[1].f
562 : buf_compare_rpad (arg[0].s.string, arg[0].s.length,
563 arg[1].s.string, arg[1].s.length) > 0))
565 struct agr_argument tmp = arg[0];
569 lex_ofs_msg (lexer, SW, args_start_ofs, args_end_ofs,
570 _("The value arguments passed to the %s function "
571 "are out of order. They will be treated as if "
572 "they had been specified in the correct order."),
577 /* Finally add these to the aggregation variables. */
578 for (size_t i = 0; i < n_vars; i++)
580 const struct variable *existing_var = dict_lookup_var (agr->dict,
584 if (var_get_dict_index (existing_var) >= starting_n_vars)
585 lex_ofs_error (lexer, dst_start_ofs, dst_end_ofs,
586 _("Duplicate target variable name %s."),
588 else if (agr->add_variables)
589 lex_ofs_error (lexer, dst_start_ofs, dst_end_ofs,
590 _("Variable name %s duplicates the name of a "
591 "variable in the active file dictionary."),
594 lex_ofs_error (lexer, dst_start_ofs, dst_end_ofs,
595 _("Variable name %s duplicates the name of a "
596 "break variable."), dest[i]);
601 if (agr->n_agr_vars >= allocated_agr_vars)
602 agr->agr_vars = x2nrealloc (agr->agr_vars, &allocated_agr_vars,
603 sizeof *agr->agr_vars);
604 struct agr_var *v = &agr->agr_vars[agr->n_agr_vars++];
605 *v = (struct agr_var) {
608 .function = func_index,
609 .src = src ? src[i] : NULL,
612 /* Create the target variable in the aggregate dictionary. */
613 if (v->src && var_is_alpha (v->src))
614 v->string = xmalloc (var_get_width (v->src));
616 if (v->src && function->alpha_type == VAL_STRING)
617 v->dest = dict_clone_var_as_assert (agr->dict, v->src, dest[i]);
620 v->dest = dict_create_var_assert (agr->dict, dest[i], 0);
623 if ((func_index == AGRF_N || func_index == AGRF_NMISS)
624 && dict_get_weight (dict) != NULL)
625 f = fmt_for_output (FMT_F, 8, 2);
627 f = function->format;
628 var_set_both_formats (v->dest, &f);
631 var_set_label (v->dest, dest_label[i]);
634 for (size_t j = 0; j < function->n_args; j++)
635 v->arg[j] = (struct agr_argument) {
637 .s = arg[j].s.string ? ss_clone (arg[j].s) : ss_empty (),
641 ss_dealloc (&arg[0].s);
642 ss_dealloc (&arg[1].s);
645 for (size_t i = 0; i < n_vars; i++)
648 free (dest_label[i]);
653 if (!lex_match (lexer, T_SLASH))
655 if (lex_token (lexer) == T_ENDCMD)
658 lex_error (lexer, "Syntax error expecting end of command.");
664 for (size_t i = 0; i < n_vars; i++)
667 free (dest_label[i]);
671 ss_dealloc (&arg[0].s);
672 ss_dealloc (&arg[1].s);
681 agr_destroy (struct agr_proc *agr)
683 subcase_uninit (&agr->sort);
684 free (agr->break_vars);
685 for (size_t i = 0; i < agr->n_agr_vars; i++)
687 struct agr_var *av = &agr->agr_vars[i];
689 ss_dealloc (&av->arg[0].s);
690 ss_dealloc (&av->arg[1].s);
693 if (av->function == AGRF_SD)
694 moments1_destroy (av->moments);
696 dict_destroy_internal_var (av->subject);
697 dict_destroy_internal_var (av->weight);
699 free (agr->agr_vars);
700 if (agr->dict != NULL)
701 dict_unref (agr->dict);
706 /* Accumulates aggregation data from the case INPUT. */
708 accumulate_aggregate_info (struct agr_proc *agr, const struct ccase *input)
710 bool bad_warn = true;
711 double weight = dict_get_case_weight (agr->src_dict, input, &bad_warn);
712 for (size_t i = 0; i < agr->n_agr_vars; i++)
714 struct agr_var *av = &agr->agr_vars[i];
717 bool is_string = var_is_alpha (av->src);
718 const union value *v = case_data (input, av->src);
719 int src_width = var_get_width (av->src);
720 const struct substring vs = (src_width > 0
721 ? value_ss (v, src_width)
724 if (var_is_value_missing (av->src, v) & av->exclude)
726 switch (av->function)
760 av->saw_missing = true;
764 /* This is horrible. There are too many possibilities. */
766 switch (av->function)
769 av->dbl += v->f * weight;
774 av->dbl += v->f * weight;
779 struct ccase *cout = case_create (casewriter_get_proto (av->writer));
780 *case_num_rw (cout, av->subject) = case_num (input, av->src);
781 *case_num_rw (cout, av->weight) = weight;
782 casewriter_write (av->writer, cout);
787 moments1_add (av->moments, v->f, weight);
792 av->dbl = MAX (av->dbl, v->f);
793 else if (memcmp (av->string, v->s, src_width) < 0)
794 memcpy (av->string, v->s, src_width);
800 av->dbl = MIN (av->dbl, v->f);
801 else if (memcmp (av->string, v->s, src_width) > 0)
802 memcpy (av->string, v->s, src_width);
803 av->dbl = MIN (av->dbl, v->f);
811 ? ss_compare_rpad (av->arg[0].s, vs) < 0
812 : v->f > av->arg[0].f)
820 ? ss_compare_rpad (av->arg[0].s, vs) > 0
821 : v->f < av->arg[0].f)
829 ? (ss_compare_rpad (av->arg[0].s, vs) <= 0
830 && ss_compare_rpad (av->arg[1].s, vs) >= 0)
831 : av->arg[0].f <= v->f && v->f <= av->arg[1].f)
839 ? (ss_compare_rpad (av->arg[0].s, vs) > 0
840 || ss_compare_rpad (av->arg[1].s, vs) < 0)
841 : av->arg[0].f > v->f || v->f > av->arg[1].f)
857 memcpy (av->string, v->s, src_width);
866 memcpy (av->string, v->s, src_width);
874 /* Our value is not missing or it would have been
875 caught earlier. Nothing to do. */
882 switch (av->function)
919 /* Writes an aggregated record to OUTPUT. */
921 dump_aggregate_info (const struct agr_proc *agr, struct casewriter *output, const struct ccase *break_case)
923 struct ccase *c = case_create (dict_get_proto (agr->dict));
925 if (agr->add_variables)
927 case_copy (c, 0, break_case, 0, dict_get_n_vars (agr->src_dict));
933 for (size_t i = 0; i < agr->break_n_vars; i++)
935 const struct variable *v = agr->break_vars[i];
936 value_copy (case_data_rw_idx (c, value_idx),
937 case_data (break_case, v),
943 for (size_t i = 0; i < agr->n_agr_vars; i++)
945 struct agr_var *av = &agr->agr_vars[i];
946 union value *v = case_data_rw (c, av->dest);
947 int width = var_get_width (av->dest);
949 if (agr->missing == COLUMNWISE && av->saw_missing
950 && av->function != AGRF_N
951 && av->function != AGRF_NU
952 && av->function != AGRF_NMISS
953 && av->function != AGRF_NUMISS)
955 value_set_missing (v, width);
956 casewriter_destroy (av->writer);
960 switch (av->function)
963 v->f = av->int1 ? av->dbl : SYSMIS;
967 v->f = av->W != 0.0 ? av->dbl / av->W : SYSMIS;
974 struct percentile *median = percentile_create (0.5, av->W);
975 struct order_stats *os = &median->parent;
976 struct casereader *sorted_reader = casewriter_make_reader (av->writer);
979 order_stats_accumulate (&os, 1,
984 av->dbl = percentile_calculate (median, PC_HAVERAGE);
985 statistic_destroy (&median->parent.parent);
995 moments1_calculate (av->moments, NULL, NULL, &variance,
997 v->f = variance != SYSMIS ? sqrt (variance) : SYSMIS;
1006 v->f = av->int1 ? av->dbl : SYSMIS;
1010 memcpy (v->s, av->string, width);
1012 value_set_missing (v, width);
1020 v->f = av->W ? av->dbl / av->W : SYSMIS;
1027 v->f = av->W ? av->dbl / av->W * 100.0 : SYSMIS;
1052 casewriter_write (output, c);
1055 /* Resets the state for all the aggregate functions. */
1057 initialize_aggregate_info (struct agr_proc *agr)
1059 for (size_t i = 0; i < agr->n_agr_vars; i++)
1061 struct agr_var *av = &agr->agr_vars[i];
1062 av->saw_missing = false;
1063 av->dbl = av->W = 0.0;
1066 int width = av->src ? var_get_width (av->src) : 0;
1067 switch (av->function)
1073 memset (av->string, 255, width);
1080 memset (av->string, 0, width);
1085 struct caseproto *proto = caseproto_create ();
1086 proto = caseproto_add_width (proto, 0);
1087 proto = caseproto_add_width (proto, 0);
1090 av->subject = dict_create_internal_var (0, 0);
1093 av->weight = dict_create_internal_var (1, 0);
1095 struct subcase ordering;
1096 subcase_init_var (&ordering, av->subject, SC_ASCEND);
1097 av->writer = sort_create_writer (&ordering, proto);
1098 subcase_uninit (&ordering);
1099 caseproto_unref (proto);
1104 if (av->moments == NULL)
1105 av->moments = moments1_create (MOMENT_VARIANCE);
1107 moments1_clear (av->moments);