1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 1997-9, 2000, 2006, 2008 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
21 #include <data/any-writer.h>
22 #include <data/case.h>
23 #include <data/casegrouper.h>
24 #include <data/casereader.h>
25 #include <data/casewriter.h>
26 #include <data/dictionary.h>
27 #include <data/file-handle-def.h>
28 #include <data/format.h>
29 #include <data/procedure.h>
30 #include <data/settings.h>
31 #include <data/subcase.h>
32 #include <data/sys-file-writer.h>
33 #include <data/variable.h>
34 #include <language/command.h>
35 #include <language/data-io/file-handle.h>
36 #include <language/lexer/lexer.h>
37 #include <language/lexer/variable-parser.h>
38 #include <language/stats/sort-criteria.h>
39 #include <libpspp/assertion.h>
40 #include <libpspp/message.h>
41 #include <libpspp/misc.h>
42 #include <libpspp/pool.h>
43 #include <libpspp/str.h>
44 #include <math/moments.h>
45 #include <math/sort.h>
46 #include <math/statistic.h>
47 #include <math/percentiles.h>
53 #define _(msgid) gettext (msgid)
55 /* Argument for AGGREGATE function. */
58 double f; /* Numeric. */
59 char *c; /* Short or long string. */
62 /* Specifies how to make an aggregate variable. */
65 struct agr_var *next; /* Next in list. */
67 /* Collected during parsing. */
68 const struct variable *src; /* Source variable. */
69 struct variable *dest; /* Target variable. */
70 int function; /* Function. */
71 enum mv_class exclude; /* Classes of missing values to exclude. */
72 union agr_argument arg[2]; /* Arguments. */
74 /* Accumulated during AGGREGATE execution. */
79 struct moments1 *moments;
82 struct variable *subject;
83 struct variable *weight;
84 struct casewriter *writer;
87 /* Aggregation functions. */
90 NONE, SUM, MEAN, MEDIAN, SD, MAX, MIN, PGT, PLT, PIN, POUT, FGT, FLT, FIN,
91 FOUT, N, NU, NMISS, NUMISS, FIRST, LAST,
92 N_AGR_FUNCS, N_NO_VARS, NU_NO_VARS,
93 FUNC = 0x1f, /* Function mask. */
94 FSTRING = 1<<5, /* String function bit. */
97 /* Attributes of an aggregation function. */
100 const char *name; /* Aggregation function name. */
101 size_t n_args; /* Number of arguments. */
102 enum val_type alpha_type; /* When given ALPHA arguments, output type. */
103 struct fmt_spec format; /* Format spec if alpha_type != ALPHA. */
106 /* Attributes of aggregation functions. */
107 static const struct agr_func agr_func_tab[] =
109 {"<NONE>", 0, -1, {0, 0, 0}},
110 {"SUM", 0, -1, {FMT_F, 8, 2}},
111 {"MEAN", 0, -1, {FMT_F, 8, 2}},
112 {"MEDIAN", 0, -1, {FMT_F, 8, 2}},
113 {"SD", 0, -1, {FMT_F, 8, 2}},
114 {"MAX", 0, VAL_STRING, {-1, -1, -1}},
115 {"MIN", 0, VAL_STRING, {-1, -1, -1}},
116 {"PGT", 1, VAL_NUMERIC, {FMT_F, 5, 1}},
117 {"PLT", 1, VAL_NUMERIC, {FMT_F, 5, 1}},
118 {"PIN", 2, VAL_NUMERIC, {FMT_F, 5, 1}},
119 {"POUT", 2, VAL_NUMERIC, {FMT_F, 5, 1}},
120 {"FGT", 1, VAL_NUMERIC, {FMT_F, 5, 3}},
121 {"FLT", 1, VAL_NUMERIC, {FMT_F, 5, 3}},
122 {"FIN", 2, VAL_NUMERIC, {FMT_F, 5, 3}},
123 {"FOUT", 2, VAL_NUMERIC, {FMT_F, 5, 3}},
124 {"N", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
125 {"NU", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
126 {"NMISS", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
127 {"NUMISS", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
128 {"FIRST", 0, VAL_STRING, {-1, -1, -1}},
129 {"LAST", 0, VAL_STRING, {-1, -1, -1}},
130 {NULL, 0, -1, {-1, -1, -1}},
131 {"N", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
132 {"NU", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
135 /* Missing value types. */
136 enum missing_treatment
138 ITEMWISE, /* Missing values item by item. */
139 COLUMNWISE /* Missing values column by column. */
142 /* An entire AGGREGATE procedure. */
145 /* Break variables. */
146 struct subcase sort; /* Sort criteria (break variables). */
147 const struct variable **break_vars; /* Break variables. */
148 size_t break_var_cnt; /* Number of break variables. */
149 struct ccase break_case; /* Last values of break variables. */
151 enum missing_treatment missing; /* How to treat missing values. */
152 struct agr_var *agr_vars; /* First aggregate variable. */
153 struct dictionary *dict; /* Aggregate dictionary. */
154 const struct dictionary *src_dict; /* Dict of the source */
155 int case_cnt; /* Counts aggregated cases. */
158 static void initialize_aggregate_info (struct agr_proc *,
159 const struct ccase *);
161 static void accumulate_aggregate_info (struct agr_proc *,
162 const struct ccase *);
164 static bool parse_aggregate_functions (struct lexer *, const struct dictionary *,
166 static void agr_destroy (struct agr_proc *);
167 static void dump_aggregate_info (struct agr_proc *agr,
168 struct casewriter *output);
172 /* Parses and executes the AGGREGATE procedure. */
174 cmd_aggregate (struct lexer *lexer, struct dataset *ds)
176 struct dictionary *dict = dataset_dict (ds);
178 struct file_handle *out_file = NULL;
179 struct casereader *input = NULL, *group;
180 struct casegrouper *grouper;
181 struct casewriter *output = NULL;
183 bool copy_documents = false;
184 bool presorted = false;
188 memset(&agr, 0 , sizeof (agr));
189 agr.missing = ITEMWISE;
190 case_nullify (&agr.break_case);
192 agr.dict = dict_create ();
194 subcase_init_empty (&agr.sort);
195 dict_set_label (agr.dict, dict_get_label (dict));
196 dict_set_documents (agr.dict, dict_get_documents (dict));
198 /* OUTFILE subcommand must be first. */
199 if (!lex_force_match_id (lexer, "OUTFILE"))
201 lex_match (lexer, '=');
202 if (!lex_match (lexer, '*'))
204 out_file = fh_parse (lexer, FH_REF_FILE | FH_REF_SCRATCH);
205 if (out_file == NULL)
209 /* Read most of the subcommands. */
212 lex_match (lexer, '/');
214 if (lex_match_id (lexer, "MISSING"))
216 lex_match (lexer, '=');
217 if (!lex_match_id (lexer, "COLUMNWISE"))
219 lex_error (lexer, _("while expecting COLUMNWISE"));
222 agr.missing = COLUMNWISE;
224 else if (lex_match_id (lexer, "DOCUMENT"))
225 copy_documents = true;
226 else if (lex_match_id (lexer, "PRESORTED"))
228 else if (lex_match_id (lexer, "BREAK"))
232 lex_match (lexer, '=');
233 if (!parse_sort_criteria (lexer, dict, &agr.sort, &agr.break_vars,
236 agr.break_var_cnt = subcase_get_n_fields (&agr.sort);
238 for (i = 0; i < agr.break_var_cnt; i++)
239 dict_clone_var_assert (agr.dict, agr.break_vars[i],
240 var_get_name (agr.break_vars[i]));
242 /* BREAK must follow the options. */
247 lex_error (lexer, _("expecting BREAK"));
251 if (presorted && saw_direction)
252 msg (SW, _("When PRESORTED is specified, specifying sorting directions "
253 "with (A) or (D) has no effect. Output data will be sorted "
254 "the same way as the input data."));
256 /* Read in the aggregate functions. */
257 lex_match (lexer, '/');
258 if (!parse_aggregate_functions (lexer, dict, &agr))
261 /* Delete documents. */
263 dict_clear_documents (agr.dict);
265 /* Cancel SPLIT FILE. */
266 dict_set_split_vars (agr.dict, NULL, 0);
271 if (out_file == NULL)
273 /* The active file will be replaced by the aggregated data,
274 so TEMPORARY is moot. */
275 proc_cancel_temporary_transformations (ds);
276 proc_discard_output (ds);
277 output = autopaging_writer_create (dict_get_next_value_idx (agr.dict));
281 output = any_writer_open (out_file, agr.dict);
286 input = proc_open (ds);
287 if (!subcase_is_empty (&agr.sort) && !presorted)
289 input = sort_execute (input, &agr.sort);
290 subcase_clear (&agr.sort);
293 for (grouper = casegrouper_create_vars (input, agr.break_vars,
295 casegrouper_get_next_group (grouper, &group);
296 casereader_destroy (group))
300 if (!casereader_peek (group, 0, &c))
302 casereader_destroy (group);
305 initialize_aggregate_info (&agr, &c);
308 for (; casereader_read (group, &c); case_destroy (&c))
309 accumulate_aggregate_info (&agr, &c);
310 dump_aggregate_info (&agr, output);
312 if (!casegrouper_destroy (grouper))
315 if (!proc_commit (ds))
322 if (out_file == NULL)
324 struct casereader *next_input = casewriter_make_reader (output);
325 if (next_input == NULL)
328 proc_set_active_file (ds, next_input, agr.dict);
333 ok = casewriter_destroy (output);
346 casewriter_destroy (output);
349 return CMD_CASCADING_FAILURE;
352 /* Parse all the aggregate functions. */
354 parse_aggregate_functions (struct lexer *lexer, const struct dictionary *dict,
355 struct agr_proc *agr)
357 struct agr_var *tail; /* Tail of linked list starting at agr->vars. */
359 /* Parse everything. */
366 struct string function_name;
368 enum mv_class exclude;
369 const struct agr_func *function;
372 union agr_argument arg[2];
374 const struct variable **src;
387 ds_init_empty (&function_name);
389 /* Parse the list of target variables. */
390 while (!lex_match (lexer, '='))
392 size_t n_dest_prev = n_dest;
394 if (!parse_DATA_LIST_vars (lexer, &dest, &n_dest,
395 PV_APPEND | PV_SINGLE | PV_NO_SCRATCH))
398 /* Assign empty labels. */
402 dest_label = xnrealloc (dest_label, n_dest, sizeof *dest_label);
403 for (j = n_dest_prev; j < n_dest; j++)
404 dest_label[j] = NULL;
409 if (lex_token (lexer) == T_STRING)
412 ds_init_string (&label, lex_tokstr (lexer));
414 ds_truncate (&label, 255);
415 dest_label[n_dest - 1] = ds_xstrdup (&label);
421 /* Get the name of the aggregation function. */
422 if (lex_token (lexer) != T_ID)
424 lex_error (lexer, _("expecting aggregation function"));
430 ds_assign_string (&function_name, lex_tokstr (lexer));
432 ds_chomp (&function_name, '.');
434 if (lex_tokid(lexer)[strlen (lex_tokid (lexer)) - 1] == '.')
437 for (function = agr_func_tab; function->name; function++)
438 if (!strcasecmp (function->name, ds_cstr (&function_name)))
440 if (NULL == function->name)
442 msg (SE, _("Unknown aggregation function %s."),
443 ds_cstr (&function_name));
446 ds_destroy (&function_name);
447 func_index = function - agr_func_tab;
450 /* Check for leading lparen. */
451 if (!lex_match (lexer, '('))
454 func_index = N_NO_VARS;
455 else if (func_index == NU)
456 func_index = NU_NO_VARS;
459 lex_error (lexer, _("expecting `('"));
465 /* Parse list of source variables. */
467 int pv_opts = PV_NO_SCRATCH;
469 if (func_index == SUM || func_index == MEAN || func_index == SD)
470 pv_opts |= PV_NUMERIC;
471 else if (function->n_args)
472 pv_opts |= PV_SAME_TYPE;
474 if (!parse_variables_const (lexer, dict, &src, &n_src, pv_opts))
478 /* Parse function arguments, for those functions that
479 require arguments. */
480 if (function->n_args != 0)
481 for (i = 0; i < function->n_args; i++)
485 lex_match (lexer, ',');
486 if (lex_token (lexer) == T_STRING)
488 arg[i].c = ds_xstrdup (lex_tokstr (lexer));
491 else if (lex_is_number (lexer))
493 arg[i].f = lex_tokval (lexer);
498 msg (SE, _("Missing argument %zu to %s."),
499 i + 1, function->name);
505 if (type != var_get_type (src[0]))
507 msg (SE, _("Arguments to %s must be of same type as "
508 "source variables."),
514 /* Trailing rparen. */
515 if (!lex_match (lexer, ')'))
517 lex_error (lexer, _("expecting `)'"));
521 /* Now check that the number of source variables match
522 the number of target variables. If we check earlier
523 than this, the user can get very misleading error
524 message, i.e. `AGGREGATE x=SUM(y t).' will get this
525 error message when a proper message would be more
526 like `unknown variable t'. */
529 msg (SE, _("Number of source variables (%zu) does not match "
530 "number of target variables (%zu)."),
535 if ((func_index == PIN || func_index == POUT
536 || func_index == FIN || func_index == FOUT)
537 && (var_is_numeric (src[0])
538 ? arg[0].f > arg[1].f
539 : str_compare_rpad (arg[0].c, arg[1].c) > 0))
541 union agr_argument t = arg[0];
545 msg (SW, _("The value arguments passed to the %s function "
546 "are out-of-order. They will be treated as if "
547 "they had been specified in the correct order."),
552 /* Finally add these to the linked list of aggregation
554 for (i = 0; i < n_dest; i++)
556 struct agr_var *v = xzalloc (sizeof *v);
558 /* Add variable to chain. */
559 if (agr->agr_vars != NULL)
567 /* Create the target variable in the aggregate
570 struct variable *destvar;
572 v->function = func_index;
578 if (var_is_alpha (src[i]))
580 v->function |= FSTRING;
581 v->string = xmalloc (var_get_width (src[i]));
584 if (function->alpha_type == VAL_STRING)
585 destvar = dict_clone_var (agr->dict, v->src, dest[i]);
588 assert (var_is_numeric (v->src)
589 || function->alpha_type == VAL_NUMERIC);
590 destvar = dict_create_var (agr->dict, dest[i], 0);
594 if ((func_index == N || func_index == NMISS)
595 && dict_get_weight (dict) != NULL)
596 f = fmt_for_output (FMT_F, 8, 2);
598 f = function->format;
599 var_set_both_formats (destvar, &f);
605 destvar = dict_create_var (agr->dict, dest[i], 0);
606 if (func_index == N_NO_VARS && dict_get_weight (dict) != NULL)
607 f = fmt_for_output (FMT_F, 8, 2);
609 f = function->format;
610 var_set_both_formats (destvar, &f);
615 msg (SE, _("Variable name %s is not unique within the "
616 "aggregate file dictionary, which contains "
617 "the aggregate variables and the break "
625 var_set_label (destvar, dest_label[i]);
630 v->exclude = exclude;
636 if (var_is_numeric (v->src))
637 for (j = 0; j < function->n_args; j++)
638 v->arg[j].f = arg[j].f;
640 for (j = 0; j < function->n_args; j++)
641 v->arg[j].c = xstrdup (arg[j].c);
645 if (src != NULL && var_is_alpha (src[0]))
646 for (i = 0; i < function->n_args; i++)
656 if (!lex_match (lexer, '/'))
658 if (lex_token (lexer) == '.')
661 lex_error (lexer, "expecting end of command");
667 ds_destroy (&function_name);
668 for (i = 0; i < n_dest; i++)
671 free (dest_label[i]);
677 if (src && n_src && var_is_alpha (src[0]))
678 for (i = 0; i < function->n_args; i++)
691 agr_destroy (struct agr_proc *agr)
693 struct agr_var *iter, *next;
695 subcase_destroy (&agr->sort);
696 free (agr->break_vars);
697 case_destroy (&agr->break_case);
698 for (iter = agr->agr_vars; iter; iter = next)
702 if (iter->function & FSTRING)
707 n_args = agr_func_tab[iter->function & FUNC].n_args;
708 for (i = 0; i < n_args; i++)
709 free (iter->arg[i].c);
712 else if (iter->function == SD)
713 moments1_destroy (iter->moments);
715 var_destroy (iter->subject);
716 var_destroy (iter->weight);
720 if (agr->dict != NULL)
721 dict_destroy (agr->dict);
726 /* Accumulates aggregation data from the case INPUT. */
728 accumulate_aggregate_info (struct agr_proc *agr, const struct ccase *input)
730 struct agr_var *iter;
732 bool bad_warn = true;
734 weight = dict_get_case_weight (agr->src_dict, input, &bad_warn);
736 for (iter = agr->agr_vars; iter; iter = iter->next)
739 const union value *v = case_data (input, iter->src);
740 int src_width = var_get_width (iter->src);
742 if (var_is_value_missing (iter->src, v, iter->exclude))
744 switch (iter->function)
747 case NMISS | FSTRING:
748 iter->dbl[0] += weight;
751 case NUMISS | FSTRING:
755 iter->saw_missing = true;
759 /* This is horrible. There are too many possibilities. */
760 switch (iter->function)
763 iter->dbl[0] += v->f * weight;
767 iter->dbl[0] += v->f * weight;
768 iter->dbl[1] += weight;
774 case_create (&cout, 2);
776 case_data_rw (&cout, iter->subject)->f =
777 case_data (input, iter->src)->f;
779 wv = dict_get_case_weight (agr->src_dict, input, NULL);
781 case_data_rw (&cout, iter->weight)->f = wv;
785 casewriter_write (iter->writer, &cout);
786 case_destroy (&cout);
790 moments1_add (iter->moments, v->f, weight);
793 iter->dbl[0] = MAX (iter->dbl[0], v->f);
797 if (memcmp (iter->string, v->s, src_width) < 0)
798 memcpy (iter->string, v->s, src_width);
802 iter->dbl[0] = MIN (iter->dbl[0], v->f);
806 if (memcmp (iter->string, v->s, src_width) > 0)
807 memcpy (iter->string, v->s, src_width);
812 if (v->f > iter->arg[0].f)
813 iter->dbl[0] += weight;
814 iter->dbl[1] += weight;
818 if (memcmp (iter->arg[0].c, v->s, src_width) < 0)
819 iter->dbl[0] += weight;
820 iter->dbl[1] += weight;
824 if (v->f < iter->arg[0].f)
825 iter->dbl[0] += weight;
826 iter->dbl[1] += weight;
830 if (memcmp (iter->arg[0].c, v->s, src_width) > 0)
831 iter->dbl[0] += weight;
832 iter->dbl[1] += weight;
836 if (iter->arg[0].f <= v->f && v->f <= iter->arg[1].f)
837 iter->dbl[0] += weight;
838 iter->dbl[1] += weight;
842 if (memcmp (iter->arg[0].c, v->s, src_width) <= 0
843 && memcmp (iter->arg[1].c, v->s, src_width) >= 0)
844 iter->dbl[0] += weight;
845 iter->dbl[1] += weight;
849 if (iter->arg[0].f > v->f || v->f > iter->arg[1].f)
850 iter->dbl[0] += weight;
851 iter->dbl[1] += weight;
855 if (memcmp (iter->arg[0].c, v->s, src_width) > 0
856 || memcmp (iter->arg[1].c, v->s, src_width) < 0)
857 iter->dbl[0] += weight;
858 iter->dbl[1] += weight;
862 iter->dbl[0] += weight;
875 case FIRST | FSTRING:
878 memcpy (iter->string, v->s, src_width);
887 memcpy (iter->string, v->s, src_width);
891 case NMISS | FSTRING:
893 case NUMISS | FSTRING:
894 /* Our value is not missing or it would have been
895 caught earlier. Nothing to do. */
901 switch (iter->function)
904 iter->dbl[0] += weight;
915 /* Writes an aggregated record to OUTPUT. */
917 dump_aggregate_info (struct agr_proc *agr, struct casewriter *output)
921 case_create (&c, dict_get_next_value_idx (agr->dict));
927 for (i = 0; i < agr->break_var_cnt; i++)
929 const struct variable *v = agr->break_vars[i];
930 size_t value_cnt = var_get_value_cnt (v);
931 memcpy (case_data_rw_idx (&c, value_idx),
932 case_data (&agr->break_case, v),
933 sizeof (union value) * value_cnt);
934 value_idx += value_cnt;
941 for (i = agr->agr_vars; i; i = i->next)
943 union value *v = case_data_rw (&c, i->dest);
946 if (agr->missing == COLUMNWISE && i->saw_missing
947 && (i->function & FUNC) != N && (i->function & FUNC) != NU
948 && (i->function & FUNC) != NMISS && (i->function & FUNC) != NUMISS)
950 if (var_is_alpha (i->dest))
951 memset (v->s, ' ', var_get_width (i->dest));
955 casewriter_destroy (i->writer);
963 v->f = i->int1 ? i->dbl[0] : SYSMIS;
966 v->f = i->dbl[1] != 0.0 ? i->dbl[0] / i->dbl[1] : SYSMIS;
970 struct casereader *sorted_reader;
971 struct order_stats *median = percentile_create (0.5, i->cc);
973 sorted_reader = casewriter_make_reader (i->writer);
975 order_stats_accumulate (&median, 1,
981 v->f = percentile_calculate ((struct percentile *) median,
984 statistic_destroy ((struct statistic *) median);
991 /* FIXME: we should use two passes. */
992 moments1_calculate (i->moments, NULL, NULL, &variance,
994 if (variance != SYSMIS)
995 v->f = sqrt (variance);
1002 v->f = i->int1 ? i->dbl[0] : SYSMIS;
1007 memcpy (v->s, i->string, var_get_width (i->dest));
1009 memset (v->s, ' ', var_get_width (i->dest));
1018 case FOUT | FSTRING:
1019 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] : SYSMIS;
1028 case POUT | FSTRING:
1029 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] * 100.0 : SYSMIS;
1041 v->f = i->int1 ? i->dbl[0] : SYSMIS;
1043 case FIRST | FSTRING:
1044 case LAST | FSTRING:
1046 memcpy (v->s, i->string, var_get_width (i->dest));
1048 memset (v->s, ' ', var_get_width (i->dest));
1057 case NMISS | FSTRING:
1061 case NUMISS | FSTRING:
1070 casewriter_write (output, &c);
1073 /* Resets the state for all the aggregate functions. */
1075 initialize_aggregate_info (struct agr_proc *agr, const struct ccase *input)
1077 struct agr_var *iter;
1079 case_destroy (&agr->break_case);
1080 case_clone (&agr->break_case, input);
1082 for (iter = agr->agr_vars; iter; iter = iter->next)
1084 iter->saw_missing = false;
1085 iter->dbl[0] = iter->dbl[1] = iter->dbl[2] = 0.0;
1086 iter->int1 = iter->int2 = 0;
1087 switch (iter->function)
1090 iter->dbl[0] = DBL_MAX;
1093 memset (iter->string, 255, var_get_width (iter->src));
1096 iter->dbl[0] = -DBL_MAX;
1099 memset (iter->string, 0, var_get_width (iter->src));
1103 struct subcase ordering;
1105 if ( ! iter->subject)
1106 iter->subject = var_create_internal (0);
1108 if ( ! iter->weight)
1109 iter->weight = var_create_internal (1);
1111 subcase_init_var (&ordering, iter->subject, SC_ASCEND);
1112 iter->writer = sort_create_writer (&ordering, 2);
1113 subcase_destroy (&ordering);
1119 if (iter->moments == NULL)
1120 iter->moments = moments1_create (MOMENT_VARIANCE);
1122 moments1_clear (iter->moments);