1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 1997-9, 2000, 2006 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
21 #include <data/any-writer.h>
22 #include <data/case-ordering.h>
23 #include <data/case.h>
24 #include <data/casegrouper.h>
25 #include <data/casereader.h>
26 #include <data/casewriter.h>
27 #include <data/dictionary.h>
28 #include <data/file-handle-def.h>
29 #include <data/format.h>
30 #include <data/procedure.h>
31 #include <data/settings.h>
32 #include <data/sys-file-writer.h>
33 #include <data/variable.h>
34 #include <language/command.h>
35 #include <language/data-io/file-handle.h>
36 #include <language/lexer/lexer.h>
37 #include <language/lexer/variable-parser.h>
38 #include <language/stats/sort-criteria.h>
39 #include <libpspp/alloc.h>
40 #include <libpspp/assertion.h>
41 #include <libpspp/message.h>
42 #include <libpspp/misc.h>
43 #include <libpspp/pool.h>
44 #include <libpspp/str.h>
45 #include <math/moments.h>
46 #include <math/sort.h>
51 #define _(msgid) gettext (msgid)
53 /* Argument for AGGREGATE function. */
56 double f; /* Numeric. */
57 char *c; /* Short or long string. */
60 /* Specifies how to make an aggregate variable. */
63 struct agr_var *next; /* Next in list. */
65 /* Collected during parsing. */
66 const struct variable *src; /* Source variable. */
67 struct variable *dest; /* Target variable. */
68 int function; /* Function. */
69 enum mv_class exclude; /* Classes of missing values to exclude. */
70 union agr_argument arg[2]; /* Arguments. */
72 /* Accumulated during AGGREGATE execution. */
77 struct moments1 *moments;
80 /* Aggregation functions. */
83 NONE, SUM, MEAN, SD, MAX, MIN, PGT, PLT, PIN, POUT, FGT, FLT, FIN,
84 FOUT, N, NU, NMISS, NUMISS, FIRST, LAST,
85 N_AGR_FUNCS, N_NO_VARS, NU_NO_VARS,
86 FUNC = 0x1f, /* Function mask. */
87 FSTRING = 1<<5, /* String function bit. */
90 /* Attributes of an aggregation function. */
93 const char *name; /* Aggregation function name. */
94 size_t n_args; /* Number of arguments. */
95 enum var_type alpha_type; /* When given ALPHA arguments, output type. */
96 struct fmt_spec format; /* Format spec if alpha_type != ALPHA. */
99 /* Attributes of aggregation functions. */
100 static const struct agr_func agr_func_tab[] =
102 {"<NONE>", 0, -1, {0, 0, 0}},
103 {"SUM", 0, -1, {FMT_F, 8, 2}},
104 {"MEAN", 0, -1, {FMT_F, 8, 2}},
105 {"SD", 0, -1, {FMT_F, 8, 2}},
106 {"MAX", 0, VAR_STRING, {-1, -1, -1}},
107 {"MIN", 0, VAR_STRING, {-1, -1, -1}},
108 {"PGT", 1, VAR_NUMERIC, {FMT_F, 5, 1}},
109 {"PLT", 1, VAR_NUMERIC, {FMT_F, 5, 1}},
110 {"PIN", 2, VAR_NUMERIC, {FMT_F, 5, 1}},
111 {"POUT", 2, VAR_NUMERIC, {FMT_F, 5, 1}},
112 {"FGT", 1, VAR_NUMERIC, {FMT_F, 5, 3}},
113 {"FLT", 1, VAR_NUMERIC, {FMT_F, 5, 3}},
114 {"FIN", 2, VAR_NUMERIC, {FMT_F, 5, 3}},
115 {"FOUT", 2, VAR_NUMERIC, {FMT_F, 5, 3}},
116 {"N", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
117 {"NU", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
118 {"NMISS", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
119 {"NUMISS", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
120 {"FIRST", 0, VAR_STRING, {-1, -1, -1}},
121 {"LAST", 0, VAR_STRING, {-1, -1, -1}},
122 {NULL, 0, -1, {-1, -1, -1}},
123 {"N", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
124 {"NU", 0, VAR_NUMERIC, {FMT_F, 7, 0}},
127 /* Missing value types. */
128 enum missing_treatment
130 ITEMWISE, /* Missing values item by item. */
131 COLUMNWISE /* Missing values column by column. */
134 /* An entire AGGREGATE procedure. */
137 /* Break variables. */
138 struct case_ordering *sort; /* Sort criteria. */
139 const struct variable **break_vars; /* Break variables. */
140 size_t break_var_cnt; /* Number of break variables. */
141 struct ccase break_case; /* Last values of break variables. */
143 enum missing_treatment missing; /* How to treat missing values. */
144 struct agr_var *agr_vars; /* First aggregate variable. */
145 struct dictionary *dict; /* Aggregate dictionary. */
146 const struct dictionary *src_dict; /* Dict of the source */
147 int case_cnt; /* Counts aggregated cases. */
150 static void initialize_aggregate_info (struct agr_proc *,
151 const struct ccase *);
152 static void accumulate_aggregate_info (struct agr_proc *,
153 const struct ccase *);
155 static bool parse_aggregate_functions (struct lexer *, const struct dictionary *,
157 static void agr_destroy (struct agr_proc *);
158 static void dump_aggregate_info (struct agr_proc *agr,
159 struct casewriter *output);
163 /* Parses and executes the AGGREGATE procedure. */
165 cmd_aggregate (struct lexer *lexer, struct dataset *ds)
167 struct dictionary *dict = dataset_dict (ds);
169 struct file_handle *out_file = NULL;
170 struct casereader *input = NULL, *group;
171 struct casegrouper *grouper;
172 struct casewriter *output = NULL;
174 bool copy_documents = false;
175 bool presorted = false;
179 memset(&agr, 0 , sizeof (agr));
180 agr.missing = ITEMWISE;
181 case_nullify (&agr.break_case);
183 agr.dict = dict_create ();
185 dict_set_label (agr.dict, dict_get_label (dict));
186 dict_set_documents (agr.dict, dict_get_documents (dict));
188 /* OUTFILE subcommand must be first. */
189 if (!lex_force_match_id (lexer, "OUTFILE"))
191 lex_match (lexer, '=');
192 if (!lex_match (lexer, '*'))
194 out_file = fh_parse (lexer, FH_REF_FILE | FH_REF_SCRATCH);
195 if (out_file == NULL)
199 /* Read most of the subcommands. */
202 lex_match (lexer, '/');
204 if (lex_match_id (lexer, "MISSING"))
206 lex_match (lexer, '=');
207 if (!lex_match_id (lexer, "COLUMNWISE"))
209 lex_error (lexer, _("while expecting COLUMNWISE"));
212 agr.missing = COLUMNWISE;
214 else if (lex_match_id (lexer, "DOCUMENT"))
215 copy_documents = true;
216 else if (lex_match_id (lexer, "PRESORTED"))
218 else if (lex_match_id (lexer, "BREAK"))
222 lex_match (lexer, '=');
223 agr.sort = parse_case_ordering (lexer, dict,
226 if (agr.sort == NULL)
228 case_ordering_get_vars (agr.sort,
229 &agr.break_vars, &agr.break_var_cnt);
231 for (i = 0; i < agr.break_var_cnt; i++)
232 dict_clone_var_assert (agr.dict, agr.break_vars[i],
233 var_get_name (agr.break_vars[i]));
235 /* BREAK must follow the options. */
240 lex_error (lexer, _("expecting BREAK"));
244 if (presorted && saw_direction)
245 msg (SW, _("When PRESORTED is specified, specifying sorting directions "
246 "with (A) or (D) has no effect. Output data will be sorted "
247 "the same way as the input data."));
249 /* Read in the aggregate functions. */
250 lex_match (lexer, '/');
251 if (!parse_aggregate_functions (lexer, dict, &agr))
254 /* Delete documents. */
256 dict_clear_documents (agr.dict);
258 /* Cancel SPLIT FILE. */
259 dict_set_split_vars (agr.dict, NULL, 0);
264 if (out_file == NULL)
266 /* The active file will be replaced by the aggregated data,
267 so TEMPORARY is moot. */
268 proc_cancel_temporary_transformations (ds);
269 proc_discard_output (ds);
270 output = autopaging_writer_create (dict_get_next_value_idx (agr.dict));
274 output = any_writer_open (out_file, agr.dict);
279 input = proc_open (ds);
280 if (agr.sort != NULL && !presorted)
282 input = sort_execute (input, agr.sort);
286 for (grouper = casegrouper_create_vars (input, agr.break_vars,
288 casegrouper_get_next_group (grouper, &group);
289 casereader_destroy (group))
293 if (!casereader_peek (group, 0, &c))
295 casereader_destroy (group);
298 initialize_aggregate_info (&agr, &c);
301 for (; casereader_read (group, &c); case_destroy (&c))
302 accumulate_aggregate_info (&agr, &c);
303 dump_aggregate_info (&agr, output);
305 if (!casegrouper_destroy (grouper))
308 if (!proc_commit (ds))
315 if (out_file == NULL)
317 struct casereader *next_input = casewriter_make_reader (output);
318 if (next_input == NULL)
321 proc_set_active_file (ds, next_input, agr.dict);
326 ok = casewriter_destroy (output);
338 casewriter_destroy (output);
340 return CMD_CASCADING_FAILURE;
343 /* Parse all the aggregate functions. */
345 parse_aggregate_functions (struct lexer *lexer, const struct dictionary *dict, struct agr_proc *agr)
347 struct agr_var *tail; /* Tail of linked list starting at agr->vars. */
349 /* Parse everything. */
356 struct string function_name;
358 enum mv_class exclude;
359 const struct agr_func *function;
362 union agr_argument arg[2];
364 const struct variable **src;
377 ds_init_empty (&function_name);
379 /* Parse the list of target variables. */
380 while (!lex_match (lexer, '='))
382 size_t n_dest_prev = n_dest;
384 if (!parse_DATA_LIST_vars (lexer, &dest, &n_dest,
385 PV_APPEND | PV_SINGLE | PV_NO_SCRATCH))
388 /* Assign empty labels. */
392 dest_label = xnrealloc (dest_label, n_dest, sizeof *dest_label);
393 for (j = n_dest_prev; j < n_dest; j++)
394 dest_label[j] = NULL;
399 if (lex_token (lexer) == T_STRING)
402 ds_init_string (&label, lex_tokstr (lexer));
404 ds_truncate (&label, 255);
405 dest_label[n_dest - 1] = ds_xstrdup (&label);
411 /* Get the name of the aggregation function. */
412 if (lex_token (lexer) != T_ID)
414 lex_error (lexer, _("expecting aggregation function"));
420 ds_assign_string (&function_name, lex_tokstr (lexer));
422 ds_chomp (&function_name, '.');
424 if (lex_tokid(lexer)[strlen (lex_tokid (lexer)) - 1] == '.')
427 for (function = agr_func_tab; function->name; function++)
428 if (!strcasecmp (function->name, ds_cstr (&function_name)))
430 if (NULL == function->name)
432 msg (SE, _("Unknown aggregation function %s."),
433 ds_cstr (&function_name));
436 ds_destroy (&function_name);
437 func_index = function - agr_func_tab;
440 /* Check for leading lparen. */
441 if (!lex_match (lexer, '('))
444 func_index = N_NO_VARS;
445 else if (func_index == NU)
446 func_index = NU_NO_VARS;
449 lex_error (lexer, _("expecting `('"));
455 /* Parse list of source variables. */
457 int pv_opts = PV_NO_SCRATCH;
459 if (func_index == SUM || func_index == MEAN || func_index == SD)
460 pv_opts |= PV_NUMERIC;
461 else if (function->n_args)
462 pv_opts |= PV_SAME_TYPE;
464 if (!parse_variables_const (lexer, dict, &src, &n_src, pv_opts))
468 /* Parse function arguments, for those functions that
469 require arguments. */
470 if (function->n_args != 0)
471 for (i = 0; i < function->n_args; i++)
475 lex_match (lexer, ',');
476 if (lex_token (lexer) == T_STRING)
478 arg[i].c = ds_xstrdup (lex_tokstr (lexer));
481 else if (lex_is_number (lexer))
483 arg[i].f = lex_tokval (lexer);
488 msg (SE, _("Missing argument %d to %s."),
489 (int) i + 1, function->name);
495 if (type != var_get_type (src[0]))
497 msg (SE, _("Arguments to %s must be of same type as "
498 "source variables."),
504 /* Trailing rparen. */
505 if (!lex_match (lexer, ')'))
507 lex_error (lexer, _("expecting `)'"));
511 /* Now check that the number of source variables match
512 the number of target variables. If we check earlier
513 than this, the user can get very misleading error
514 message, i.e. `AGGREGATE x=SUM(y t).' will get this
515 error message when a proper message would be more
516 like `unknown variable t'. */
519 msg (SE, _("Number of source variables (%u) does not match "
520 "number of target variables (%u)."),
521 (unsigned) n_src, (unsigned) n_dest);
525 if ((func_index == PIN || func_index == POUT
526 || func_index == FIN || func_index == FOUT)
527 && (var_is_numeric (src[0])
528 ? arg[0].f > arg[1].f
529 : str_compare_rpad (arg[0].c, arg[1].c) > 0))
531 union agr_argument t = arg[0];
535 msg (SW, _("The value arguments passed to the %s function "
536 "are out-of-order. They will be treated as if "
537 "they had been specified in the correct order."),
542 /* Finally add these to the linked list of aggregation
544 for (i = 0; i < n_dest; i++)
546 struct agr_var *v = xmalloc (sizeof *v);
548 /* Add variable to chain. */
549 if (agr->agr_vars != NULL)
557 /* Create the target variable in the aggregate
560 struct variable *destvar;
562 v->function = func_index;
568 if (var_is_alpha (src[i]))
570 v->function |= FSTRING;
571 v->string = xmalloc (var_get_width (src[i]));
574 if (function->alpha_type == VAR_STRING)
575 destvar = dict_clone_var (agr->dict, v->src, dest[i]);
578 assert (var_is_numeric (v->src)
579 || function->alpha_type == VAR_NUMERIC);
580 destvar = dict_create_var (agr->dict, dest[i], 0);
584 if ((func_index == N || func_index == NMISS)
585 && dict_get_weight (dict) != NULL)
586 f = fmt_for_output (FMT_F, 8, 2);
588 f = function->format;
589 var_set_both_formats (destvar, &f);
595 destvar = dict_create_var (agr->dict, dest[i], 0);
596 if (func_index == N_NO_VARS && dict_get_weight (dict) != NULL)
597 f = fmt_for_output (FMT_F, 8, 2);
599 f = function->format;
600 var_set_both_formats (destvar, &f);
605 msg (SE, _("Variable name %s is not unique within the "
606 "aggregate file dictionary, which contains "
607 "the aggregate variables and the break "
615 var_set_label (destvar, dest_label[i]);
620 v->exclude = exclude;
626 if (var_is_numeric (v->src))
627 for (j = 0; j < function->n_args; j++)
628 v->arg[j].f = arg[j].f;
630 for (j = 0; j < function->n_args; j++)
631 v->arg[j].c = xstrdup (arg[j].c);
635 if (src != NULL && var_is_alpha (src[0]))
636 for (i = 0; i < function->n_args; i++)
646 if (!lex_match (lexer, '/'))
648 if (lex_token (lexer) == '.')
651 lex_error (lexer, "expecting end of command");
657 ds_destroy (&function_name);
658 for (i = 0; i < n_dest; i++)
661 free (dest_label[i]);
667 if (src && n_src && var_is_alpha (src[0]))
668 for (i = 0; i < function->n_args; i++)
681 agr_destroy (struct agr_proc *agr)
683 struct agr_var *iter, *next;
685 case_ordering_destroy (agr->sort);
686 free (agr->break_vars);
687 case_destroy (&agr->break_case);
688 for (iter = agr->agr_vars; iter; iter = next)
692 if (iter->function & FSTRING)
697 n_args = agr_func_tab[iter->function & FUNC].n_args;
698 for (i = 0; i < n_args; i++)
699 free (iter->arg[i].c);
702 else if (iter->function == SD)
703 moments1_destroy (iter->moments);
706 if (agr->dict != NULL)
707 dict_destroy (agr->dict);
712 /* Accumulates aggregation data from the case INPUT. */
714 accumulate_aggregate_info (struct agr_proc *agr, const struct ccase *input)
716 struct agr_var *iter;
718 bool bad_warn = true;
720 weight = dict_get_case_weight (agr->src_dict, input, &bad_warn);
722 for (iter = agr->agr_vars; iter; iter = iter->next)
725 const union value *v = case_data (input, iter->src);
726 int src_width = var_get_width (iter->src);
728 if (var_is_value_missing (iter->src, v, iter->exclude))
730 switch (iter->function)
733 case NMISS | FSTRING:
734 iter->dbl[0] += weight;
737 case NUMISS | FSTRING:
741 iter->saw_missing = true;
745 /* This is horrible. There are too many possibilities. */
746 switch (iter->function)
749 iter->dbl[0] += v->f * weight;
753 iter->dbl[0] += v->f * weight;
754 iter->dbl[1] += weight;
757 moments1_add (iter->moments, v->f, weight);
760 iter->dbl[0] = MAX (iter->dbl[0], v->f);
764 if (memcmp (iter->string, v->s, src_width) < 0)
765 memcpy (iter->string, v->s, src_width);
769 iter->dbl[0] = MIN (iter->dbl[0], v->f);
773 if (memcmp (iter->string, v->s, src_width) > 0)
774 memcpy (iter->string, v->s, src_width);
779 if (v->f > iter->arg[0].f)
780 iter->dbl[0] += weight;
781 iter->dbl[1] += weight;
785 if (memcmp (iter->arg[0].c, v->s, src_width) < 0)
786 iter->dbl[0] += weight;
787 iter->dbl[1] += weight;
791 if (v->f < iter->arg[0].f)
792 iter->dbl[0] += weight;
793 iter->dbl[1] += weight;
797 if (memcmp (iter->arg[0].c, v->s, src_width) > 0)
798 iter->dbl[0] += weight;
799 iter->dbl[1] += weight;
803 if (iter->arg[0].f <= v->f && v->f <= iter->arg[1].f)
804 iter->dbl[0] += weight;
805 iter->dbl[1] += weight;
809 if (memcmp (iter->arg[0].c, v->s, src_width) <= 0
810 && memcmp (iter->arg[1].c, v->s, src_width) >= 0)
811 iter->dbl[0] += weight;
812 iter->dbl[1] += weight;
816 if (iter->arg[0].f > v->f || v->f > iter->arg[1].f)
817 iter->dbl[0] += weight;
818 iter->dbl[1] += weight;
822 if (memcmp (iter->arg[0].c, v->s, src_width) > 0
823 || memcmp (iter->arg[1].c, v->s, src_width) < 0)
824 iter->dbl[0] += weight;
825 iter->dbl[1] += weight;
829 iter->dbl[0] += weight;
842 case FIRST | FSTRING:
845 memcpy (iter->string, v->s, src_width);
854 memcpy (iter->string, v->s, src_width);
858 case NMISS | FSTRING:
860 case NUMISS | FSTRING:
861 /* Our value is not missing or it would have been
862 caught earlier. Nothing to do. */
868 switch (iter->function)
871 iter->dbl[0] += weight;
882 /* Writes an aggregated record to OUTPUT. */
884 dump_aggregate_info (struct agr_proc *agr, struct casewriter *output)
888 case_create (&c, dict_get_next_value_idx (agr->dict));
894 for (i = 0; i < agr->break_var_cnt; i++)
896 const struct variable *v = agr->break_vars[i];
897 size_t value_cnt = var_get_value_cnt (v);
898 memcpy (case_data_rw_idx (&c, value_idx),
899 case_data (&agr->break_case, v),
900 sizeof (union value) * value_cnt);
901 value_idx += value_cnt;
908 for (i = agr->agr_vars; i; i = i->next)
910 union value *v = case_data_rw (&c, i->dest);
912 if (agr->missing == COLUMNWISE && i->saw_missing
913 && (i->function & FUNC) != N && (i->function & FUNC) != NU
914 && (i->function & FUNC) != NMISS && (i->function & FUNC) != NUMISS)
916 if (var_is_alpha (i->dest))
917 memset (v->s, ' ', var_get_width (i->dest));
926 v->f = i->int1 ? i->dbl[0] : SYSMIS;
929 v->f = i->dbl[1] != 0.0 ? i->dbl[0] / i->dbl[1] : SYSMIS;
935 /* FIXME: we should use two passes. */
936 moments1_calculate (i->moments, NULL, NULL, &variance,
938 if (variance != SYSMIS)
939 v->f = sqrt (variance);
946 v->f = i->int1 ? i->dbl[0] : SYSMIS;
951 memcpy (v->s, i->string, var_get_width (i->dest));
953 memset (v->s, ' ', var_get_width (i->dest));
963 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] : SYSMIS;
973 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] * 100.0 : SYSMIS;
985 v->f = i->int1 ? i->dbl[0] : SYSMIS;
987 case FIRST | FSTRING:
990 memcpy (v->s, i->string, var_get_width (i->dest));
992 memset (v->s, ' ', var_get_width (i->dest));
1001 case NMISS | FSTRING:
1005 case NUMISS | FSTRING:
1014 casewriter_write (output, &c);
1017 /* Resets the state for all the aggregate functions. */
1019 initialize_aggregate_info (struct agr_proc *agr, const struct ccase *input)
1021 struct agr_var *iter;
1023 case_destroy (&agr->break_case);
1024 case_clone (&agr->break_case, input);
1026 for (iter = agr->agr_vars; iter; iter = iter->next)
1028 iter->saw_missing = false;
1029 iter->dbl[0] = iter->dbl[1] = iter->dbl[2] = 0.0;
1030 iter->int1 = iter->int2 = 0;
1031 switch (iter->function)
1034 iter->dbl[0] = DBL_MAX;
1037 memset (iter->string, 255, var_get_width (iter->src));
1040 iter->dbl[0] = -DBL_MAX;
1043 memset (iter->string, 0, var_get_width (iter->src));
1046 if (iter->moments == NULL)
1047 iter->moments = moments1_create (MOMENT_VARIANCE);
1049 moments1_clear (iter->moments);