1 /* PSPP - a program for statistical analysis.
2 Copyright (C) 1997-9, 2000, 2006, 2008 Free Software Foundation, Inc.
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program. If not, see <http://www.gnu.org/licenses/>. */
21 #include <data/any-writer.h>
22 #include <data/case-ordering.h>
23 #include <data/case.h>
24 #include <data/casegrouper.h>
25 #include <data/casereader.h>
26 #include <data/casewriter.h>
27 #include <data/dictionary.h>
28 #include <data/file-handle-def.h>
29 #include <data/format.h>
30 #include <data/procedure.h>
31 #include <data/settings.h>
32 #include <data/sys-file-writer.h>
33 #include <data/variable.h>
34 #include <language/command.h>
35 #include <language/data-io/file-handle.h>
36 #include <language/lexer/lexer.h>
37 #include <language/lexer/variable-parser.h>
38 #include <language/stats/sort-criteria.h>
39 #include <libpspp/assertion.h>
40 #include <libpspp/message.h>
41 #include <libpspp/misc.h>
42 #include <libpspp/pool.h>
43 #include <libpspp/str.h>
44 #include <math/moments.h>
45 #include <math/sort.h>
46 #include <math/statistic.h>
47 #include <math/percentiles.h>
53 #define _(msgid) gettext (msgid)
55 /* Argument for AGGREGATE function. */
58 double f; /* Numeric. */
59 char *c; /* Short or long string. */
62 /* Specifies how to make an aggregate variable. */
65 struct agr_var *next; /* Next in list. */
67 /* Collected during parsing. */
68 const struct variable *src; /* Source variable. */
69 struct variable *dest; /* Target variable. */
70 int function; /* Function. */
71 enum mv_class exclude; /* Classes of missing values to exclude. */
72 union agr_argument arg[2]; /* Arguments. */
74 /* Accumulated during AGGREGATE execution. */
79 struct moments1 *moments;
82 struct variable *subject;
83 struct variable *weight;
84 struct casewriter *writer;
87 /* Aggregation functions. */
90 NONE, SUM, MEAN, MEDIAN, SD, MAX, MIN, PGT, PLT, PIN, POUT, FGT, FLT, FIN,
91 FOUT, N, NU, NMISS, NUMISS, FIRST, LAST,
92 N_AGR_FUNCS, N_NO_VARS, NU_NO_VARS,
93 FUNC = 0x1f, /* Function mask. */
94 FSTRING = 1<<5, /* String function bit. */
97 /* Attributes of an aggregation function. */
100 const char *name; /* Aggregation function name. */
101 size_t n_args; /* Number of arguments. */
102 enum val_type alpha_type; /* When given ALPHA arguments, output type. */
103 struct fmt_spec format; /* Format spec if alpha_type != ALPHA. */
106 /* Attributes of aggregation functions. */
107 static const struct agr_func agr_func_tab[] =
109 {"<NONE>", 0, -1, {0, 0, 0}},
110 {"SUM", 0, -1, {FMT_F, 8, 2}},
111 {"MEAN", 0, -1, {FMT_F, 8, 2}},
112 {"MEDIAN", 0, -1, {FMT_F, 8, 2}},
113 {"SD", 0, -1, {FMT_F, 8, 2}},
114 {"MAX", 0, VAL_STRING, {-1, -1, -1}},
115 {"MIN", 0, VAL_STRING, {-1, -1, -1}},
116 {"PGT", 1, VAL_NUMERIC, {FMT_F, 5, 1}},
117 {"PLT", 1, VAL_NUMERIC, {FMT_F, 5, 1}},
118 {"PIN", 2, VAL_NUMERIC, {FMT_F, 5, 1}},
119 {"POUT", 2, VAL_NUMERIC, {FMT_F, 5, 1}},
120 {"FGT", 1, VAL_NUMERIC, {FMT_F, 5, 3}},
121 {"FLT", 1, VAL_NUMERIC, {FMT_F, 5, 3}},
122 {"FIN", 2, VAL_NUMERIC, {FMT_F, 5, 3}},
123 {"FOUT", 2, VAL_NUMERIC, {FMT_F, 5, 3}},
124 {"N", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
125 {"NU", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
126 {"NMISS", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
127 {"NUMISS", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
128 {"FIRST", 0, VAL_STRING, {-1, -1, -1}},
129 {"LAST", 0, VAL_STRING, {-1, -1, -1}},
130 {NULL, 0, -1, {-1, -1, -1}},
131 {"N", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
132 {"NU", 0, VAL_NUMERIC, {FMT_F, 7, 0}},
135 /* Missing value types. */
136 enum missing_treatment
138 ITEMWISE, /* Missing values item by item. */
139 COLUMNWISE /* Missing values column by column. */
142 /* An entire AGGREGATE procedure. */
145 /* Break variables. */
146 struct case_ordering *sort; /* Sort criteria (break variable). */
147 const struct variable **break_vars; /* Break variables. */
148 size_t break_var_cnt; /* Number of break variables. */
149 struct ccase break_case; /* Last values of break variables. */
151 enum missing_treatment missing; /* How to treat missing values. */
152 struct agr_var *agr_vars; /* First aggregate variable. */
153 struct dictionary *dict; /* Aggregate dictionary. */
154 const struct dictionary *src_dict; /* Dict of the source */
155 int case_cnt; /* Counts aggregated cases. */
158 static void initialize_aggregate_info (struct agr_proc *,
159 const struct ccase *);
161 static void accumulate_aggregate_info (struct agr_proc *,
162 const struct ccase *);
164 static bool parse_aggregate_functions (struct lexer *, const struct dictionary *,
166 static void agr_destroy (struct agr_proc *);
167 static void dump_aggregate_info (struct agr_proc *agr,
168 struct casewriter *output);
172 /* Parses and executes the AGGREGATE procedure. */
174 cmd_aggregate (struct lexer *lexer, struct dataset *ds)
176 struct dictionary *dict = dataset_dict (ds);
178 struct file_handle *out_file = NULL;
179 struct casereader *input = NULL, *group;
180 struct casegrouper *grouper;
181 struct casewriter *output = NULL;
183 bool copy_documents = false;
184 bool presorted = false;
188 memset(&agr, 0 , sizeof (agr));
189 agr.missing = ITEMWISE;
190 case_nullify (&agr.break_case);
192 agr.dict = dict_create ();
194 dict_set_label (agr.dict, dict_get_label (dict));
195 dict_set_documents (agr.dict, dict_get_documents (dict));
197 /* OUTFILE subcommand must be first. */
198 if (!lex_force_match_id (lexer, "OUTFILE"))
200 lex_match (lexer, '=');
201 if (!lex_match (lexer, '*'))
203 out_file = fh_parse (lexer, FH_REF_FILE | FH_REF_SCRATCH);
204 if (out_file == NULL)
208 /* Read most of the subcommands. */
211 lex_match (lexer, '/');
213 if (lex_match_id (lexer, "MISSING"))
215 lex_match (lexer, '=');
216 if (!lex_match_id (lexer, "COLUMNWISE"))
218 lex_error (lexer, _("while expecting COLUMNWISE"));
221 agr.missing = COLUMNWISE;
223 else if (lex_match_id (lexer, "DOCUMENT"))
224 copy_documents = true;
225 else if (lex_match_id (lexer, "PRESORTED"))
227 else if (lex_match_id (lexer, "BREAK"))
231 lex_match (lexer, '=');
232 agr.sort = parse_case_ordering (lexer, dict,
235 if (agr.sort == NULL)
237 case_ordering_get_vars (agr.sort,
238 &agr.break_vars, &agr.break_var_cnt);
240 for (i = 0; i < agr.break_var_cnt; i++)
241 dict_clone_var_assert (agr.dict, agr.break_vars[i],
242 var_get_name (agr.break_vars[i]));
244 /* BREAK must follow the options. */
249 lex_error (lexer, _("expecting BREAK"));
253 if (presorted && saw_direction)
254 msg (SW, _("When PRESORTED is specified, specifying sorting directions "
255 "with (A) or (D) has no effect. Output data will be sorted "
256 "the same way as the input data."));
258 /* Read in the aggregate functions. */
259 lex_match (lexer, '/');
260 if (!parse_aggregate_functions (lexer, dict, &agr))
263 /* Delete documents. */
265 dict_clear_documents (agr.dict);
267 /* Cancel SPLIT FILE. */
268 dict_set_split_vars (agr.dict, NULL, 0);
273 if (out_file == NULL)
275 /* The active file will be replaced by the aggregated data,
276 so TEMPORARY is moot. */
277 proc_cancel_temporary_transformations (ds);
278 proc_discard_output (ds);
279 output = autopaging_writer_create (dict_get_next_value_idx (agr.dict));
283 output = any_writer_open (out_file, agr.dict);
288 input = proc_open (ds);
289 if (agr.sort != NULL && !presorted)
291 input = sort_execute (input, agr.sort);
295 for (grouper = casegrouper_create_vars (input, agr.break_vars,
297 casegrouper_get_next_group (grouper, &group);
298 casereader_destroy (group))
302 if (!casereader_peek (group, 0, &c))
304 casereader_destroy (group);
307 initialize_aggregate_info (&agr, &c);
310 for (; casereader_read (group, &c); case_destroy (&c))
311 accumulate_aggregate_info (&agr, &c);
312 dump_aggregate_info (&agr, output);
314 if (!casegrouper_destroy (grouper))
317 if (!proc_commit (ds))
324 if (out_file == NULL)
326 struct casereader *next_input = casewriter_make_reader (output);
327 if (next_input == NULL)
330 proc_set_active_file (ds, next_input, agr.dict);
335 ok = casewriter_destroy (output);
348 casewriter_destroy (output);
351 return CMD_CASCADING_FAILURE;
354 /* Parse all the aggregate functions. */
356 parse_aggregate_functions (struct lexer *lexer, const struct dictionary *dict,
357 struct agr_proc *agr)
359 struct agr_var *tail; /* Tail of linked list starting at agr->vars. */
361 /* Parse everything. */
368 struct string function_name;
370 enum mv_class exclude;
371 const struct agr_func *function;
374 union agr_argument arg[2];
376 const struct variable **src;
389 ds_init_empty (&function_name);
391 /* Parse the list of target variables. */
392 while (!lex_match (lexer, '='))
394 size_t n_dest_prev = n_dest;
396 if (!parse_DATA_LIST_vars (lexer, &dest, &n_dest,
397 PV_APPEND | PV_SINGLE | PV_NO_SCRATCH))
400 /* Assign empty labels. */
404 dest_label = xnrealloc (dest_label, n_dest, sizeof *dest_label);
405 for (j = n_dest_prev; j < n_dest; j++)
406 dest_label[j] = NULL;
411 if (lex_token (lexer) == T_STRING)
414 ds_init_string (&label, lex_tokstr (lexer));
416 ds_truncate (&label, 255);
417 dest_label[n_dest - 1] = ds_xstrdup (&label);
423 /* Get the name of the aggregation function. */
424 if (lex_token (lexer) != T_ID)
426 lex_error (lexer, _("expecting aggregation function"));
432 ds_assign_string (&function_name, lex_tokstr (lexer));
434 ds_chomp (&function_name, '.');
436 if (lex_tokid(lexer)[strlen (lex_tokid (lexer)) - 1] == '.')
439 for (function = agr_func_tab; function->name; function++)
440 if (!strcasecmp (function->name, ds_cstr (&function_name)))
442 if (NULL == function->name)
444 msg (SE, _("Unknown aggregation function %s."),
445 ds_cstr (&function_name));
448 ds_destroy (&function_name);
449 func_index = function - agr_func_tab;
452 /* Check for leading lparen. */
453 if (!lex_match (lexer, '('))
456 func_index = N_NO_VARS;
457 else if (func_index == NU)
458 func_index = NU_NO_VARS;
461 lex_error (lexer, _("expecting `('"));
467 /* Parse list of source variables. */
469 int pv_opts = PV_NO_SCRATCH;
471 if (func_index == SUM || func_index == MEAN || func_index == SD)
472 pv_opts |= PV_NUMERIC;
473 else if (function->n_args)
474 pv_opts |= PV_SAME_TYPE;
476 if (!parse_variables_const (lexer, dict, &src, &n_src, pv_opts))
480 /* Parse function arguments, for those functions that
481 require arguments. */
482 if (function->n_args != 0)
483 for (i = 0; i < function->n_args; i++)
487 lex_match (lexer, ',');
488 if (lex_token (lexer) == T_STRING)
490 arg[i].c = ds_xstrdup (lex_tokstr (lexer));
493 else if (lex_is_number (lexer))
495 arg[i].f = lex_tokval (lexer);
500 msg (SE, _("Missing argument %zu to %s."),
501 i + 1, function->name);
507 if (type != var_get_type (src[0]))
509 msg (SE, _("Arguments to %s must be of same type as "
510 "source variables."),
516 /* Trailing rparen. */
517 if (!lex_match (lexer, ')'))
519 lex_error (lexer, _("expecting `)'"));
523 /* Now check that the number of source variables match
524 the number of target variables. If we check earlier
525 than this, the user can get very misleading error
526 message, i.e. `AGGREGATE x=SUM(y t).' will get this
527 error message when a proper message would be more
528 like `unknown variable t'. */
531 msg (SE, _("Number of source variables (%zu) does not match "
532 "number of target variables (%zu)."),
537 if ((func_index == PIN || func_index == POUT
538 || func_index == FIN || func_index == FOUT)
539 && (var_is_numeric (src[0])
540 ? arg[0].f > arg[1].f
541 : str_compare_rpad (arg[0].c, arg[1].c) > 0))
543 union agr_argument t = arg[0];
547 msg (SW, _("The value arguments passed to the %s function "
548 "are out-of-order. They will be treated as if "
549 "they had been specified in the correct order."),
554 /* Finally add these to the linked list of aggregation
556 for (i = 0; i < n_dest; i++)
558 struct agr_var *v = xzalloc (sizeof *v);
560 /* Add variable to chain. */
561 if (agr->agr_vars != NULL)
569 /* Create the target variable in the aggregate
572 struct variable *destvar;
574 v->function = func_index;
580 if (var_is_alpha (src[i]))
582 v->function |= FSTRING;
583 v->string = xmalloc (var_get_width (src[i]));
586 if (function->alpha_type == VAL_STRING)
587 destvar = dict_clone_var (agr->dict, v->src, dest[i]);
590 assert (var_is_numeric (v->src)
591 || function->alpha_type == VAL_NUMERIC);
592 destvar = dict_create_var (agr->dict, dest[i], 0);
596 if ((func_index == N || func_index == NMISS)
597 && dict_get_weight (dict) != NULL)
598 f = fmt_for_output (FMT_F, 8, 2);
600 f = function->format;
601 var_set_both_formats (destvar, &f);
607 destvar = dict_create_var (agr->dict, dest[i], 0);
608 if (func_index == N_NO_VARS && dict_get_weight (dict) != NULL)
609 f = fmt_for_output (FMT_F, 8, 2);
611 f = function->format;
612 var_set_both_formats (destvar, &f);
617 msg (SE, _("Variable name %s is not unique within the "
618 "aggregate file dictionary, which contains "
619 "the aggregate variables and the break "
627 var_set_label (destvar, dest_label[i]);
632 v->exclude = exclude;
638 if (var_is_numeric (v->src))
639 for (j = 0; j < function->n_args; j++)
640 v->arg[j].f = arg[j].f;
642 for (j = 0; j < function->n_args; j++)
643 v->arg[j].c = xstrdup (arg[j].c);
647 if (src != NULL && var_is_alpha (src[0]))
648 for (i = 0; i < function->n_args; i++)
658 if (!lex_match (lexer, '/'))
660 if (lex_token (lexer) == '.')
663 lex_error (lexer, "expecting end of command");
669 ds_destroy (&function_name);
670 for (i = 0; i < n_dest; i++)
673 free (dest_label[i]);
679 if (src && n_src && var_is_alpha (src[0]))
680 for (i = 0; i < function->n_args; i++)
693 agr_destroy (struct agr_proc *agr)
695 struct agr_var *iter, *next;
697 case_ordering_destroy (agr->sort);
698 free (agr->break_vars);
699 case_destroy (&agr->break_case);
700 for (iter = agr->agr_vars; iter; iter = next)
704 if (iter->function & FSTRING)
709 n_args = agr_func_tab[iter->function & FUNC].n_args;
710 for (i = 0; i < n_args; i++)
711 free (iter->arg[i].c);
714 else if (iter->function == SD)
715 moments1_destroy (iter->moments);
717 var_destroy (iter->subject);
718 var_destroy (iter->weight);
722 if (agr->dict != NULL)
723 dict_destroy (agr->dict);
728 /* Accumulates aggregation data from the case INPUT. */
730 accumulate_aggregate_info (struct agr_proc *agr, const struct ccase *input)
732 struct agr_var *iter;
734 bool bad_warn = true;
736 weight = dict_get_case_weight (agr->src_dict, input, &bad_warn);
738 for (iter = agr->agr_vars; iter; iter = iter->next)
741 const union value *v = case_data (input, iter->src);
742 int src_width = var_get_width (iter->src);
744 if (var_is_value_missing (iter->src, v, iter->exclude))
746 switch (iter->function)
749 case NMISS | FSTRING:
750 iter->dbl[0] += weight;
753 case NUMISS | FSTRING:
757 iter->saw_missing = true;
761 /* This is horrible. There are too many possibilities. */
762 switch (iter->function)
765 iter->dbl[0] += v->f * weight;
769 iter->dbl[0] += v->f * weight;
770 iter->dbl[1] += weight;
776 case_create (&cout, 2);
778 case_data_rw (&cout, iter->subject)->f =
779 case_data (input, iter->src)->f;
781 wv = dict_get_case_weight (agr->src_dict, input, NULL);
783 case_data_rw (&cout, iter->weight)->f = wv;
787 casewriter_write (iter->writer, &cout);
788 case_destroy (&cout);
792 moments1_add (iter->moments, v->f, weight);
795 iter->dbl[0] = MAX (iter->dbl[0], v->f);
799 if (memcmp (iter->string, v->s, src_width) < 0)
800 memcpy (iter->string, v->s, src_width);
804 iter->dbl[0] = MIN (iter->dbl[0], v->f);
808 if (memcmp (iter->string, v->s, src_width) > 0)
809 memcpy (iter->string, v->s, src_width);
814 if (v->f > iter->arg[0].f)
815 iter->dbl[0] += weight;
816 iter->dbl[1] += weight;
820 if (memcmp (iter->arg[0].c, v->s, src_width) < 0)
821 iter->dbl[0] += weight;
822 iter->dbl[1] += weight;
826 if (v->f < iter->arg[0].f)
827 iter->dbl[0] += weight;
828 iter->dbl[1] += weight;
832 if (memcmp (iter->arg[0].c, v->s, src_width) > 0)
833 iter->dbl[0] += weight;
834 iter->dbl[1] += weight;
838 if (iter->arg[0].f <= v->f && v->f <= iter->arg[1].f)
839 iter->dbl[0] += weight;
840 iter->dbl[1] += weight;
844 if (memcmp (iter->arg[0].c, v->s, src_width) <= 0
845 && memcmp (iter->arg[1].c, v->s, src_width) >= 0)
846 iter->dbl[0] += weight;
847 iter->dbl[1] += weight;
851 if (iter->arg[0].f > v->f || v->f > iter->arg[1].f)
852 iter->dbl[0] += weight;
853 iter->dbl[1] += weight;
857 if (memcmp (iter->arg[0].c, v->s, src_width) > 0
858 || memcmp (iter->arg[1].c, v->s, src_width) < 0)
859 iter->dbl[0] += weight;
860 iter->dbl[1] += weight;
864 iter->dbl[0] += weight;
877 case FIRST | FSTRING:
880 memcpy (iter->string, v->s, src_width);
889 memcpy (iter->string, v->s, src_width);
893 case NMISS | FSTRING:
895 case NUMISS | FSTRING:
896 /* Our value is not missing or it would have been
897 caught earlier. Nothing to do. */
903 switch (iter->function)
906 iter->dbl[0] += weight;
917 /* Writes an aggregated record to OUTPUT. */
919 dump_aggregate_info (struct agr_proc *agr, struct casewriter *output)
923 case_create (&c, dict_get_next_value_idx (agr->dict));
929 for (i = 0; i < agr->break_var_cnt; i++)
931 const struct variable *v = agr->break_vars[i];
932 size_t value_cnt = var_get_value_cnt (v);
933 memcpy (case_data_rw_idx (&c, value_idx),
934 case_data (&agr->break_case, v),
935 sizeof (union value) * value_cnt);
936 value_idx += value_cnt;
943 for (i = agr->agr_vars; i; i = i->next)
945 union value *v = case_data_rw (&c, i->dest);
948 if (agr->missing == COLUMNWISE && i->saw_missing
949 && (i->function & FUNC) != N && (i->function & FUNC) != NU
950 && (i->function & FUNC) != NMISS && (i->function & FUNC) != NUMISS)
952 if (var_is_alpha (i->dest))
953 memset (v->s, ' ', var_get_width (i->dest));
957 casewriter_destroy (i->writer);
965 v->f = i->int1 ? i->dbl[0] : SYSMIS;
968 v->f = i->dbl[1] != 0.0 ? i->dbl[0] / i->dbl[1] : SYSMIS;
972 struct casereader *sorted_reader;
973 struct order_stats *median = percentile_create (0.5, i->cc);
975 sorted_reader = casewriter_make_reader (i->writer);
977 order_stats_accumulate (&median, 1,
983 v->f = percentile_calculate ((struct percentile *) median,
986 statistic_destroy ((struct statistic *) median);
993 /* FIXME: we should use two passes. */
994 moments1_calculate (i->moments, NULL, NULL, &variance,
996 if (variance != SYSMIS)
997 v->f = sqrt (variance);
1004 v->f = i->int1 ? i->dbl[0] : SYSMIS;
1009 memcpy (v->s, i->string, var_get_width (i->dest));
1011 memset (v->s, ' ', var_get_width (i->dest));
1020 case FOUT | FSTRING:
1021 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] : SYSMIS;
1030 case POUT | FSTRING:
1031 v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] * 100.0 : SYSMIS;
1043 v->f = i->int1 ? i->dbl[0] : SYSMIS;
1045 case FIRST | FSTRING:
1046 case LAST | FSTRING:
1048 memcpy (v->s, i->string, var_get_width (i->dest));
1050 memset (v->s, ' ', var_get_width (i->dest));
1059 case NMISS | FSTRING:
1063 case NUMISS | FSTRING:
1072 casewriter_write (output, &c);
1075 /* Resets the state for all the aggregate functions. */
1077 initialize_aggregate_info (struct agr_proc *agr, const struct ccase *input)
1079 struct agr_var *iter;
1081 case_destroy (&agr->break_case);
1082 case_clone (&agr->break_case, input);
1084 for (iter = agr->agr_vars; iter; iter = iter->next)
1086 iter->saw_missing = false;
1087 iter->dbl[0] = iter->dbl[1] = iter->dbl[2] = 0.0;
1088 iter->int1 = iter->int2 = 0;
1089 switch (iter->function)
1092 iter->dbl[0] = DBL_MAX;
1095 memset (iter->string, 255, var_get_width (iter->src));
1098 iter->dbl[0] = -DBL_MAX;
1101 memset (iter->string, 0, var_get_width (iter->src));
1105 struct case_ordering *ordering = case_ordering_create ();
1107 if ( ! iter->subject)
1108 iter->subject = var_create_internal (0);
1110 if ( ! iter->weight)
1111 iter->weight = var_create_internal (1);
1113 case_ordering_add_var (ordering, iter->subject, SRT_ASCEND);
1115 iter->writer = sort_create_writer (ordering, 2);
1120 if (iter->moments == NULL)
1121 iter->moments = moments1_create (MOMENT_VARIANCE);
1123 moments1_clear (iter->moments);