1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 1997-9, 2000, 2006 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
21 #include <data/any-writer.h>
22 #include <data/case-ordering.h>
23 #include <data/case.h>
24 #include <data/casegrouper.h>
25 #include <data/casereader.h>
26 #include <data/casewriter.h>
27 #include <data/dictionary.h>
28 #include <data/file-handle-def.h>
29 #include <data/format.h>
30 #include <data/procedure.h>
31 #include <data/settings.h>
32 #include <data/sys-file-writer.h>
33 #include <data/variable.h>
34 #include <language/command.h>
35 #include <language/data-io/file-handle.h>
36 #include <language/lexer/lexer.h>
37 #include <language/lexer/variable-parser.h>
38 #include <language/stats/sort-criteria.h>
39 #include <libpspp/assertion.h>
40 #include <libpspp/message.h>
41 #include <libpspp/misc.h>
42 #include <libpspp/pool.h>
43 #include <libpspp/str.h>
44 #include <math/moments.h>
45 #include <math/sort.h>
51 #define _(msgid) gettext (msgid)
53 /* Argument for AGGREGATE function. */
56 double f; /* Numeric. */
57 char *c; /* Short or long string. */
60 /* Specifies how to make an aggregate variable. */
63 struct agr_var *next; /* Next in list. */
65 /* Collected during parsing. */
66 const struct variable *src; /* Source variable. */
67 struct variable *dest; /* Target variable. */
68 int function; /* Function. */
69 enum mv_class exclude; /* Classes of missing values to exclude. */
70 union agr_argument arg[2]; /* Arguments. */
72 /* Accumulated during AGGREGATE execution. */
77 struct moments1 *moments;
80 /* Aggregation functions. */
83 NONE, SUM, MEAN, SD, MAX, MIN, PGT, PLT, PIN, POUT, FGT, FLT, FIN,
84 FOUT, N, NU, NMISS, NUMISS, FIRST, LAST,
85 N_AGR_FUNCS, N_NO_VARS, NU_NO_VARS,
86 FUNC = 0x1f, /* Function mask. */
87 FSTRING = 1<<5, /* String function bit. */
90 /* Attributes of an aggregation function. */
93 const char *name; /* Aggregation function name. */
94 size_t n_args; /* Number of arguments. */
95 enum val_type alpha_type; /* When given ALPHA arguments, output type. */
96 struct fmt_spec format; /* Format spec if alpha_type != ALPHA. */
99 /* Attributes of aggregation functions. */
100 static const struct agr_func agr_func_tab[] =
102 {"<NONE>", 0, -1, {0, 0, 0}},
103 {"SUM", 0, -1, {FMT_F, 8, 2}},
104 {"MEAN", 0, -1, {FMT_F, 8, 2}},
105 {"SD", 0, -1, {FMT_F, 8, 2}},
106 {"MAX", 0, VAL_STRING, {-1, -1, -1}},
107 {"MIN", 0, VAL_STRING, {-1, -1, -1}},
108 {"PGT", 1, VAL_NUMERIC, {FMT_F, 5, 1}},
109 {"PLT", 1, VAL_NUMERIC, {FMT_F, 5, 1}},
110 {"PIN", 2, VAL_NUMERIC, {FMT_F, 5, 1}},
111 {"POUT", 2, VAL_NUMERIC, {FMT_F, 5, 1}},
112 {"FGT", 1, VAL_NUMERIC, {FMT_F, 5, 3}},
113 {"FLT", 1, VAL_NUMERIC, {FMT_F, 5, 3}},
114 {"FIN", 2, VAL_NUMERIC, {FMT_F, 5, 3}},
115 {"FOUT", 2, VAL_NUMERIC, {FMT_F, 5, 3}},
116 {"N", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
117 {"NU", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
118 {"NMISS", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
119 {"NUMISS", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
120 {"FIRST", 0, VAL_STRING, {-1, -1, -1}},
121 {"LAST", 0, VAL_STRING, {-1, -1, -1}},
122 {NULL, 0, -1, {-1, -1, -1}},
123 {"N", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
124 {"NU", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
127 /* Missing value types. */
128 enum missing_treatment
130 ITEMWISE, /* Missing values item by item. */
131 COLUMNWISE /* Missing values column by column. */
134 /* An entire AGGREGATE procedure. */
137 /* Break variables. */
138 struct case_ordering *sort; /* Sort criteria. */
139 const struct variable **break_vars; /* Break variables. */
140 size_t break_var_cnt; /* Number of break variables. */
141 struct ccase break_case; /* Last values of break variables. */
143 enum missing_treatment missing; /* How to treat missing values. */
144 struct agr_var *agr_vars; /* First aggregate variable. */
145 struct dictionary *dict; /* Aggregate dictionary. */
146 const struct dictionary *src_dict; /* Dict of the source */
147 int case_cnt; /* Counts aggregated cases. */
150 static void initialize_aggregate_info (struct agr_proc *,
151 const struct ccase *);
152 static void accumulate_aggregate_info (struct agr_proc *,
153 const struct ccase *);
155 static bool parse_aggregate_functions (struct lexer *, const struct dictionary *,
157 static void agr_destroy (struct agr_proc *);
158 static void dump_aggregate_info (struct agr_proc *agr,
159 struct casewriter *output);
163 /* Parses and executes the AGGREGATE procedure. */
165 cmd_aggregate (struct lexer *lexer, struct dataset *ds)
167 struct dictionary *dict = dataset_dict (ds);
169 struct file_handle *out_file = NULL;
170 struct casereader *input = NULL, *group;
171 struct casegrouper *grouper;
172 struct casewriter *output = NULL;
174 bool copy_documents = false;
175 bool presorted = false;
179 memset(&agr, 0 , sizeof (agr));
180 agr.missing = ITEMWISE;
181 case_nullify (&agr.break_case);
183 agr.dict = dict_create ();
185 dict_set_label (agr.dict, dict_get_label (dict));
186 dict_set_documents (agr.dict, dict_get_documents (dict));
188 /* OUTFILE subcommand must be first. */
189 if (!lex_force_match_id (lexer, "OUTFILE"))
191 lex_match (lexer, '=');
192 if (!lex_match (lexer, '*'))
194 out_file = fh_parse (lexer, FH_REF_FILE | FH_REF_SCRATCH);
195 if (out_file == NULL)
199 /* Read most of the subcommands. */
202 lex_match (lexer, '/');
204 if (lex_match_id (lexer, "MISSING"))
206 lex_match (lexer, '=');
207 if (!lex_match_id (lexer, "COLUMNWISE"))
209 lex_error (lexer, _("while expecting COLUMNWISE"));
212 agr.missing = COLUMNWISE;
214 else if (lex_match_id (lexer, "DOCUMENT"))
215 copy_documents = true;
216 else if (lex_match_id (lexer, "PRESORTED"))
218 else if (lex_match_id (lexer, "BREAK"))
222 lex_match (lexer, '=');
223 agr.sort = parse_case_ordering (lexer, dict,
226 if (agr.sort == NULL)
228 case_ordering_get_vars (agr.sort,
229 &agr.break_vars, &agr.break_var_cnt);
231 for (i = 0; i < agr.break_var_cnt; i++)
232 dict_clone_var_assert (agr.dict, agr.break_vars[i],
233 var_get_name (agr.break_vars[i]));
235 /* BREAK must follow the options. */
240 lex_error (lexer, _("expecting BREAK"));
244 if (presorted && saw_direction)
245 msg (SW, _("When PRESORTED is specified, specifying sorting directions "
246 "with (A) or (D) has no effect. Output data will be sorted "
247 "the same way as the input data."));
249 /* Read in the aggregate functions. */
250 lex_match (lexer, '/');
251 if (!parse_aggregate_functions (lexer, dict, &agr))
254 /* Delete documents. */
256 dict_clear_documents (agr.dict);
258 /* Cancel SPLIT FILE. */
259 dict_set_split_vars (agr.dict, NULL, 0);
264 if (out_file == NULL)
266 /* The active file will be replaced by the aggregated data,
267 so TEMPORARY is moot. */
268 proc_cancel_temporary_transformations (ds);
269 proc_discard_output (ds);
270 output = autopaging_writer_create (dict_get_next_value_idx (agr.dict));
274 output = any_writer_open (out_file, agr.dict);
279 input = proc_open (ds);
280 if (agr.sort != NULL && !presorted)
282 input = sort_execute (input, agr.sort);
286 for (grouper = casegrouper_create_vars (input, agr.break_vars,
288 casegrouper_get_next_group (grouper, &group);
289 casereader_destroy (group))
293 if (!casereader_peek (group, 0, &c))
295 casereader_destroy (group);
298 initialize_aggregate_info (&agr, &c);
301 for (; casereader_read (group, &c); case_destroy (&c))
302 accumulate_aggregate_info (&agr, &c);
303 dump_aggregate_info (&agr, output);
305 if (!casegrouper_destroy (grouper))
308 if (!proc_commit (ds))
315 if (out_file == NULL)
317 struct casereader *next_input = casewriter_make_reader (output);
318 if (next_input == NULL)
321 proc_set_active_file (ds, next_input, agr.dict);
326 ok = casewriter_destroy (output);
339 casewriter_destroy (output);
342 return CMD_CASCADING_FAILURE;
345 /* Parse all the aggregate functions. */
347 parse_aggregate_functions (struct lexer *lexer, const struct dictionary *dict, struct agr_proc *agr)
349 struct agr_var *tail; /* Tail of linked list starting at agr->vars. */
351 /* Parse everything. */
358 struct string function_name;
360 enum mv_class exclude;
361 const struct agr_func *function;
364 union agr_argument arg[2];
366 const struct variable **src;
379 ds_init_empty (&function_name);
381 /* Parse the list of target variables. */
382 while (!lex_match (lexer, '='))
384 size_t n_dest_prev = n_dest;
386 if (!parse_DATA_LIST_vars (lexer, &dest, &n_dest,
387 PV_APPEND | PV_SINGLE | PV_NO_SCRATCH))
390 /* Assign empty labels. */
394 dest_label = xnrealloc (dest_label, n_dest, sizeof *dest_label);
395 for (j = n_dest_prev; j < n_dest; j++)
396 dest_label[j] = NULL;
401 if (lex_token (lexer) == T_STRING)
404 ds_init_string (&label, lex_tokstr (lexer));
406 ds_truncate (&label, 255);
407 dest_label[n_dest - 1] = ds_xstrdup (&label);
413 /* Get the name of the aggregation function. */
414 if (lex_token (lexer) != T_ID)
416 lex_error (lexer, _("expecting aggregation function"));
422 ds_assign_string (&function_name, lex_tokstr (lexer));
424 ds_chomp (&function_name, '.');
426 if (lex_tokid(lexer)[strlen (lex_tokid (lexer)) - 1] == '.')
429 for (function = agr_func_tab; function->name; function++)
430 if (!strcasecmp (function->name, ds_cstr (&function_name)))
432 if (NULL == function->name)
434 msg (SE, _("Unknown aggregation function %s."),
435 ds_cstr (&function_name));
438 ds_destroy (&function_name);
439 func_index = function - agr_func_tab;
442 /* Check for leading lparen. */
443 if (!lex_match (lexer, '('))
446 func_index = N_NO_VARS;
447 else if (func_index == NU)
448 func_index = NU_NO_VARS;
451 lex_error (lexer, _("expecting `('"));
457 /* Parse list of source variables. */
459 int pv_opts = PV_NO_SCRATCH;
461 if (func_index == SUM || func_index == MEAN || func_index == SD)
462 pv_opts |= PV_NUMERIC;
463 else if (function->n_args)
464 pv_opts |= PV_SAME_TYPE;
466 if (!parse_variables_const (lexer, dict, &src, &n_src, pv_opts))
470 /* Parse function arguments, for those functions that
471 require arguments. */
472 if (function->n_args != 0)
473 for (i = 0; i < function->n_args; i++)
477 lex_match (lexer, ',');
478 if (lex_token (lexer) == T_STRING)
480 arg[i].c = ds_xstrdup (lex_tokstr (lexer));
483 else if (lex_is_number (lexer))
485 arg[i].f = lex_tokval (lexer);
490 msg (SE, _("Missing argument %zu to %s."),
491 i + 1, function->name);
497 if (type != var_get_type (src[0]))
499 msg (SE, _("Arguments to %s must be of same type as "
500 "source variables."),
506 /* Trailing rparen. */
507 if (!lex_match (lexer, ')'))
509 lex_error (lexer, _("expecting `)'"));
513 /* Now check that the number of source variables match
514 the number of target variables. If we check earlier
515 than this, the user can get very misleading error
516 message, i.e. `AGGREGATE x=SUM(y t).' will get this
517 error message when a proper message would be more
518 like `unknown variable t'. */
521 msg (SE, _("Number of source variables (%zu) does not match "
522 "number of target variables (%zu)."),
527 if ((func_index == PIN || func_index == POUT
528 || func_index == FIN || func_index == FOUT)
529 && (var_is_numeric (src[0])
530 ? arg[0].f > arg[1].f
531 : str_compare_rpad (arg[0].c, arg[1].c) > 0))
533 union agr_argument t = arg[0];
537 msg (SW, _("The value arguments passed to the %s function "
538 "are out-of-order. They will be treated as if "
539 "they had been specified in the correct order."),
544 /* Finally add these to the linked list of aggregation
546 for (i = 0; i < n_dest; i++)
548 struct agr_var *v = xmalloc (sizeof *v);
550 /* Add variable to chain. */
551 if (agr->agr_vars != NULL)
559 /* Create the target variable in the aggregate
562 struct variable *destvar;
564 v->function = func_index;
570 if (var_is_alpha (src[i]))
572 v->function |= FSTRING;
573 v->string = xmalloc (var_get_width (src[i]));
576 if (function->alpha_type == VAL_STRING)
577 destvar = dict_clone_var (agr->dict, v->src, dest[i]);
580 assert (var_is_numeric (v->src)
581 || function->alpha_type == VAL_NUMERIC);
582 destvar = dict_create_var (agr->dict, dest[i], 0);
586 if ((func_index == N || func_index == NMISS)
587 && dict_get_weight (dict) != NULL)
588 f = fmt_for_output (FMT_F, 8, 2);
590 f = function->format;
591 var_set_both_formats (destvar, &f);
597 destvar = dict_create_var (agr->dict, dest[i], 0);
598 if (func_index == N_NO_VARS && dict_get_weight (dict) != NULL)
599 f = fmt_for_output (FMT_F, 8, 2);
601 f = function->format;
602 var_set_both_formats (destvar, &f);
607 msg (SE, _("Variable name %s is not unique within the "
608 "aggregate file dictionary, which contains "
609 "the aggregate variables and the break "
617 var_set_label (destvar, dest_label[i]);
622 v->exclude = exclude;
628 if (var_is_numeric (v->src))
629 for (j = 0; j < function->n_args; j++)
630 v->arg[j].f = arg[j].f;
632 for (j = 0; j < function->n_args; j++)
633 v->arg[j].c = xstrdup (arg[j].c);
637 if (src != NULL && var_is_alpha (src[0]))
638 for (i = 0; i < function->n_args; i++)
648 if (!lex_match (lexer, '/'))
650 if (lex_token (lexer) == '.')
653 lex_error (lexer, "expecting end of command");
659 ds_destroy (&function_name);
660 for (i = 0; i < n_dest; i++)
663 free (dest_label[i]);
669 if (src && n_src && var_is_alpha (src[0]))
670 for (i = 0; i < function->n_args; i++)
683 agr_destroy (struct agr_proc *agr)
685 struct agr_var *iter, *next;
687 case_ordering_destroy (agr->sort);
688 free (agr->break_vars);
689 case_destroy (&agr->break_case);
690 for (iter = agr->agr_vars; iter; iter = next)
694 if (iter->function & FSTRING)
699 n_args = agr_func_tab[iter->function & FUNC].n_args;
700 for (i = 0; i < n_args; i++)
701 free (iter->arg[i].c);
704 else if (iter->function == SD)
705 moments1_destroy (iter->moments);
708 if (agr->dict != NULL)
709 dict_destroy (agr->dict);
714 /* Accumulates aggregation data from the case INPUT. */
716 accumulate_aggregate_info (struct agr_proc *agr, const struct ccase *input)
718 struct agr_var *iter;
720 bool bad_warn = true;
722 weight = dict_get_case_weight (agr->src_dict, input, &bad_warn);
724 for (iter = agr->agr_vars; iter; iter = iter->next)
727 const union value *v = case_data (input, iter->src);
728 int src_width = var_get_width (iter->src);
730 if (var_is_value_missing (iter->src, v, iter->exclude))
732 switch (iter->function)
735 case NMISS | FSTRING:
736 iter->dbl[0] += weight;
739 case NUMISS | FSTRING:
743 iter->saw_missing = true;
747 /* This is horrible. There are too many possibilities. */
748 switch (iter->function)
751 iter->dbl[0] += v->f * weight;
755 iter->dbl[0] += v->f * weight;
756 iter->dbl[1] += weight;
759 moments1_add (iter->moments, v->f, weight);
762 iter->dbl[0] = MAX (iter->dbl[0], v->f);
766 if (memcmp (iter->string, v->s, src_width) < 0)
767 memcpy (iter->string, v->s, src_width);
771 iter->dbl[0] = MIN (iter->dbl[0], v->f);
775 if (memcmp (iter->string, v->s, src_width) > 0)
776 memcpy (iter->string, v->s, src_width);
781 if (v->f > iter->arg[0].f)
782 iter->dbl[0] += weight;
783 iter->dbl[1] += weight;
787 if (memcmp (iter->arg[0].c, v->s, src_width) < 0)
788 iter->dbl[0] += weight;
789 iter->dbl[1] += weight;
793 if (v->f < iter->arg[0].f)
794 iter->dbl[0] += weight;
795 iter->dbl[1] += weight;
799 if (memcmp (iter->arg[0].c, v->s, src_width) > 0)
800 iter->dbl[0] += weight;
801 iter->dbl[1] += weight;
805 if (iter->arg[0].f <= v->f && v->f <= iter->arg[1].f)
806 iter->dbl[0] += weight;
807 iter->dbl[1] += weight;
811 if (memcmp (iter->arg[0].c, v->s, src_width) <= 0
812 && memcmp (iter->arg[1].c, v->s, src_width) >= 0)
813 iter->dbl[0] += weight;
814 iter->dbl[1] += weight;
818 if (iter->arg[0].f > v->f || v->f > iter->arg[1].f)
819 iter->dbl[0] += weight;
820 iter->dbl[1] += weight;
824 if (memcmp (iter->arg[0].c, v->s, src_width) > 0
825 || memcmp (iter->arg[1].c, v->s, src_width) < 0)
826 iter->dbl[0] += weight;
827 iter->dbl[1] += weight;
831 iter->dbl[0] += weight;
844 case FIRST | FSTRING:
847 memcpy (iter->string, v->s, src_width);
856 memcpy (iter->string, v->s, src_width);
860 case NMISS | FSTRING:
862 case NUMISS | FSTRING:
863 /* Our value is not missing or it would have been
864 caught earlier. Nothing to do. */
870 switch (iter->function)
873 iter->dbl[0] += weight;
884 /* Writes an aggregated record to OUTPUT. */
886 dump_aggregate_info (struct agr_proc *agr, struct casewriter *output)
890 case_create (&c, dict_get_next_value_idx (agr->dict));
896 for (i = 0; i < agr->break_var_cnt; i++)
898 const struct variable *v = agr->break_vars[i];
899 size_t value_cnt = var_get_value_cnt (v);
900 memcpy (case_data_rw_idx (&c, value_idx),
901 case_data (&agr->break_case, v),
902 sizeof (union value) * value_cnt);
903 value_idx += value_cnt;
910 for (i = agr->agr_vars; i; i = i->next)
912 union value *v = case_data_rw (&c, i->dest);
914 if (agr->missing == COLUMNWISE && i->saw_missing
915 && (i->function & FUNC) != N && (i->function & FUNC) != NU
916 && (i->function & FUNC) != NMISS && (i->function & FUNC) != NUMISS)
918 if (var_is_alpha (i->dest))
919 memset (v->s, ' ', var_get_width (i->dest));
928 v->f = i->int1 ? i->dbl[0] : SYSMIS;
931 v->f = i->dbl[1] != 0.0 ? i->dbl[0] / i->dbl[1] : SYSMIS;
937 /* FIXME: we should use two passes. */
938 moments1_calculate (i->moments, NULL, NULL, &variance,
940 if (variance != SYSMIS)
941 v->f = sqrt (variance);
948 v->f = i->int1 ? i->dbl[0] : SYSMIS;
953 memcpy (v->s, i->string, var_get_width (i->dest));
955 memset (v->s, ' ', var_get_width (i->dest));
965 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] : SYSMIS;
975 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] * 100.0 : SYSMIS;
987 v->f = i->int1 ? i->dbl[0] : SYSMIS;
989 case FIRST | FSTRING:
992 memcpy (v->s, i->string, var_get_width (i->dest));
994 memset (v->s, ' ', var_get_width (i->dest));
1003 case NMISS | FSTRING:
1007 case NUMISS | FSTRING:
1016 casewriter_write (output, &c);
1019 /* Resets the state for all the aggregate functions. */
1021 initialize_aggregate_info (struct agr_proc *agr, const struct ccase *input)
1023 struct agr_var *iter;
1025 case_destroy (&agr->break_case);
1026 case_clone (&agr->break_case, input);
1028 for (iter = agr->agr_vars; iter; iter = iter->next)
1030 iter->saw_missing = false;
1031 iter->dbl[0] = iter->dbl[1] = iter->dbl[2] = 0.0;
1032 iter->int1 = iter->int2 = 0;
1033 switch (iter->function)
1036 iter->dbl[0] = DBL_MAX;
1039 memset (iter->string, 255, var_get_width (iter->src));
1042 iter->dbl[0] = -DBL_MAX;
1045 memset (iter->string, 0, var_get_width (iter->src));
1048 if (iter->moments == NULL)
1049 iter->moments = moments1_create (MOMENT_VARIANCE);
1051 moments1_clear (iter->moments);