Fixed the categoricals such that now both GLM and ONEWAY work
[pspp] / src / language / xforms / compute.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 1997-9, 2000, 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <stdint.h>
20 #include <stdlib.h>
21
22 #include "data/case.h"
23 #include "data/dataset.h"
24 #include "data/dictionary.h"
25 #include "data/transformations.h"
26 #include "data/variable.h"
27 #include "data/vector.h"
28 #include "language/command.h"
29 #include "language/expressions/public.h"
30 #include "language/lexer/lexer.h"
31 #include "libpspp/message.h"
32 #include "libpspp/misc.h"
33 #include "libpspp/str.h"
34
35 #include "gl/xalloc.h"
36
37 #include "gettext.h"
38 #define _(msgid) gettext (msgid)
39
40 struct compute_trns;
41 struct lvalue;
42
43 /* Target of a COMPUTE or IF assignment, either a variable or a
44    vector element. */
45 static struct lvalue *lvalue_parse (struct lexer *lexer, struct dataset *);
46 static int lvalue_get_type (const struct lvalue *);
47 static bool lvalue_is_vector (const struct lvalue *);
48 static void lvalue_finalize (struct lvalue *,
49                              struct compute_trns *, struct dictionary *);
50 static void lvalue_destroy (struct lvalue *, struct dictionary *);
51
52 /* COMPUTE and IF transformation. */
53 struct compute_trns
54   {
55     /* Test expression (IF only). */
56     struct expression *test;     /* Test expression. */
57
58     /* Variable lvalue, if variable != NULL. */
59     struct variable *variable;   /* Destination variable, if any. */
60     int width;                   /* Lvalue string width; 0=numeric. */
61
62     /* Vector lvalue, if vector != NULL. */
63     const struct vector *vector; /* Destination vector, if any. */
64     struct expression *element;  /* Destination vector element expr. */
65
66     /* Rvalue. */
67     struct expression *rvalue;   /* Rvalue expression. */
68   };
69
70 static struct expression *parse_rvalue (struct lexer *lexer,
71                                         const struct lvalue *,
72                                         struct dataset *);
73
74 static struct compute_trns *compute_trns_create (void);
75 static trns_proc_func *get_proc_func (const struct lvalue *);
76 static trns_free_func compute_trns_free;
77 \f
78 /* COMPUTE. */
79
80 int
81 cmd_compute (struct lexer *lexer, struct dataset *ds)
82 {
83   struct dictionary *dict = dataset_dict (ds);
84   struct lvalue *lvalue = NULL;
85   struct compute_trns *compute = NULL;
86
87   compute = compute_trns_create ();
88
89   lvalue = lvalue_parse (lexer, ds);
90   if (lvalue == NULL)
91     goto fail;
92
93   if (!lex_force_match (lexer, T_EQUALS))
94     goto fail;
95   compute->rvalue = parse_rvalue (lexer, lvalue, ds);
96   if (compute->rvalue == NULL)
97     goto fail;
98
99   add_transformation (ds, get_proc_func (lvalue), compute_trns_free, compute);
100
101   lvalue_finalize (lvalue, compute, dict);
102
103   return CMD_SUCCESS;
104
105  fail:
106   lvalue_destroy (lvalue, dict);
107   compute_trns_free (compute);
108   return CMD_CASCADING_FAILURE;
109 }
110 \f
111 /* Transformation functions. */
112
113 /* Handle COMPUTE or IF with numeric target variable. */
114 static int
115 compute_num (void *compute_, struct ccase **c, casenumber case_num)
116 {
117   struct compute_trns *compute = compute_;
118
119   if (compute->test == NULL
120       || expr_evaluate_num (compute->test, *c, case_num) == 1.0)
121     {
122       *c = case_unshare (*c);
123       case_data_rw (*c, compute->variable)->f
124         = expr_evaluate_num (compute->rvalue, *c, case_num);
125     }
126
127   return TRNS_CONTINUE;
128 }
129
130 /* Handle COMPUTE or IF with numeric vector element target
131    variable. */
132 static int
133 compute_num_vec (void *compute_, struct ccase **c, casenumber case_num)
134 {
135   struct compute_trns *compute = compute_;
136
137   if (compute->test == NULL
138       || expr_evaluate_num (compute->test, *c, case_num) == 1.0)
139     {
140       double index;     /* Index into the vector. */
141       int rindx;        /* Rounded index value. */
142
143       index = expr_evaluate_num (compute->element, *c, case_num);
144       rindx = floor (index + EPSILON);
145       if (index == SYSMIS
146           || rindx < 1 || rindx > vector_get_var_cnt (compute->vector))
147         {
148           if (index == SYSMIS)
149             msg (SW, _("When executing COMPUTE: SYSMIS is not a valid value "
150                        "as an index into vector %s."),
151                  vector_get_name (compute->vector));
152           else
153             msg (SW, _("When executing COMPUTE: %g is not a valid value as "
154                        "an index into vector %s."),
155                  index, vector_get_name (compute->vector));
156           return TRNS_CONTINUE;
157         }
158
159       *c = case_unshare (*c);
160       case_data_rw (*c, vector_get_var (compute->vector, rindx - 1))->f
161         = expr_evaluate_num (compute->rvalue, *c, case_num);
162     }
163
164   return TRNS_CONTINUE;
165 }
166
167 /* Handle COMPUTE or IF with string target variable. */
168 static int
169 compute_str (void *compute_, struct ccase **c, casenumber case_num)
170 {
171   struct compute_trns *compute = compute_;
172
173   if (compute->test == NULL
174       || expr_evaluate_num (compute->test, *c, case_num) == 1.0)
175     {
176       char *s;
177
178       *c = case_unshare (*c);
179       s = CHAR_CAST_BUG (char *, case_str_rw (*c, compute->variable));
180       expr_evaluate_str (compute->rvalue, *c, case_num, s, compute->width);
181     }
182
183   return TRNS_CONTINUE;
184 }
185
186 /* Handle COMPUTE or IF with string vector element target
187    variable. */
188 static int
189 compute_str_vec (void *compute_, struct ccase **c, casenumber case_num)
190 {
191   struct compute_trns *compute = compute_;
192
193   if (compute->test == NULL
194       || expr_evaluate_num (compute->test, *c, case_num) == 1.0)
195     {
196       double index;             /* Index into the vector. */
197       int rindx;                /* Rounded index value. */
198       struct variable *vr;      /* Variable reference by indexed vector. */
199
200       index = expr_evaluate_num (compute->element, *c, case_num);
201       rindx = floor (index + EPSILON);
202       if (index == SYSMIS)
203         {
204           msg (SW, _("When executing COMPUTE: SYSMIS is not a valid "
205                      "value as an index into vector %s."),
206                vector_get_name (compute->vector));
207           return TRNS_CONTINUE;
208         }
209       else if (rindx < 1 || rindx > vector_get_var_cnt (compute->vector))
210         {
211           msg (SW, _("When executing COMPUTE: %g is not a valid value as "
212                      "an index into vector %s."),
213                index, vector_get_name (compute->vector));
214           return TRNS_CONTINUE;
215         }
216
217       vr = vector_get_var (compute->vector, rindx - 1);
218       *c = case_unshare (*c);
219       expr_evaluate_str (compute->rvalue, *c, case_num,
220                          CHAR_CAST_BUG (char *, case_str_rw (*c, vr)),
221                          var_get_width (vr));
222     }
223
224   return TRNS_CONTINUE;
225 }
226 \f
227 /* IF. */
228
229 int
230 cmd_if (struct lexer *lexer, struct dataset *ds)
231 {
232   struct dictionary *dict = dataset_dict (ds);
233   struct compute_trns *compute = NULL;
234   struct lvalue *lvalue = NULL;
235
236   compute = compute_trns_create ();
237
238   /* Test expression. */
239   compute->test = expr_parse (lexer, ds, EXPR_BOOLEAN);
240   if (compute->test == NULL)
241     goto fail;
242
243   /* Lvalue variable. */
244   lvalue = lvalue_parse (lexer, ds);
245   if (lvalue == NULL)
246     goto fail;
247
248   /* Rvalue expression. */
249   if (!lex_force_match (lexer, T_EQUALS))
250     goto fail;
251   compute->rvalue = parse_rvalue (lexer, lvalue, ds);
252   if (compute->rvalue == NULL)
253     goto fail;
254
255   add_transformation (ds, get_proc_func (lvalue), compute_trns_free, compute);
256
257   lvalue_finalize (lvalue, compute, dict);
258
259   return CMD_SUCCESS;
260
261  fail:
262   lvalue_destroy (lvalue, dict);
263   compute_trns_free (compute);
264   return CMD_CASCADING_FAILURE;
265 }
266 \f
267 /* Code common to COMPUTE and IF. */
268
269 static trns_proc_func *
270 get_proc_func (const struct lvalue *lvalue)
271 {
272   bool is_numeric = lvalue_get_type (lvalue) == VAL_NUMERIC;
273   bool is_vector = lvalue_is_vector (lvalue);
274
275   return (is_numeric
276           ? (is_vector ? compute_num_vec : compute_num)
277           : (is_vector ? compute_str_vec : compute_str));
278 }
279
280 /* Parses and returns an rvalue expression of the same type as
281    LVALUE, or a null pointer on failure. */
282 static struct expression *
283 parse_rvalue (struct lexer *lexer,
284               const struct lvalue *lvalue, struct dataset *ds)
285 {
286   bool is_numeric = lvalue_get_type (lvalue) == VAL_NUMERIC;
287
288   return expr_parse (lexer, ds, is_numeric ? EXPR_NUMBER : EXPR_STRING);
289 }
290
291 /* Returns a new struct compute_trns after initializing its fields. */
292 static struct compute_trns *
293 compute_trns_create (void)
294 {
295   struct compute_trns *compute = xmalloc (sizeof *compute);
296   compute->test = NULL;
297   compute->variable = NULL;
298   compute->vector = NULL;
299   compute->element = NULL;
300   compute->rvalue = NULL;
301   return compute;
302 }
303
304 /* Deletes all the fields in COMPUTE. */
305 static bool
306 compute_trns_free (void *compute_)
307 {
308   struct compute_trns *compute = compute_;
309
310   if (compute != NULL)
311     {
312       expr_free (compute->test);
313       expr_free (compute->element);
314       expr_free (compute->rvalue);
315       free (compute);
316     }
317   return true;
318 }
319 \f
320 /* COMPUTE or IF target variable or vector element.
321    For a variable, the `variable' member is non-null.
322    For a vector element, the `vector' member is non-null. */
323 struct lvalue
324   {
325     struct variable *variable;   /* Destination variable. */
326     bool is_new_variable;        /* Did we create the variable? */
327
328     const struct vector *vector; /* Destination vector, if any, or NULL. */
329     struct expression *element;  /* Destination vector element, or NULL. */
330   };
331
332 /* Parses the target variable or vector element into a new
333    `struct lvalue', which is returned. */
334 static struct lvalue *
335 lvalue_parse (struct lexer *lexer, struct dataset *ds)
336 {
337   struct dictionary *dict = dataset_dict (ds);
338   struct lvalue *lvalue;
339
340   lvalue = xmalloc (sizeof *lvalue);
341   lvalue->variable = NULL;
342   lvalue->is_new_variable = false;
343   lvalue->vector = NULL;
344   lvalue->element = NULL;
345
346   if (!lex_force_id (lexer))
347     goto lossage;
348
349   if (lex_next_token (lexer, 1) == T_LPAREN)
350     {
351       /* Vector. */
352       lvalue->vector = dict_lookup_vector (dict, lex_tokcstr (lexer));
353       if (lvalue->vector == NULL)
354         {
355           msg (SE, _("There is no vector named %s."), lex_tokcstr (lexer));
356           goto lossage;
357         }
358
359       /* Vector element. */
360       lex_get (lexer);
361       if (!lex_force_match (lexer, T_LPAREN))
362         goto lossage;
363       lvalue->element = expr_parse (lexer, ds, EXPR_NUMBER);
364       if (lvalue->element == NULL)
365         goto lossage;
366       if (!lex_force_match (lexer, T_RPAREN))
367         goto lossage;
368     }
369   else
370     {
371       /* Variable name. */
372       const char *var_name = lex_tokcstr (lexer);
373       lvalue->variable = dict_lookup_var (dict, var_name);
374       if (lvalue->variable == NULL)
375         {
376           lvalue->variable = dict_create_var_assert (dict, var_name, 0);
377           lvalue->is_new_variable = true;
378         }
379       lex_get (lexer);
380     }
381   return lvalue;
382
383  lossage:
384   lvalue_destroy (lvalue, dict);
385   return NULL;
386 }
387
388 /* Returns the type (NUMERIC or ALPHA) of the target variable or
389    vector in LVALUE. */
390 static int
391 lvalue_get_type (const struct lvalue *lvalue)
392 {
393   return (lvalue->variable != NULL
394           ? var_get_type (lvalue->variable)
395           : vector_get_type (lvalue->vector));
396 }
397
398 /* Returns true if LVALUE has a vector as its target. */
399 static bool
400 lvalue_is_vector (const struct lvalue *lvalue)
401 {
402   return lvalue->vector != NULL;
403 }
404
405 /* Finalizes making LVALUE the target of COMPUTE, by creating the
406    target variable if necessary and setting fields in COMPUTE. */
407 static void
408 lvalue_finalize (struct lvalue *lvalue,
409                  struct compute_trns *compute,
410                  struct dictionary *dict)
411 {
412   if (lvalue->vector == NULL)
413     {
414       compute->variable = lvalue->variable;
415       compute->width = var_get_width (compute->variable);
416
417       /* Goofy behavior, but compatible: Turn off LEAVE. */
418       if (!var_must_leave (compute->variable))
419         var_set_leave (compute->variable, false);
420
421       /* Prevent lvalue_destroy from deleting variable. */
422       lvalue->is_new_variable = false;
423     }
424   else
425     {
426       compute->vector = lvalue->vector;
427       compute->element = lvalue->element;
428       lvalue->element = NULL;
429     }
430
431   lvalue_destroy (lvalue, dict);
432 }
433
434 /* Destroys LVALUE. */
435 static void
436 lvalue_destroy (struct lvalue *lvalue, struct dictionary *dict)
437 {
438   if (lvalue == NULL)
439      return;
440
441   if (lvalue->is_new_variable)
442     dict_delete_var (dict, lvalue->variable);
443   expr_free (lvalue->element);
444   free (lvalue);
445 }