1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 1997-9, 2000, 2006 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
21 #include <data/any-writer.h>
22 #include <data/case-ordering.h>
23 #include <data/case.h>
24 #include <data/casegrouper.h>
25 #include <data/casereader.h>
26 #include <data/casewriter.h>
27 #include <data/dictionary.h>
28 #include <data/file-handle-def.h>
29 #include <data/format.h>
30 #include <data/procedure.h>
31 #include <data/settings.h>
32 #include <data/sys-file-writer.h>
33 #include <data/variable.h>
34 #include <language/command.h>
35 #include <language/data-io/file-handle.h>
36 #include <language/lexer/lexer.h>
37 #include <language/lexer/variable-parser.h>
38 #include <language/stats/sort-criteria.h>
39 #include <libpspp/alloc.h>
40 #include <libpspp/assertion.h>
41 #include <libpspp/message.h>
42 #include <libpspp/misc.h>
43 #include <libpspp/pool.h>
44 #include <libpspp/str.h>
45 #include <math/moments.h>
46 #include <math/sort.h>
51 #define _(msgid) gettext (msgid)
53 /* Argument for AGGREGATE function. */
56 double f; /* Numeric. */
57 char *c; /* Short or long string. */
60 /* Specifies how to make an aggregate variable. */
63 struct agr_var *next; /* Next in list. */
65 /* Collected during parsing. */
66 const struct variable *src; /* Source variable. */
67 struct variable *dest; /* Target variable. */
68 int function; /* Function. */
69 enum mv_class exclude; /* Classes of missing values to exclude. */
70 union agr_argument arg[2]; /* Arguments. */
72 /* Accumulated during AGGREGATE execution. */
77 struct moments1 *moments;
80 /* Aggregation functions. */
83 NONE, SUM, MEAN, SD, MAX, MIN, PGT, PLT, PIN, POUT, FGT, FLT, FIN,
84 FOUT, N, NU, NMISS, NUMISS, FIRST, LAST,
85 N_AGR_FUNCS, N_NO_VARS, NU_NO_VARS,
86 FUNC = 0x1f, /* Function mask. */
87 FSTRING = 1<<5, /* String function bit. */
90 /* Attributes of an aggregation function. */
93 const char *name; /* Aggregation function name. */
94 size_t n_args; /* Number of arguments. */
95 enum var_type alpha_type; /* When given ALPHA arguments, output type. */
96 struct fmt_spec format; /* Format spec if alpha_type != ALPHA. */
99 /* Attributes of aggregation functions. */
100 static const struct agr_func agr_func_tab[] =
102 {"<NONE>", 0, -1, {0, 0, 0}},
103 {"SUM", 0, -1, {FMT_F, 8, 2}},
104 {"MEAN", 0, -1, {FMT_F, 8, 2}},
105 {"SD", 0, -1, {FMT_F, 8, 2}},
106 {"MAX", 0, VAR_STRING, {-1, -1, -1}},
107 {"MIN", 0, VAR_STRING, {-1, -1, -1}},
108 {"PGT", 1, VAR_NUMERIC, {FMT_F, 5, 1}},
109 {"PLT", 1, VAR_NUMERIC, {FMT_F, 5, 1}},
110 {"PIN", 2, VAR_NUMERIC, {FMT_F, 5, 1}},
111 {"POUT", 2, VAR_NUMERIC, {FMT_F, 5, 1}},
112 {"FGT", 1, VAR_NUMERIC, {FMT_F, 5, 3}},
113 {"FLT", 1, VAR_NUMERIC, {FMT_F, 5, 3}},
114 {"FIN", 2, VAR_NUMERIC, {FMT_F, 5, 3}},
115 {"FOUT", 2, VAR_NUMERIC, {FMT_F, 5, 3}},
116 {"N", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
117 {"NU", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
118 {"NMISS", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
119 {"NUMISS", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
120 {"FIRST", 0, VAR_STRING, {-1, -1, -1}},
121 {"LAST", 0, VAR_STRING, {-1, -1, -1}},
122 {NULL, 0, -1, {-1, -1, -1}},
123 {"N", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
124 {"NU", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
127 /* Missing value types. */
128 enum missing_treatment
130 ITEMWISE, /* Missing values item by item. */
131 COLUMNWISE /* Missing values column by column. */
134 /* An entire AGGREGATE procedure. */
137 /* Break variables. */
138 struct case_ordering *sort; /* Sort criteria. */
139 const struct variable **break_vars; /* Break variables. */
140 size_t break_var_cnt; /* Number of break variables. */
141 struct ccase break_case; /* Last values of break variables. */
143 enum missing_treatment missing; /* How to treat missing values. */
144 struct agr_var *agr_vars; /* First aggregate variable. */
145 struct dictionary *dict; /* Aggregate dictionary. */
146 const struct dictionary *src_dict; /* Dict of the source */
147 int case_cnt; /* Counts aggregated cases. */
150 static void initialize_aggregate_info (struct agr_proc *,
151 const struct ccase *);
152 static void accumulate_aggregate_info (struct agr_proc *,
153 const struct ccase *);
155 static bool parse_aggregate_functions (struct lexer *, const struct dictionary *,
157 static void agr_destroy (struct agr_proc *);
158 static void dump_aggregate_info (struct agr_proc *agr,
159 struct casewriter *output);
163 /* Parses and executes the AGGREGATE procedure. */
165 cmd_aggregate (struct lexer *lexer, struct dataset *ds)
167 struct dictionary *dict = dataset_dict (ds);
169 struct file_handle *out_file = NULL;
170 struct casereader *input = NULL, *group;
171 struct casegrouper *grouper;
172 struct casewriter *output = NULL;
174 bool copy_documents = false;
175 bool presorted = false;
179 memset(&agr, 0 , sizeof (agr));
180 agr.missing = ITEMWISE;
181 case_nullify (&agr.break_case);
183 agr.dict = dict_create ();
185 dict_set_label (agr.dict, dict_get_label (dict));
186 dict_set_documents (agr.dict, dict_get_documents (dict));
188 /* OUTFILE subcommand must be first. */
189 if (!lex_force_match_id (lexer, "OUTFILE"))
191 lex_match (lexer, '=');
192 if (!lex_match (lexer, '*'))
194 out_file = fh_parse (lexer, FH_REF_FILE | FH_REF_SCRATCH);
195 if (out_file == NULL)
199 /* Read most of the subcommands. */
202 lex_match (lexer, '/');
204 if (lex_match_id (lexer, "MISSING"))
206 lex_match (lexer, '=');
207 if (!lex_match_id (lexer, "COLUMNWISE"))
209 lex_error (lexer, _("while expecting COLUMNWISE"));
212 agr.missing = COLUMNWISE;
214 else if (lex_match_id (lexer, "DOCUMENT"))
215 copy_documents = true;
216 else if (lex_match_id (lexer, "PRESORTED"))
218 else if (lex_match_id (lexer, "BREAK"))
222 lex_match (lexer, '=');
223 agr.sort = parse_case_ordering (lexer, dict,
226 if (agr.sort == NULL)
228 case_ordering_get_vars (agr.sort,
229 &agr.break_vars, &agr.break_var_cnt);
231 for (i = 0; i < agr.break_var_cnt; i++)
232 dict_clone_var_assert (agr.dict, agr.break_vars[i],
233 var_get_name (agr.break_vars[i]));
235 /* BREAK must follow the options. */
240 lex_error (lexer, _("expecting BREAK"));
244 if (presorted && saw_direction)
245 msg (SW, _("When PRESORTED is specified, specifying sorting directions "
246 "with (A) or (D) has no effect. Output data will be sorted "
247 "the same way as the input data."));
249 /* Read in the aggregate functions. */
250 lex_match (lexer, '/');
251 if (!parse_aggregate_functions (lexer, dict, &agr))
254 /* Delete documents. */
256 dict_clear_documents (agr.dict);
258 /* Cancel SPLIT FILE. */
259 dict_set_split_vars (agr.dict, NULL, 0);
264 if (out_file == NULL)
266 /* The active file will be replaced by the aggregated data,
267 so TEMPORARY is moot. */
268 proc_cancel_temporary_transformations (ds);
269 proc_discard_output (ds);
270 output = autopaging_writer_create (dict_get_next_value_idx (agr.dict));
274 output = any_writer_open (out_file, agr.dict);
279 input = proc_open (ds);
280 if (agr.sort != NULL && !presorted)
282 input = sort_execute (input, agr.sort);
286 for (grouper = casegrouper_create_vars (input, agr.break_vars,
288 casegrouper_get_next_group (grouper, &group);
289 casereader_destroy (group))
293 if (!casereader_peek (group, 0, &c))
295 initialize_aggregate_info (&agr, &c);
298 for (; casereader_read (group, &c); case_destroy (&c))
299 accumulate_aggregate_info (&agr, &c);
300 dump_aggregate_info (&agr, output);
302 if (!casegrouper_destroy (grouper))
305 if (!proc_commit (ds))
312 if (out_file == NULL)
314 struct casereader *next_input = casewriter_make_reader (output);
315 if (next_input == NULL)
318 proc_set_active_file (ds, next_input, agr.dict);
323 ok = casewriter_destroy (output);
335 casewriter_destroy (output);
337 return CMD_CASCADING_FAILURE;
340 /* Parse all the aggregate functions. */
342 parse_aggregate_functions (struct lexer *lexer, const struct dictionary *dict, struct agr_proc *agr)
344 struct agr_var *tail; /* Tail of linked list starting at agr->vars. */
346 /* Parse everything. */
353 struct string function_name;
355 enum mv_class exclude;
356 const struct agr_func *function;
359 union agr_argument arg[2];
361 const struct variable **src;
374 ds_init_empty (&function_name);
376 /* Parse the list of target variables. */
377 while (!lex_match (lexer, '='))
379 size_t n_dest_prev = n_dest;
381 if (!parse_DATA_LIST_vars (lexer, &dest, &n_dest,
382 PV_APPEND | PV_SINGLE | PV_NO_SCRATCH))
385 /* Assign empty labels. */
389 dest_label = xnrealloc (dest_label, n_dest, sizeof *dest_label);
390 for (j = n_dest_prev; j < n_dest; j++)
391 dest_label[j] = NULL;
396 if (lex_token (lexer) == T_STRING)
399 ds_init_string (&label, lex_tokstr (lexer));
401 ds_truncate (&label, 255);
402 dest_label[n_dest - 1] = ds_xstrdup (&label);
408 /* Get the name of the aggregation function. */
409 if (lex_token (lexer) != T_ID)
411 lex_error (lexer, _("expecting aggregation function"));
417 ds_assign_string (&function_name, lex_tokstr (lexer));
419 ds_chomp (&function_name, '.');
421 if (lex_tokid(lexer)[strlen (lex_tokid (lexer)) - 1] == '.')
424 for (function = agr_func_tab; function->name; function++)
425 if (!strcasecmp (function->name, ds_cstr (&function_name)))
427 if (NULL == function->name)
429 msg (SE, _("Unknown aggregation function %s."),
430 ds_cstr (&function_name));
433 ds_destroy (&function_name);
434 func_index = function - agr_func_tab;
437 /* Check for leading lparen. */
438 if (!lex_match (lexer, '('))
441 func_index = N_NO_VARS;
442 else if (func_index == NU)
443 func_index = NU_NO_VARS;
446 lex_error (lexer, _("expecting `('"));
452 /* Parse list of source variables. */
454 int pv_opts = PV_NO_SCRATCH;
456 if (func_index == SUM || func_index == MEAN || func_index == SD)
457 pv_opts |= PV_NUMERIC;
458 else if (function->n_args)
459 pv_opts |= PV_SAME_TYPE;
461 if (!parse_variables_const (lexer, dict, &src, &n_src, pv_opts))
465 /* Parse function arguments, for those functions that
466 require arguments. */
467 if (function->n_args != 0)
468 for (i = 0; i < function->n_args; i++)
472 lex_match (lexer, ',');
473 if (lex_token (lexer) == T_STRING)
475 arg[i].c = ds_xstrdup (lex_tokstr (lexer));
478 else if (lex_is_number (lexer))
480 arg[i].f = lex_tokval (lexer);
485 msg (SE, _("Missing argument %d to %s."),
486 (int) i + 1, function->name);
492 if (type != var_get_type (src[0]))
494 msg (SE, _("Arguments to %s must be of same type as "
495 "source variables."),
501 /* Trailing rparen. */
502 if (!lex_match (lexer, ')'))
504 lex_error (lexer, _("expecting `)'"));
508 /* Now check that the number of source variables match
509 the number of target variables. If we check earlier
510 than this, the user can get very misleading error
511 message, i.e. `AGGREGATE x=SUM(y t).' will get this
512 error message when a proper message would be more
513 like `unknown variable t'. */
516 msg (SE, _("Number of source variables (%u) does not match "
517 "number of target variables (%u)."),
518 (unsigned) n_src, (unsigned) n_dest);
522 if ((func_index == PIN || func_index == POUT
523 || func_index == FIN || func_index == FOUT)
524 && (var_is_numeric (src[0])
525 ? arg[0].f > arg[1].f
526 : str_compare_rpad (arg[0].c, arg[1].c) > 0))
528 union agr_argument t = arg[0];
532 msg (SW, _("The value arguments passed to the %s function "
533 "are out-of-order. They will be treated as if "
534 "they had been specified in the correct order."),
539 /* Finally add these to the linked list of aggregation
541 for (i = 0; i < n_dest; i++)
543 struct agr_var *v = xmalloc (sizeof *v);
545 /* Add variable to chain. */
546 if (agr->agr_vars != NULL)
554 /* Create the target variable in the aggregate
557 struct variable *destvar;
559 v->function = func_index;
565 if (var_is_alpha (src[i]))
567 v->function |= FSTRING;
568 v->string = xmalloc (var_get_width (src[i]));
571 if (function->alpha_type == VAR_STRING)
572 destvar = dict_clone_var (agr->dict, v->src, dest[i]);
575 assert (var_is_numeric (v->src)
576 || function->alpha_type == VAR_NUMERIC);
577 destvar = dict_create_var (agr->dict, dest[i], 0);
581 if ((func_index == N || func_index == NMISS)
582 && dict_get_weight (dict) != NULL)
583 f = fmt_for_output (FMT_F, 8, 2);
585 f = function->format;
586 var_set_both_formats (destvar, &f);
592 destvar = dict_create_var (agr->dict, dest[i], 0);
593 if (func_index == N_NO_VARS && dict_get_weight (dict) != NULL)
594 f = fmt_for_output (FMT_F, 8, 2);
596 f = function->format;
597 var_set_both_formats (destvar, &f);
602 msg (SE, _("Variable name %s is not unique within the "
603 "aggregate file dictionary, which contains "
604 "the aggregate variables and the break "
612 var_set_label (destvar, dest_label[i]);
617 v->exclude = exclude;
623 if (var_is_numeric (v->src))
624 for (j = 0; j < function->n_args; j++)
625 v->arg[j].f = arg[j].f;
627 for (j = 0; j < function->n_args; j++)
628 v->arg[j].c = xstrdup (arg[j].c);
632 if (src != NULL && var_is_alpha (src[0]))
633 for (i = 0; i < function->n_args; i++)
643 if (!lex_match (lexer, '/'))
645 if (lex_token (lexer) == '.')
648 lex_error (lexer, "expecting end of command");
654 ds_destroy (&function_name);
655 for (i = 0; i < n_dest; i++)
658 free (dest_label[i]);
664 if (src && n_src && var_is_alpha (src[0]))
665 for (i = 0; i < function->n_args; i++)
678 agr_destroy (struct agr_proc *agr)
680 struct agr_var *iter, *next;
682 case_ordering_destroy (agr->sort);
683 free (agr->break_vars);
684 case_destroy (&agr->break_case);
685 for (iter = agr->agr_vars; iter; iter = next)
689 if (iter->function & FSTRING)
694 n_args = agr_func_tab[iter->function & FUNC].n_args;
695 for (i = 0; i < n_args; i++)
696 free (iter->arg[i].c);
699 else if (iter->function == SD)
700 moments1_destroy (iter->moments);
703 if (agr->dict != NULL)
704 dict_destroy (agr->dict);
709 /* Accumulates aggregation data from the case INPUT. */
711 accumulate_aggregate_info (struct agr_proc *agr, const struct ccase *input)
713 struct agr_var *iter;
715 bool bad_warn = true;
717 weight = dict_get_case_weight (agr->src_dict, input, &bad_warn);
719 for (iter = agr->agr_vars; iter; iter = iter->next)
722 const union value *v = case_data (input, iter->src);
723 int src_width = var_get_width (iter->src);
725 if (var_is_value_missing (iter->src, v, iter->exclude))
727 switch (iter->function)
730 case NMISS | FSTRING:
731 iter->dbl[0] += weight;
734 case NUMISS | FSTRING:
738 iter->saw_missing = true;
742 /* This is horrible. There are too many possibilities. */
743 switch (iter->function)
746 iter->dbl[0] += v->f * weight;
750 iter->dbl[0] += v->f * weight;
751 iter->dbl[1] += weight;
754 moments1_add (iter->moments, v->f, weight);
757 iter->dbl[0] = MAX (iter->dbl[0], v->f);
761 if (memcmp (iter->string, v->s, src_width) < 0)
762 memcpy (iter->string, v->s, src_width);
766 iter->dbl[0] = MIN (iter->dbl[0], v->f);
770 if (memcmp (iter->string, v->s, src_width) > 0)
771 memcpy (iter->string, v->s, src_width);
776 if (v->f > iter->arg[0].f)
777 iter->dbl[0] += weight;
778 iter->dbl[1] += weight;
782 if (memcmp (iter->arg[0].c, v->s, src_width) < 0)
783 iter->dbl[0] += weight;
784 iter->dbl[1] += weight;
788 if (v->f < iter->arg[0].f)
789 iter->dbl[0] += weight;
790 iter->dbl[1] += weight;
794 if (memcmp (iter->arg[0].c, v->s, src_width) > 0)
795 iter->dbl[0] += weight;
796 iter->dbl[1] += weight;
800 if (iter->arg[0].f <= v->f && v->f <= iter->arg[1].f)
801 iter->dbl[0] += weight;
802 iter->dbl[1] += weight;
806 if (memcmp (iter->arg[0].c, v->s, src_width) <= 0
807 && memcmp (iter->arg[1].c, v->s, src_width) >= 0)
808 iter->dbl[0] += weight;
809 iter->dbl[1] += weight;
813 if (iter->arg[0].f > v->f || v->f > iter->arg[1].f)
814 iter->dbl[0] += weight;
815 iter->dbl[1] += weight;
819 if (memcmp (iter->arg[0].c, v->s, src_width) > 0
820 || memcmp (iter->arg[1].c, v->s, src_width) < 0)
821 iter->dbl[0] += weight;
822 iter->dbl[1] += weight;
826 iter->dbl[0] += weight;
839 case FIRST | FSTRING:
842 memcpy (iter->string, v->s, src_width);
851 memcpy (iter->string, v->s, src_width);
855 case NMISS | FSTRING:
857 case NUMISS | FSTRING:
858 /* Our value is not missing or it would have been
859 caught earlier. Nothing to do. */
865 switch (iter->function)
868 iter->dbl[0] += weight;
879 /* Writes an aggregated record to OUTPUT. */
881 dump_aggregate_info (struct agr_proc *agr, struct casewriter *output)
885 case_create (&c, dict_get_next_value_idx (agr->dict));
891 for (i = 0; i < agr->break_var_cnt; i++)
893 const struct variable *v = agr->break_vars[i];
894 size_t value_cnt = var_get_value_cnt (v);
895 memcpy (case_data_rw_idx (&c, value_idx),
896 case_data (&agr->break_case, v),
897 sizeof (union value) * value_cnt);
898 value_idx += value_cnt;
905 for (i = agr->agr_vars; i; i = i->next)
907 union value *v = case_data_rw (&c, i->dest);
909 if (agr->missing == COLUMNWISE && i->saw_missing
910 && (i->function & FUNC) != N && (i->function & FUNC) != NU
911 && (i->function & FUNC) != NMISS && (i->function & FUNC) != NUMISS)
913 if (var_is_alpha (i->dest))
914 memset (v->s, ' ', var_get_width (i->dest));
923 v->f = i->int1 ? i->dbl[0] : SYSMIS;
926 v->f = i->dbl[1] != 0.0 ? i->dbl[0] / i->dbl[1] : SYSMIS;
932 /* FIXME: we should use two passes. */
933 moments1_calculate (i->moments, NULL, NULL, &variance,
935 if (variance != SYSMIS)
936 v->f = sqrt (variance);
943 v->f = i->int1 ? i->dbl[0] : SYSMIS;
948 memcpy (v->s, i->string, var_get_width (i->dest));
950 memset (v->s, ' ', var_get_width (i->dest));
960 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] : SYSMIS;
970 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] * 100.0 : SYSMIS;
982 v->f = i->int1 ? i->dbl[0] : SYSMIS;
984 case FIRST | FSTRING:
987 memcpy (v->s, i->string, var_get_width (i->dest));
989 memset (v->s, ' ', var_get_width (i->dest));
998 case NMISS | FSTRING:
1002 case NUMISS | FSTRING:
1011 casewriter_write (output, &c);
1014 /* Resets the state for all the aggregate functions. */
1016 initialize_aggregate_info (struct agr_proc *agr, const struct ccase *input)
1018 struct agr_var *iter;
1020 case_destroy (&agr->break_case);
1021 case_clone (&agr->break_case, input);
1023 for (iter = agr->agr_vars; iter; iter = iter->next)
1025 iter->saw_missing = false;
1026 iter->dbl[0] = iter->dbl[1] = iter->dbl[2] = 0.0;
1027 iter->int1 = iter->int2 = 0;
1028 switch (iter->function)
1031 iter->dbl[0] = DBL_MAX;
1034 memset (iter->string, 255, var_get_width (iter->src));
1037 iter->dbl[0] = -DBL_MAX;
1040 memset (iter->string, 0, var_get_width (iter->src));
1043 if (iter->moments == NULL)
1044 iter->moments = moments1_create (MOMENT_VARIANCE);
1046 moments1_clear (iter->moments);