treewide: Replace <name>_cnt by n_<name>s and <name>_cap by allocated_<name>.
[pspp] / src / language / stats / aggregate.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 1997-9, 2000, 2006, 2008, 2009, 2010, 2011, 2012, 2014 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "language/stats/aggregate.h"
20
21 #include <stdlib.h>
22
23 #include "data/any-writer.h"
24 #include "data/case.h"
25 #include "data/casegrouper.h"
26 #include "data/casereader.h"
27 #include "data/casewriter.h"
28 #include "data/dataset.h"
29 #include "data/dictionary.h"
30 #include "data/file-handle-def.h"
31 #include "data/format.h"
32 #include "data/settings.h"
33 #include "data/subcase.h"
34 #include "data/sys-file-writer.h"
35 #include "data/variable.h"
36 #include "language/command.h"
37 #include "language/data-io/file-handle.h"
38 #include "language/lexer/lexer.h"
39 #include "language/lexer/variable-parser.h"
40 #include "language/stats/sort-criteria.h"
41 #include "libpspp/assertion.h"
42 #include "libpspp/i18n.h"
43 #include "libpspp/message.h"
44 #include "libpspp/misc.h"
45 #include "libpspp/pool.h"
46 #include "libpspp/str.h"
47 #include "math/moments.h"
48 #include "math/percentiles.h"
49 #include "math/sort.h"
50 #include "math/statistic.h"
51
52 #include "gl/c-strcase.h"
53 #include "gl/minmax.h"
54 #include "gl/xalloc.h"
55
56 #include "gettext.h"
57 #define _(msgid) gettext (msgid)
58 #define N_(msgid) msgid
59
60 /* Argument for AGGREGATE function. */
61 union agr_argument
62   {
63     double f;                           /* Numeric. */
64     char *c;                            /* Short or long string. */
65   };
66
67 /* Specifies how to make an aggregate variable. */
68 struct agr_var
69   {
70     struct agr_var *next;               /* Next in list. */
71
72     /* Collected during parsing. */
73     const struct variable *src; /* Source variable. */
74     struct variable *dest;      /* Target variable. */
75     int function;               /* Function. */
76     enum mv_class exclude;      /* Classes of missing values to exclude. */
77     union agr_argument arg[2];  /* Arguments. */
78
79     /* Accumulated during AGGREGATE execution. */
80     double dbl[3];
81     int int1, int2;
82     char *string;
83     bool saw_missing;
84     struct moments1 *moments;
85     double cc;
86
87     struct variable *subject;
88     struct variable *weight;
89     struct casewriter *writer;
90   };
91
92
93 /* Attributes of aggregation functions. */
94 const struct agr_func agr_func_tab[] =
95   {
96     {"SUM",     N_("Sum of values"),                         AGR_SV_YES, 0, -1,          { .type = FMT_F, .w = 8, .d = 2 }},
97     {"MEAN",    N_("Mean average"),                          AGR_SV_YES, 0, -1,          { .type = FMT_F, .w = 8, .d = 2 }},
98     {"MEDIAN",  N_("Median average"),                        AGR_SV_YES, 0, -1,          { .type = FMT_F, .w = 8, .d = 2 }},
99     {"SD",      N_("Standard deviation"),                    AGR_SV_YES, 0, -1,          { .type = FMT_F, .w = 8, .d = 2 }},
100     {"MAX",     N_("Maximum value"),                         AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
101     {"MIN",     N_("Minimum value"),                         AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
102     {"PGT",     N_("Percentage greater than"),               AGR_SV_YES, 1, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 1 }},
103     {"PLT",     N_("Percentage less than"),                  AGR_SV_YES, 1, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 1 }},
104     {"PIN",     N_("Percentage included in range"),          AGR_SV_YES, 2, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 1 }},
105     {"POUT",    N_("Percentage excluded from range"),        AGR_SV_YES, 2, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 1 }},
106     {"FGT",     N_("Fraction greater than"),                 AGR_SV_YES, 1, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 3 }},
107     {"FLT",     N_("Fraction less than"),                    AGR_SV_YES, 1, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 3 }},
108     {"FIN",     N_("Fraction included in range"),            AGR_SV_YES, 2, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 3 }},
109     {"FOUT",    N_("Fraction excluded from range"),          AGR_SV_YES, 2, VAL_NUMERIC, { .type = FMT_F, .w = 5, .d = 3 }},
110     {"N",       N_("Number of cases"),                       AGR_SV_NO,  0, VAL_NUMERIC, { .type = FMT_F, .w = 7, .d = 0 }},
111     {"NU",      N_("Number of cases (unweighted)"),          AGR_SV_OPT, 0, VAL_NUMERIC, { .type = FMT_F, .w = 7, .d = 0 }},
112     {"NMISS",   N_("Number of missing values"),              AGR_SV_YES, 0, VAL_NUMERIC, { .type = FMT_F, .w = 7, .d = 0 }},
113     {"NUMISS",  N_("Number of missing values (unweighted)"), AGR_SV_YES, 0, VAL_NUMERIC, { .type = FMT_F, .w = 7, .d = 0 }},
114     {"FIRST",   N_("First non-missing value"),               AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
115     {"LAST",    N_("Last non-missing value"),                AGR_SV_YES, 0, VAL_STRING,  {-1, -1, -1}},
116     {NULL,      NULL,                                        AGR_SV_NO,  0, -1,          {-1, -1, -1}},
117   };
118
119 /* Missing value types. */
120 enum missing_treatment
121   {
122     ITEMWISE,           /* Missing values item by item. */
123     COLUMNWISE          /* Missing values column by column. */
124   };
125
126 /* An entire AGGREGATE procedure. */
127 struct agr_proc
128   {
129     /* Break variables. */
130     struct subcase sort;                /* Sort criteria (break variables). */
131     const struct variable **break_vars;       /* Break variables. */
132     size_t break_n_vars;                /* Number of break variables. */
133
134     enum missing_treatment missing;     /* How to treat missing values. */
135     struct agr_var *agr_vars;           /* First aggregate variable. */
136     struct dictionary *dict;            /* Aggregate dictionary. */
137     const struct dictionary *src_dict;  /* Dict of the source */
138     int n_cases;                        /* Counts aggregated cases. */
139
140     bool add_variables;                 /* True iff the aggregated variables should
141                                            be appended to the existing dictionary */
142   };
143
144 static void initialize_aggregate_info (struct agr_proc *);
145
146 static void accumulate_aggregate_info (struct agr_proc *,
147                                        const struct ccase *);
148 /* Prototypes. */
149 static bool parse_aggregate_functions (struct lexer *, const struct dictionary *,
150                                        struct agr_proc *);
151 static void agr_destroy (struct agr_proc *);
152 static void dump_aggregate_info (const struct agr_proc *agr,
153                                  struct casewriter *output,
154                                  const struct ccase *break_case);
155 \f
156 /* Parsing. */
157
158 /* Parses and executes the AGGREGATE procedure. */
159 int
160 cmd_aggregate (struct lexer *lexer, struct dataset *ds)
161 {
162   struct dictionary *dict = dataset_dict (ds);
163   struct agr_proc agr;
164   struct file_handle *out_file = NULL;
165   struct casereader *input = NULL, *group;
166   struct casegrouper *grouper;
167   struct casewriter *output = NULL;
168
169   bool copy_documents = false;
170   bool presorted = false;
171   bool saw_direction;
172   bool ok;
173
174   memset(&agr, 0 , sizeof (agr));
175   agr.missing = ITEMWISE;
176   agr.src_dict = dict;
177   subcase_init_empty (&agr.sort);
178
179   /* OUTFILE subcommand must be first. */
180   lex_match (lexer, T_SLASH);
181   if (!lex_force_match_id (lexer, "OUTFILE"))
182     goto error;
183   lex_match (lexer, T_EQUALS);
184   if (!lex_match (lexer, T_ASTERISK))
185     {
186       out_file = fh_parse (lexer, FH_REF_FILE, dataset_session (ds));
187       if (out_file == NULL)
188         goto error;
189     }
190
191   if (out_file == NULL && lex_match_id (lexer, "MODE"))
192     {
193       lex_match (lexer, T_EQUALS);
194       if (lex_match_id (lexer, "ADDVARIABLES"))
195         {
196           agr.add_variables = true;
197
198           /* presorted is assumed in ADDVARIABLES mode */
199           presorted = true;
200         }
201       else if (lex_match_id (lexer, "REPLACE"))
202         {
203           agr.add_variables = false;
204         }
205       else
206         goto error;
207     }
208
209   if (agr.add_variables)
210     agr.dict = dict_clone (dict);
211   else
212     agr.dict = dict_create (dict_get_encoding (dict));
213
214   dict_set_label (agr.dict, dict_get_label (dict));
215   dict_set_documents (agr.dict, dict_get_documents (dict));
216
217   /* Read most of the subcommands. */
218   for (;;)
219     {
220       lex_match (lexer, T_SLASH);
221
222       if (lex_match_id (lexer, "MISSING"))
223         {
224           lex_match (lexer, T_EQUALS);
225           if (!lex_match_id (lexer, "COLUMNWISE"))
226             {
227               lex_error_expecting (lexer, "COLUMNWISE");
228               goto error;
229             }
230           agr.missing = COLUMNWISE;
231         }
232       else if (lex_match_id (lexer, "DOCUMENT"))
233         copy_documents = true;
234       else if (lex_match_id (lexer, "PRESORTED"))
235         presorted = true;
236       else if (lex_force_match_id (lexer, "BREAK"))
237         {
238           int i;
239
240           lex_match (lexer, T_EQUALS);
241           if (!parse_sort_criteria (lexer, dict, &agr.sort, &agr.break_vars,
242                                     &saw_direction))
243             goto error;
244           agr.break_n_vars = subcase_get_n_fields (&agr.sort);
245
246           if  (! agr.add_variables)
247             for (i = 0; i < agr.break_n_vars; i++)
248               dict_clone_var_assert (agr.dict, agr.break_vars[i]);
249
250           /* BREAK must follow the options. */
251           break;
252         }
253       else
254         goto error;
255
256     }
257   if (presorted && saw_direction)
258     msg (SW, _("When PRESORTED is specified, specifying sorting directions "
259                "with (A) or (D) has no effect.  Output data will be sorted "
260                "the same way as the input data."));
261
262   /* Read in the aggregate functions. */
263   lex_match (lexer, T_SLASH);
264   if (!parse_aggregate_functions (lexer, dict, &agr))
265     goto error;
266
267   /* Delete documents. */
268   if (!copy_documents)
269     dict_clear_documents (agr.dict);
270
271   /* Cancel SPLIT FILE. */
272   dict_set_split_vars (agr.dict, NULL, 0);
273
274   /* Initialize. */
275   agr.n_cases = 0;
276
277   if (out_file == NULL)
278     {
279       /* The active dataset will be replaced by the aggregated data,
280          so TEMPORARY is moot. */
281       proc_cancel_temporary_transformations (ds);
282       proc_discard_output (ds);
283       output = autopaging_writer_create (dict_get_proto (agr.dict));
284     }
285   else
286     {
287       output = any_writer_open (out_file, agr.dict);
288       if (output == NULL)
289         goto error;
290     }
291
292   input = proc_open (ds);
293   if (!subcase_is_empty (&agr.sort) && !presorted)
294     {
295       input = sort_execute (input, &agr.sort);
296       subcase_clear (&agr.sort);
297     }
298
299   for (grouper = casegrouper_create_vars (input, agr.break_vars,
300                                           agr.break_n_vars);
301        casegrouper_get_next_group (grouper, &group);
302        casereader_destroy (group))
303     {
304       struct casereader *placeholder = NULL;
305       struct ccase *c = casereader_peek (group, 0);
306
307       if (c == NULL)
308         {
309           casereader_destroy (group);
310           continue;
311         }
312
313       initialize_aggregate_info (&agr);
314
315       if (agr.add_variables)
316         placeholder = casereader_clone (group);
317
318       {
319         struct ccase *cg;
320         for (; (cg = casereader_read (group)) != NULL; case_unref (cg))
321           accumulate_aggregate_info (&agr, cg);
322       }
323
324
325       if  (agr.add_variables)
326         {
327           struct ccase *cg;
328           for (; (cg = casereader_read (placeholder)) != NULL; case_unref (cg))
329             dump_aggregate_info (&agr, output, cg);
330
331           casereader_destroy (placeholder);
332         }
333       else
334         {
335           dump_aggregate_info (&agr, output, c);
336         }
337       case_unref (c);
338     }
339   if (!casegrouper_destroy (grouper))
340     goto error;
341
342   if (!proc_commit (ds))
343     {
344       input = NULL;
345       goto error;
346     }
347   input = NULL;
348
349   if (out_file == NULL)
350     {
351       struct casereader *next_input = casewriter_make_reader (output);
352       if (next_input == NULL)
353         goto error;
354
355       dataset_set_dict (ds, agr.dict);
356       dataset_set_source (ds, next_input);
357       agr.dict = NULL;
358     }
359   else
360     {
361       ok = casewriter_destroy (output);
362       output = NULL;
363       if (!ok)
364         goto error;
365     }
366
367   agr_destroy (&agr);
368   fh_unref (out_file);
369   return CMD_SUCCESS;
370
371 error:
372   if (input != NULL)
373     proc_commit (ds);
374   casewriter_destroy (output);
375   agr_destroy (&agr);
376   fh_unref (out_file);
377   return CMD_CASCADING_FAILURE;
378 }
379
380 /* Parse all the aggregate functions. */
381 static bool
382 parse_aggregate_functions (struct lexer *lexer, const struct dictionary *dict,
383                            struct agr_proc *agr)
384 {
385   struct agr_var *tail; /* Tail of linked list starting at agr->vars. */
386
387   /* Parse everything. */
388   tail = NULL;
389   for (;;)
390     {
391       char **dest;
392       char **dest_label;
393       size_t n_dest;
394       struct string function_name;
395
396       enum mv_class exclude;
397       const struct agr_func *function;
398       int func_index;
399
400       union agr_argument arg[2];
401
402       const struct variable **src;
403       size_t n_src;
404
405       size_t i;
406
407       dest = NULL;
408       dest_label = NULL;
409       n_dest = 0;
410       src = NULL;
411       function = NULL;
412       n_src = 0;
413       arg[0].c = NULL;
414       arg[1].c = NULL;
415       ds_init_empty (&function_name);
416
417       /* Parse the list of target variables. */
418       while (!lex_match (lexer, T_EQUALS))
419         {
420           size_t n_dest_prev = n_dest;
421
422           if (!parse_DATA_LIST_vars (lexer, dict, &dest, &n_dest,
423                                      (PV_APPEND | PV_SINGLE | PV_NO_SCRATCH
424                                       | PV_NO_DUPLICATE)))
425             goto error;
426
427           /* Assign empty labels. */
428           {
429             int j;
430
431             dest_label = xnrealloc (dest_label, n_dest, sizeof *dest_label);
432             for (j = n_dest_prev; j < n_dest; j++)
433               dest_label[j] = NULL;
434           }
435
436
437
438           if (lex_is_string (lexer))
439             {
440               dest_label[n_dest - 1] = xstrdup (lex_tokcstr (lexer));
441               lex_get (lexer);
442             }
443         }
444
445       /* Get the name of the aggregation function. */
446       if (lex_token (lexer) != T_ID)
447         {
448           lex_error (lexer, _("expecting aggregation function"));
449           goto error;
450         }
451
452       ds_assign_substring (&function_name, lex_tokss (lexer));
453       exclude = ds_chomp_byte (&function_name, '.') ? MV_SYSTEM : MV_ANY;
454
455       for (function = agr_func_tab; function->name; function++)
456         if (!c_strcasecmp (function->name, ds_cstr (&function_name)))
457           break;
458       if (NULL == function->name)
459         {
460           msg (SE, _("Unknown aggregation function %s."),
461                ds_cstr (&function_name));
462           goto error;
463         }
464       ds_destroy (&function_name);
465       func_index = function - agr_func_tab;
466       lex_get (lexer);
467
468       /* Check for leading lparen. */
469       if (!lex_match (lexer, T_LPAREN))
470         {
471           if (function->src_vars == AGR_SV_YES)
472             {
473               goto error;
474             }
475         }
476       else
477         {
478           /* Parse list of source variables. */
479           {
480             int pv_opts = PV_NO_SCRATCH;
481
482             if (func_index == SUM || func_index == MEAN || func_index == SD)
483               pv_opts |= PV_NUMERIC;
484             else if (function->n_args)
485               pv_opts |= PV_SAME_TYPE;
486
487             if (!parse_variables_const (lexer, dict, &src, &n_src, pv_opts))
488               goto error;
489           }
490
491           /* Parse function arguments, for those functions that
492              require arguments. */
493           if (function->n_args != 0)
494             for (i = 0; i < function->n_args; i++)
495               {
496                 int type;
497
498                 lex_match (lexer, T_COMMA);
499                 if (lex_is_string (lexer))
500                   {
501                     arg[i].c = recode_string (dict_get_encoding (agr->dict),
502                                               "UTF-8", lex_tokcstr (lexer),
503                                               -1);
504                     type = VAL_STRING;
505                   }
506                 else if (lex_is_number (lexer))
507                   {
508                     arg[i].f = lex_tokval (lexer);
509                     type = VAL_NUMERIC;
510                   }
511                 else
512                   {
513                     msg (SE, _("Missing argument %zu to %s."),
514                          i + 1, function->name);
515                     goto error;
516                   }
517
518                 lex_get (lexer);
519
520                 if (type != var_get_type (src[0]))
521                   {
522                     msg (SE, _("Arguments to %s must be of same type as "
523                                "source variables."),
524                          function->name);
525                     goto error;
526                   }
527               }
528
529           /* Trailing rparen. */
530           if (!lex_force_match (lexer, T_RPAREN))
531             goto error;
532
533           /* Now check that the number of source variables match
534              the number of target variables.  If we check earlier
535              than this, the user can get very misleading error
536              message, i.e. `AGGREGATE x=SUM(y t).' will get this
537              error message when a proper message would be more
538              like `unknown variable t'. */
539           if (n_src != n_dest)
540             {
541               msg (SE, _("Number of source variables (%zu) does not match "
542                          "number of target variables (%zu)."),
543                     n_src, n_dest);
544               goto error;
545             }
546
547           if ((func_index == PIN || func_index == POUT
548               || func_index == FIN || func_index == FOUT)
549               && (var_is_numeric (src[0])
550                   ? arg[0].f > arg[1].f
551                   : str_compare_rpad (arg[0].c, arg[1].c) > 0))
552             {
553               union agr_argument t = arg[0];
554               arg[0] = arg[1];
555               arg[1] = t;
556
557               msg (SW, _("The value arguments passed to the %s function "
558                          "are out-of-order.  They will be treated as if "
559                          "they had been specified in the correct order."),
560                    function->name);
561             }
562         }
563
564       /* Finally add these to the linked list of aggregation
565          variables. */
566       for (i = 0; i < n_dest; i++)
567         {
568           struct agr_var *v = XZALLOC (struct agr_var);
569
570           /* Add variable to chain. */
571           if (agr->agr_vars != NULL)
572             tail->next = v;
573           else
574             agr->agr_vars = v;
575           tail = v;
576           tail->next = NULL;
577           v->moments = NULL;
578
579           /* Create the target variable in the aggregate
580              dictionary. */
581           {
582             struct variable *destvar;
583
584             v->function = func_index;
585
586             if (src)
587               {
588                 v->src = src[i];
589
590                 if (var_is_alpha (src[i]))
591                   {
592                     v->function |= FSTRING;
593                     v->string = xmalloc (var_get_width (src[i]));
594                   }
595
596                 if (function->alpha_type == VAL_STRING)
597                   destvar = dict_clone_var_as (agr->dict, v->src, dest[i]);
598                 else
599                   {
600                     assert (var_is_numeric (v->src)
601                             || function->alpha_type == VAL_NUMERIC);
602                     destvar = dict_create_var (agr->dict, dest[i], 0);
603                     if (destvar != NULL)
604                       {
605                         struct fmt_spec f;
606                         if ((func_index == N || func_index == NMISS)
607                             && dict_get_weight (dict) != NULL)
608                           f = fmt_for_output (FMT_F, 8, 2);
609                         else
610                           f = function->format;
611                         var_set_both_formats (destvar, &f);
612                       }
613                   }
614               } else {
615                 struct fmt_spec f;
616                 v->src = NULL;
617                 destvar = dict_create_var (agr->dict, dest[i], 0);
618                 if (destvar != NULL)
619                   {
620                     if ((func_index == N || func_index == NMISS)
621                         && dict_get_weight (dict) != NULL)
622                       f = fmt_for_output (FMT_F, 8, 2);
623                     else
624                       f = function->format;
625                     var_set_both_formats (destvar, &f);
626                   }
627             }
628
629             if (!destvar)
630               {
631                 msg (SE, _("Variable name %s is not unique within the "
632                            "aggregate file dictionary, which contains "
633                            "the aggregate variables and the break "
634                            "variables."),
635                      dest[i]);
636                 goto error;
637               }
638
639             free (dest[i]);
640             if (dest_label[i])
641               var_set_label (destvar, dest_label[i]);
642
643             v->dest = destvar;
644           }
645
646           v->exclude = exclude;
647
648           if (v->src != NULL)
649             {
650               int j;
651
652               if (var_is_numeric (v->src))
653                 for (j = 0; j < function->n_args; j++)
654                   v->arg[j].f = arg[j].f;
655               else
656                 for (j = 0; j < function->n_args; j++)
657                   v->arg[j].c = xstrdup (arg[j].c);
658             }
659         }
660
661       if (src != NULL && var_is_alpha (src[0]))
662         for (i = 0; i < function->n_args; i++)
663           {
664             free (arg[i].c);
665             arg[i].c = NULL;
666           }
667
668       free (src);
669       free (dest);
670       free (dest_label);
671
672       if (!lex_match (lexer, T_SLASH))
673         {
674           if (lex_token (lexer) == T_ENDCMD)
675             return true;
676
677           lex_error (lexer, "expecting end of command");
678           return false;
679         }
680       continue;
681
682     error:
683       ds_destroy (&function_name);
684       for (i = 0; i < n_dest; i++)
685         {
686           free (dest[i]);
687           free (dest_label[i]);
688         }
689       free (dest);
690       free (dest_label);
691       free (arg[0].c);
692       free (arg[1].c);
693       if (src && n_src && var_is_alpha (src[0]))
694         for (i = 0; i < function->n_args; i++)
695           {
696             free (arg[i].c);
697             arg[i].c = NULL;
698           }
699       free (src);
700
701       return false;
702     }
703 }
704
705 /* Destroys AGR. */
706 static void
707 agr_destroy (struct agr_proc *agr)
708 {
709   struct agr_var *iter, *next;
710
711   subcase_destroy (&agr->sort);
712   free (agr->break_vars);
713   for (iter = agr->agr_vars; iter; iter = next)
714     {
715       next = iter->next;
716
717       if (iter->function & FSTRING)
718         {
719           size_t n_args;
720           size_t i;
721
722           n_args = agr_func_tab[iter->function & FUNC].n_args;
723           for (i = 0; i < n_args; i++)
724             free (iter->arg[i].c);
725           free (iter->string);
726         }
727       else if (iter->function == SD)
728         moments1_destroy (iter->moments);
729
730       dict_destroy_internal_var (iter->subject);
731       dict_destroy_internal_var (iter->weight);
732
733       free (iter);
734     }
735   if (agr->dict != NULL)
736     dict_unref (agr->dict);
737 }
738 \f
739 /* Execution. */
740
741 /* Accumulates aggregation data from the case INPUT. */
742 static void
743 accumulate_aggregate_info (struct agr_proc *agr, const struct ccase *input)
744 {
745   struct agr_var *iter;
746   double weight;
747   bool bad_warn = true;
748
749   weight = dict_get_case_weight (agr->src_dict, input, &bad_warn);
750
751   for (iter = agr->agr_vars; iter; iter = iter->next)
752     if (iter->src)
753       {
754         const union value *v = case_data (input, iter->src);
755         int src_width = var_get_width (iter->src);
756
757         if (var_is_value_missing (iter->src, v, iter->exclude))
758           {
759             switch (iter->function)
760               {
761               case NMISS:
762               case NMISS | FSTRING:
763                 iter->dbl[0] += weight;
764                 break;
765               case NUMISS:
766               case NUMISS | FSTRING:
767                 iter->int1++;
768                 break;
769               }
770             iter->saw_missing = true;
771             continue;
772           }
773
774         /* This is horrible.  There are too many possibilities. */
775         switch (iter->function)
776           {
777           case SUM:
778             iter->dbl[0] += v->f * weight;
779             iter->int1 = 1;
780             break;
781           case MEAN:
782             iter->dbl[0] += v->f * weight;
783             iter->dbl[1] += weight;
784             break;
785           case MEDIAN:
786             {
787               double wv ;
788               struct ccase *cout;
789
790               cout = case_create (casewriter_get_proto (iter->writer));
791
792               *case_num_rw (cout, iter->subject) = case_num (input, iter->src);
793
794               wv = dict_get_case_weight (agr->src_dict, input, NULL);
795
796               *case_num_rw (cout, iter->weight) = wv;
797
798               iter->cc += wv;
799
800               casewriter_write (iter->writer, cout);
801             }
802             break;
803           case SD:
804             moments1_add (iter->moments, v->f, weight);
805             break;
806           case MAX:
807             iter->dbl[0] = MAX (iter->dbl[0], v->f);
808             iter->int1 = 1;
809             break;
810           case MAX | FSTRING:
811             /* Need to do some kind of Unicode collation thingy here */
812             if (memcmp (iter->string, v->s, src_width) < 0)
813               memcpy (iter->string, v->s, src_width);
814             iter->int1 = 1;
815             break;
816           case MIN:
817             iter->dbl[0] = MIN (iter->dbl[0], v->f);
818             iter->int1 = 1;
819             break;
820           case MIN | FSTRING:
821             if (memcmp (iter->string, v->s, src_width) > 0)
822               memcpy (iter->string, v->s, src_width);
823             iter->int1 = 1;
824             break;
825           case FGT:
826           case PGT:
827             if (v->f > iter->arg[0].f)
828               iter->dbl[0] += weight;
829             iter->dbl[1] += weight;
830             break;
831           case FGT | FSTRING:
832           case PGT | FSTRING:
833             if (memcmp (iter->arg[0].c, v->s, src_width) < 0)
834               iter->dbl[0] += weight;
835             iter->dbl[1] += weight;
836             break;
837           case FLT:
838           case PLT:
839             if (v->f < iter->arg[0].f)
840               iter->dbl[0] += weight;
841             iter->dbl[1] += weight;
842             break;
843           case FLT | FSTRING:
844           case PLT | FSTRING:
845             if (memcmp (iter->arg[0].c, v->s, src_width) > 0)
846               iter->dbl[0] += weight;
847             iter->dbl[1] += weight;
848             break;
849           case FIN:
850           case PIN:
851             if (iter->arg[0].f <= v->f && v->f <= iter->arg[1].f)
852               iter->dbl[0] += weight;
853             iter->dbl[1] += weight;
854             break;
855           case FIN | FSTRING:
856           case PIN | FSTRING:
857             if (memcmp (iter->arg[0].c, v->s, src_width) <= 0
858                 && memcmp (iter->arg[1].c, v->s, src_width) >= 0)
859               iter->dbl[0] += weight;
860             iter->dbl[1] += weight;
861             break;
862           case FOUT:
863           case POUT:
864             if (iter->arg[0].f > v->f || v->f > iter->arg[1].f)
865               iter->dbl[0] += weight;
866             iter->dbl[1] += weight;
867             break;
868           case FOUT | FSTRING:
869           case POUT | FSTRING:
870             if (memcmp (iter->arg[0].c, v->s, src_width) > 0
871                 || memcmp (iter->arg[1].c, v->s, src_width) < 0)
872               iter->dbl[0] += weight;
873             iter->dbl[1] += weight;
874             break;
875           case N:
876           case N | FSTRING:
877             iter->dbl[0] += weight;
878             break;
879           case NU:
880           case NU | FSTRING:
881             iter->int1++;
882             break;
883           case FIRST:
884             if (iter->int1 == 0)
885               {
886                 iter->dbl[0] = v->f;
887                 iter->int1 = 1;
888               }
889             break;
890           case FIRST | FSTRING:
891             if (iter->int1 == 0)
892               {
893                 memcpy (iter->string, v->s, src_width);
894                 iter->int1 = 1;
895               }
896             break;
897           case LAST:
898             iter->dbl[0] = v->f;
899             iter->int1 = 1;
900             break;
901           case LAST | FSTRING:
902             memcpy (iter->string, v->s, src_width);
903             iter->int1 = 1;
904             break;
905           case NMISS:
906           case NMISS | FSTRING:
907           case NUMISS:
908           case NUMISS | FSTRING:
909             /* Our value is not missing or it would have been
910                caught earlier.  Nothing to do. */
911             break;
912           default:
913             NOT_REACHED ();
914           }
915       } else {
916       switch (iter->function)
917         {
918         case N:
919           iter->dbl[0] += weight;
920           break;
921         case NU:
922           iter->int1++;
923           break;
924         default:
925           NOT_REACHED ();
926         }
927     }
928 }
929
930 /* Writes an aggregated record to OUTPUT. */
931 static void
932 dump_aggregate_info (const struct agr_proc *agr, struct casewriter *output, const struct ccase *break_case)
933 {
934   struct ccase *c = case_create (dict_get_proto (agr->dict));
935
936   if (agr->add_variables)
937     {
938       case_copy (c, 0, break_case, 0, dict_get_n_vars (agr->src_dict));
939     }
940   else
941     {
942       int value_idx = 0;
943       int i;
944
945       for (i = 0; i < agr->break_n_vars; i++)
946         {
947           const struct variable *v = agr->break_vars[i];
948           value_copy (case_data_rw_idx (c, value_idx),
949                       case_data (break_case, v),
950                       var_get_width (v));
951           value_idx++;
952         }
953     }
954
955   {
956     struct agr_var *i;
957
958     for (i = agr->agr_vars; i; i = i->next)
959       {
960         union value *v = case_data_rw (c, i->dest);
961         int width = var_get_width (i->dest);
962
963         if (agr->missing == COLUMNWISE && i->saw_missing
964             && (i->function & FUNC) != N && (i->function & FUNC) != NU
965             && (i->function & FUNC) != NMISS && (i->function & FUNC) != NUMISS)
966           {
967             value_set_missing (v, width);
968             casewriter_destroy (i->writer);
969             continue;
970           }
971
972         switch (i->function)
973           {
974           case SUM:
975             v->f = i->int1 ? i->dbl[0] : SYSMIS;
976             break;
977           case MEAN:
978             v->f = i->dbl[1] != 0.0 ? i->dbl[0] / i->dbl[1] : SYSMIS;
979             break;
980           case MEDIAN:
981             {
982               if (i->writer)
983                 {
984                   struct percentile *median = percentile_create (0.5, i->cc);
985                   struct order_stats *os = &median->parent;
986                   struct casereader *sorted_reader = casewriter_make_reader (i->writer);
987                   i->writer = NULL;
988
989                   order_stats_accumulate (&os, 1,
990                                           sorted_reader,
991                                           i->weight,
992                                           i->subject,
993                                           i->exclude);
994                   i->dbl[0] = percentile_calculate (median, PC_HAVERAGE);
995                   statistic_destroy (&median->parent.parent);
996                 }
997               v->f = i->dbl[0];
998             }
999             break;
1000           case SD:
1001             {
1002               double variance;
1003
1004               /* FIXME: we should use two passes. */
1005               moments1_calculate (i->moments, NULL, NULL, &variance,
1006                                  NULL, NULL);
1007               if (variance != SYSMIS)
1008                 v->f = sqrt (variance);
1009               else
1010                 v->f = SYSMIS;
1011             }
1012             break;
1013           case MAX:
1014           case MIN:
1015             v->f = i->int1 ? i->dbl[0] : SYSMIS;
1016             break;
1017           case MAX | FSTRING:
1018           case MIN | FSTRING:
1019             if (i->int1)
1020               memcpy (v->s, i->string, width);
1021             else
1022               value_set_missing (v, width);
1023             break;
1024           case FGT:
1025           case FGT | FSTRING:
1026           case FLT:
1027           case FLT | FSTRING:
1028           case FIN:
1029           case FIN | FSTRING:
1030           case FOUT:
1031           case FOUT | FSTRING:
1032             v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] : SYSMIS;
1033             break;
1034           case PGT:
1035           case PGT | FSTRING:
1036           case PLT:
1037           case PLT | FSTRING:
1038           case PIN:
1039           case PIN | FSTRING:
1040           case POUT:
1041           case POUT | FSTRING:
1042             v->f = i->dbl[1] ? i->dbl[0] / i->dbl[1] * 100.0 : SYSMIS;
1043             break;
1044           case N:
1045           case N | FSTRING:
1046               v->f = i->dbl[0];
1047             break;
1048           case NU:
1049           case NU | FSTRING:
1050             v->f = i->int1;
1051             break;
1052           case FIRST:
1053           case LAST:
1054             v->f = i->int1 ? i->dbl[0] : SYSMIS;
1055             break;
1056           case FIRST | FSTRING:
1057           case LAST | FSTRING:
1058             if (i->int1)
1059               memcpy (v->s, i->string, width);
1060             else
1061               value_set_missing (v, width);
1062             break;
1063           case NMISS:
1064           case NMISS | FSTRING:
1065             v->f = i->dbl[0];
1066             break;
1067           case NUMISS:
1068           case NUMISS | FSTRING:
1069             v->f = i->int1;
1070             break;
1071           default:
1072             NOT_REACHED ();
1073           }
1074       }
1075   }
1076
1077   casewriter_write (output, c);
1078 }
1079
1080 /* Resets the state for all the aggregate functions. */
1081 static void
1082 initialize_aggregate_info (struct agr_proc *agr)
1083 {
1084   struct agr_var *iter;
1085
1086   for (iter = agr->agr_vars; iter; iter = iter->next)
1087     {
1088       iter->saw_missing = false;
1089       iter->dbl[0] = iter->dbl[1] = iter->dbl[2] = 0.0;
1090       iter->int1 = iter->int2 = 0;
1091       switch (iter->function)
1092         {
1093         case MIN:
1094           iter->dbl[0] = DBL_MAX;
1095           break;
1096         case MIN | FSTRING:
1097           memset (iter->string, 255, var_get_width (iter->src));
1098           break;
1099         case MAX:
1100           iter->dbl[0] = -DBL_MAX;
1101           break;
1102         case MAX | FSTRING:
1103           memset (iter->string, 0, var_get_width (iter->src));
1104           break;
1105         case MEDIAN:
1106           {
1107             struct caseproto *proto;
1108             struct subcase ordering;
1109
1110             proto = caseproto_create ();
1111             proto = caseproto_add_width (proto, 0);
1112             proto = caseproto_add_width (proto, 0);
1113
1114             if (! iter->subject)
1115               iter->subject = dict_create_internal_var (0, 0);
1116
1117             if (! iter->weight)
1118               iter->weight = dict_create_internal_var (1, 0);
1119
1120             subcase_init_var (&ordering, iter->subject, SC_ASCEND);
1121             iter->writer = sort_create_writer (&ordering, proto);
1122             subcase_destroy (&ordering);
1123             caseproto_unref (proto);
1124
1125             iter->cc = 0;
1126           }
1127           break;
1128         case SD:
1129           if (iter->moments == NULL)
1130             iter->moments = moments1_create (MOMENT_VARIANCE);
1131           else
1132             moments1_clear (iter->moments);
1133           break;
1134         default:
1135           break;
1136         }
1137     }
1138 }