Add GRAPH command initially with just scatterplots and histograms.
[pspp] / src / language / stats / graph.c
1 /*
2   PSPP - a program for statistical analysis.
3   Copyright (C) 2012, 2013  Free Software Foundation, Inc.
4   
5   This program is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   This program is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14   
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 /*
20  * This module implements the graph command
21  */
22
23 #include <config.h>
24
25 #include <math.h>
26 #include <gsl/gsl_cdf.h>
27
28 #include "libpspp/assertion.h"
29 #include "libpspp/message.h"
30 #include "libpspp/pool.h"
31
32
33 #include "data/dataset.h"
34 #include "data/dictionary.h"
35 #include "data/casegrouper.h"
36 #include "data/casereader.h"
37 #include "data/casewriter.h"
38 #include "data/caseproto.h"
39 #include "data/subcase.h"
40
41
42 #include "data/format.h"
43
44 #include "math/chart-geometry.h"
45 #include "math/histogram.h"
46 #include "math/moments.h"
47 #include "math/sort.h"
48 #include "math/order-stats.h"
49 #include "output/charts/plot-hist.h"
50 #include "output/charts/scatterplot.h"
51
52 #include "language/command.h"
53 #include "language/lexer/lexer.h"
54 #include "language/lexer/value-parser.h"
55 #include "language/lexer/variable-parser.h"
56
57 #include "output/tab.h"
58
59 #include "gettext.h"
60 #define _(msgid) gettext (msgid)
61 #define N_(msgid) msgid
62
63 enum chart_type
64   {
65     CT_NONE,
66     CT_BAR,
67     CT_LINE,
68     CT_PIE,
69     CT_ERRORBAR,
70     CT_HILO,
71     CT_HISTOGRAM,
72     CT_SCATTERPLOT,
73     CT_PARETO
74   };
75
76 enum scatter_type
77   {
78     ST_BIVARIATE,
79     ST_OVERLAY,
80     ST_MATRIX,
81     ST_XYZ
82   };
83
84 struct exploratory_stats
85 {
86   double missing;
87   double non_missing;
88
89   struct moments *mom;
90
91   double minimum;
92   double maximum;
93
94   /* Total weight */
95   double cc;
96
97   /* The minimum weight */
98   double cmin;
99 };
100
101
102 struct graph
103 {
104   struct pool *pool;
105
106   size_t n_dep_vars;
107   const struct variable **dep_vars;
108   struct exploratory_stats *es;
109
110   enum mv_class dep_excl;
111   enum mv_class fctr_excl;
112
113   const struct dictionary *dict;
114
115   bool missing_pw;
116
117   /* ------------ Graph ---------------- */
118   enum chart_type chart_type;
119   enum scatter_type scatter_type;
120   const struct variable *byvar;
121 };
122
123
124 static void
125 show_scatterplot (const struct graph *cmd, const struct casereader *input)
126 {
127   struct string title;
128   struct scatterplot_chart *scatterplot;
129   bool byvar_overflow = false;
130
131   ds_init_cstr (&title, var_to_string (cmd->dep_vars[0]));
132   ds_put_cstr (&title, " vs ");              
133   ds_put_cstr (&title, var_to_string (cmd->dep_vars[1]));
134   if (cmd->byvar)
135     {
136       ds_put_cstr (&title, " by ");                
137       ds_put_cstr (&title, var_to_string (cmd->byvar));
138     }    
139
140   scatterplot = scatterplot_create(input,
141                                    cmd->dep_vars[0], 
142                                    cmd->dep_vars[1],
143                                    cmd->byvar,
144                                    &byvar_overflow,
145                                    ds_cstr (&title),
146                                    cmd->es[0].minimum, cmd->es[0].maximum,
147                                    cmd->es[1].minimum, cmd->es[1].maximum);
148   scatterplot_chart_submit(scatterplot);
149   ds_destroy(&title);
150
151   if (byvar_overflow)
152     {
153       msg (MW, _("Maximum number of scatterplot categories reached." 
154                  "Your BY variable has too many distinct values."
155                  "The colouring of the plot will not be correct"));
156     }
157
158
159 }
160
161 static void
162 show_histogr (const struct graph *cmd, const struct casereader *input)
163 {
164   struct histogram *histogram;
165   struct ccase *c;
166   struct casereader *reader;
167
168   {
169     /* Sturges Rule */
170     double bin_width = fabs (cmd->es[0].minimum - cmd->es[0].maximum)
171       / (1 + log2 (cmd->es[0].cc))
172       ;
173
174     histogram =
175       histogram_create (bin_width, cmd->es[0].minimum, cmd->es[0].maximum);
176   }
177
178
179   for (reader=casereader_clone(input);(c = casereader_read (reader)) != NULL; case_unref (c))
180     {
181       const struct variable *var = cmd->dep_vars[0];
182       const double x = case_data (c, var)->f;
183       const double weight = dict_get_case_weight(cmd->dict,c,NULL);
184       moments_pass_two (cmd->es[0].mom, x, weight);
185       histogram_add (histogram, x, weight);
186     }
187   casereader_destroy(reader);
188
189
190   {
191     double n, mean, var;
192
193     struct string label;
194
195     ds_init_cstr (&label, 
196                   var_to_string (cmd->dep_vars[0]));
197
198     moments_calculate (cmd->es[0].mom, &n, &mean, &var, NULL, NULL);
199
200     chart_item_submit
201       ( histogram_chart_create (histogram->gsl_hist,
202                                 ds_cstr (&label), n, mean,
203                                 sqrt (var), false));
204
205     statistic_destroy(&histogram->parent);      
206     ds_destroy (&label);
207   }
208 }
209
210 static void
211 cleanup_exploratory_stats (struct graph *cmd)
212
213   int v;
214
215   for (v = 0; v < cmd->n_dep_vars; ++v)
216     {
217       moments_destroy (cmd->es[v].mom);
218     }
219 }
220
221
222 static void
223 run_graph (struct graph *cmd, struct casereader *input)
224 {
225   struct ccase *c;
226   struct casereader *reader;
227
228
229   cmd->es = pool_calloc(cmd->pool,cmd->n_dep_vars,sizeof(struct exploratory_stats));
230   for(int v=0;v<cmd->n_dep_vars;v++)
231     {
232       cmd->es[v].mom = moments_create (MOMENT_KURTOSIS);
233       cmd->es[v].cmin = DBL_MAX;
234       cmd->es[v].maximum = -DBL_MAX;
235       cmd->es[v].minimum =  DBL_MAX;
236     }
237   /* Always remove cases listwise. This is correct for */
238   /* the histogram because there is only one variable  */
239   /* and a simple bivariate scatterplot                */
240   /* if ( cmd->missing_pw == false)                    */
241     input = casereader_create_filter_missing (input,
242                                               cmd->dep_vars,
243                                               cmd->n_dep_vars,
244                                               cmd->dep_excl,
245                                               NULL,
246                                               NULL);
247
248   for (reader = casereader_clone (input);
249        (c = casereader_read (reader)) != NULL; case_unref (c))
250     {
251       const double weight = dict_get_case_weight(cmd->dict,c,NULL);      
252       for(int v=0;v<cmd->n_dep_vars;v++)
253         {
254           const struct variable *var = cmd->dep_vars[v];
255           const double x = case_data (c, var)->f;
256
257           if (var_is_value_missing (var, case_data (c, var), cmd->dep_excl))
258             {
259               cmd->es[v].missing += weight;
260               continue;
261             }
262
263           if (x > cmd->es[v].maximum)
264             cmd->es[v].maximum = x;
265
266           if (x < cmd->es[v].minimum)
267             cmd->es[v].minimum =  x;
268
269           cmd->es[v].non_missing += weight;
270
271           moments_pass_one (cmd->es[v].mom, x, weight);
272
273           cmd->es[v].cc += weight;
274
275           if (cmd->es[v].cmin > weight)
276             cmd->es[v].cmin = weight;
277         }
278     }
279   casereader_destroy (reader);
280
281   switch (cmd->chart_type)
282     {
283     case CT_HISTOGRAM:
284       reader = casereader_clone(input);
285       show_histogr(cmd,reader);
286       casereader_destroy(reader);
287       break;
288     case CT_SCATTERPLOT:
289       reader = casereader_clone(input);
290       show_scatterplot(cmd,reader);
291       casereader_destroy(reader);
292       break;
293     default:
294       NOT_REACHED ();
295       break;
296     };
297
298   casereader_destroy(input);
299
300   cleanup_exploratory_stats (cmd);
301 }
302
303
304 int
305 cmd_graph (struct lexer *lexer, struct dataset *ds)
306 {
307   struct graph graph;
308
309   graph.missing_pw = false;
310   
311   graph.pool = pool_create ();
312
313   graph.dep_excl = MV_ANY;
314   graph.fctr_excl = MV_ANY;
315   
316   graph.dict = dataset_dict (ds);
317   
318
319   /* ---------------- graph ------------------ */
320   graph.dep_vars = NULL;
321   graph.chart_type = CT_NONE;
322   graph.scatter_type = ST_BIVARIATE;
323   graph.byvar = NULL;
324
325   while (lex_token (lexer) != T_ENDCMD)
326     {
327       lex_match (lexer, T_SLASH);
328
329       if (lex_match_id(lexer, "HISTOGRAM"))
330         {
331           if (graph.chart_type != CT_NONE)
332             {
333               lex_error(lexer, _("Only one chart type is allowed."));
334               goto error;
335             }
336           if (!lex_force_match (lexer, T_EQUALS))
337             goto error;
338           graph.chart_type = CT_HISTOGRAM;
339           if (!parse_variables_const (lexer, graph.dict,
340                                       &graph.dep_vars, &graph.n_dep_vars,
341                                       PV_NO_DUPLICATE | PV_NUMERIC))
342             goto error;
343           if (graph.n_dep_vars > 1)
344             {
345               lex_error(lexer, _("Only one variable allowed"));
346               goto error;
347             }
348         }
349       else if (lex_match_id (lexer, "SCATTERPLOT"))
350         {
351           if (graph.chart_type != CT_NONE)
352             {
353               lex_error(lexer, _("Only one chart type is allowed."));
354               goto error;
355             }
356           graph.chart_type = CT_SCATTERPLOT;
357           if (lex_match (lexer, T_LPAREN)) 
358             {
359               if (lex_match_id (lexer, "BIVARIATE"))
360                 {
361                   /* This is the default anyway */
362                 }
363               else if (lex_match_id (lexer, "OVERLAY"))  
364                 {
365                   lex_error(lexer, _("%s is not yet implemented."),"OVERLAY");
366                   goto error;
367                 }
368               else if (lex_match_id (lexer, "MATRIX"))  
369                 {
370                   lex_error(lexer, _("%s is not yet implemented."),"MATRIX");
371                   goto error;
372                 }
373               else if (lex_match_id (lexer, "XYZ"))  
374                 {
375                   lex_error(lexer, _("%s is not yet implemented."),"XYZ");
376                   goto error;
377                 }
378               else
379                 {
380                   lex_error_expecting(lexer, "BIVARIATE", NULL);
381                   goto error;
382                 }
383               if (!lex_force_match (lexer, T_RPAREN))
384                 goto error;
385             }
386           if (!lex_force_match (lexer, T_EQUALS))
387             goto error;
388
389           if (!parse_variables_const (lexer, graph.dict,
390                                       &graph.dep_vars, &graph.n_dep_vars,
391                                       PV_NO_DUPLICATE | PV_NUMERIC))
392             goto error;
393          
394           if (graph.scatter_type == ST_BIVARIATE && graph.n_dep_vars != 1)
395             {
396               lex_error(lexer, _("Only one variable allowed"));
397               goto error;
398             }
399
400           if (!lex_force_match (lexer, T_WITH))
401             goto error;
402
403           if (!parse_variables_const (lexer, graph.dict,
404                                       &graph.dep_vars, &graph.n_dep_vars,
405                                       PV_NO_DUPLICATE | PV_NUMERIC | PV_APPEND))
406             goto error;
407
408           if (graph.scatter_type == ST_BIVARIATE && graph.n_dep_vars != 2)
409             {
410               lex_error(lexer, _("Only one variable allowed"));
411               goto error;
412             }
413           
414           if (lex_match(lexer, T_BY))
415             {
416               const struct variable *v = NULL;
417               if (!lex_match_variable (lexer,graph.dict,&v))
418                 {
419                   lex_error(lexer, _("Variable expected"));
420                   goto error;
421                 }
422               graph.byvar = v;
423             }
424         }
425       else if (lex_match_id (lexer, "BAR"))
426         {
427           lex_error (lexer, _("%s is not yet implemented."),"BAR");
428           goto error;
429         }
430       else if (lex_match_id (lexer, "LINE"))
431         {
432           lex_error (lexer, _("%s is not yet implemented."),"LINE");
433           goto error;
434         }
435       else if (lex_match_id (lexer, "PIE"))
436         {
437           lex_error (lexer, _("%s is not yet implemented."),"PIE");
438           goto error;
439         }
440       else if (lex_match_id (lexer, "ERRORBAR"))
441         {
442           lex_error (lexer, _("%s is not yet implemented."),"ERRORBAR");
443           goto error;
444         }
445       else if (lex_match_id (lexer, "PARETO"))
446         {
447           lex_error (lexer, _("%s is not yet implemented."),"PARETO");
448           goto error;
449         }
450       else if (lex_match_id (lexer, "TITLE"))
451         {
452           lex_error (lexer, _("%s is not yet implemented."),"TITLE");
453           goto error;
454         }
455       else if (lex_match_id (lexer, "SUBTITLE"))
456         {
457           lex_error (lexer, _("%s is not yet implemented."),"SUBTITLE");
458           goto error;
459         }
460       else if (lex_match_id (lexer, "FOOTNOTE"))
461         {
462           lex_error (lexer, _("%s is not yet implemented."),"FOOTNOTE");
463           lex_error (lexer, _("FOOTNOTE is not implemented yet for GRAPH"));
464           goto error;
465         }
466       else if (lex_match_id (lexer, "MISSING"))
467         {
468           lex_match (lexer, T_EQUALS);
469
470           while (lex_token (lexer) != T_ENDCMD
471                  && lex_token (lexer) != T_SLASH)
472             {
473               if (lex_match_id (lexer, "LISTWISE"))
474                 {
475                   graph.missing_pw = false;
476                 }
477               else if (lex_match_id (lexer, "VARIABLE"))
478                 {
479                   graph.missing_pw = true;
480                 }
481               else if (lex_match_id (lexer, "EXCLUDE"))
482                 {
483                   graph.dep_excl = MV_ANY;
484                 }
485               else if (lex_match_id (lexer, "INCLUDE"))
486                 {
487                   graph.dep_excl = MV_SYSTEM;
488                 }
489               else if (lex_match_id (lexer, "REPORT"))
490                 {
491                   graph.fctr_excl = MV_NEVER;
492                 }
493               else if (lex_match_id (lexer, "NOREPORT"))
494                 {
495                   graph.fctr_excl = MV_ANY;
496                 }
497               else
498                 {
499                   lex_error (lexer, NULL);
500                   goto error;
501                 }
502             }
503         }
504       else
505         {
506           lex_error (lexer, NULL);
507           goto error;
508         }
509     }
510
511   if (graph.chart_type == CT_NONE)
512     {
513       lex_error_expecting(lexer,"HISTOGRAM","SCATTERPLOT",NULL);
514       goto error;
515     }
516
517
518   {
519     struct casegrouper *grouper;
520     struct casereader *group;
521     bool ok;
522     
523     grouper = casegrouper_create_splits (proc_open (ds), graph.dict);
524     while (casegrouper_get_next_group (grouper, &group))
525       run_graph (&graph, group);
526     ok = casegrouper_destroy (grouper);
527     ok = proc_commit (ds) && ok;
528   }
529
530   free (graph.dep_vars);
531   pool_destroy (graph.pool);
532
533   return CMD_SUCCESS;
534
535  error:
536   free (graph.dep_vars);
537   pool_destroy (graph.pool);
538
539   return CMD_FAILURE;
540 }