GET DATA: Improve coding style and tests.
[pspp] / src / data / psql-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2008, 2009, 2010, 2011, 2012 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "data/psql-reader.h"
20
21 #include <inttypes.h>
22 #include <math.h>
23 #include <stdlib.h>
24
25 #include "data/calendar.h"
26 #include "data/casereader-provider.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/variable.h"
30 #include "libpspp/i18n.h"
31 #include "libpspp/message.h"
32 #include "libpspp/misc.h"
33 #include "libpspp/str.h"
34
35 #include "gl/c-strcase.h"
36 #include "gl/minmax.h"
37 #include "gl/xalloc.h"
38
39 #include "gettext.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) (msgid)
42
43
44 #if !PSQL_SUPPORT
45 struct casereader *
46 psql_open_reader (struct psql_read_info *info UNUSED, struct dictionary **dict UNUSED)
47 {
48   msg (ME, _("Support for reading postgres databases was not compiled into this installation of PSPP"));
49
50   return NULL;
51 }
52
53 #else
54
55 #include <stdint.h>
56 #include <libpq-fe.h>
57
58
59 /* Default width of string variables. */
60 #define PSQL_DEFAULT_WIDTH 8
61
62 /* These macros  must be the same as in catalog/pg_types.h from the postgres source */
63 #define BOOLOID            16
64 #define BYTEAOID           17
65 #define CHAROID            18
66 #define NAMEOID            19
67 #define INT8OID            20
68 #define INT2OID            21
69 #define INT4OID            23
70 #define TEXTOID            25
71 #define OIDOID             26
72 #define FLOAT4OID          700
73 #define FLOAT8OID          701
74 #define CASHOID            790
75 #define BPCHAROID          1042
76 #define VARCHAROID         1043
77 #define DATEOID            1082
78 #define TIMEOID            1083
79 #define TIMESTAMPOID       1114
80 #define TIMESTAMPTZOID     1184
81 #define INTERVALOID        1186
82 #define TIMETZOID          1266
83 #define NUMERICOID         1700
84
85 static void psql_casereader_destroy (struct casereader *reader UNUSED, void *r_);
86
87 static struct ccase *psql_casereader_read (struct casereader *, void *);
88
89 static const struct casereader_class psql_casereader_class =
90   {
91     psql_casereader_read,
92     psql_casereader_destroy,
93     NULL,
94     NULL,
95   };
96
97 struct psql_reader
98 {
99   PGconn *conn;
100   PGresult *res;
101   int tuple;
102
103   bool integer_datetimes;
104
105   double postgres_epoch;
106
107   struct caseproto *proto;
108   struct dictionary *dict;
109
110   /* An array of ints, which maps psql column numbers into
111      pspp variables */
112   struct variable **vmap;
113   size_t vmapsize;
114
115   struct string fetch_cmd;
116   int cache_size;
117 };
118
119
120 static struct ccase *set_value (struct psql_reader *r);
121
122
123
124 #if WORDS_BIGENDIAN
125 static void
126 data_to_native (const void *in_, void *out_, int len)
127 {
128   int i;
129   const unsigned char *in = in_;
130   unsigned char *out = out_;
131   for (i = 0 ; i < len ; ++i)
132     out[i] = in[i];
133 }
134 #else
135 static void
136 data_to_native (const void *in_, void *out_, int len)
137 {
138   int i;
139   const unsigned char *in = in_;
140   unsigned char *out = out_;
141   for (i = 0 ; i < len ; ++i)
142     out[len - i - 1] = in[i];
143 }
144 #endif
145
146
147 #define GET_VALUE(IN, OUT) do { \
148     size_t sz = sizeof (OUT); \
149     data_to_native (*(IN), &(OUT), sz) ; \
150     (*IN) += sz; \
151 } while (false)
152
153
154 #if 0
155 static void
156 dump (const unsigned char *x, int l)
157 {
158   int i;
159
160   for (i = 0; i < l ; ++i)
161     {
162       printf ("%02x ", x[i]);
163     }
164
165   putchar ('\n');
166
167   for (i = 0; i < l ; ++i)
168     {
169       if (isprint (x[i]))
170         printf ("%c ", x[i]);
171       else
172         printf ("   ");
173     }
174
175   putchar ('\n');
176 }
177 #endif
178
179 static struct variable *
180 create_var (struct psql_reader *r, const struct fmt_spec *fmt,
181             int width, const char *suggested_name, int col)
182 {
183   unsigned long int vx = 0;
184   struct variable *var;
185   char *name;
186
187   name = dict_make_unique_var_name (r->dict, suggested_name, &vx);
188   var = dict_create_var (r->dict, name, width);
189   free (name);
190
191   var_set_both_formats (var, fmt);
192
193   if (col != -1)
194     {
195       r->vmap = xrealloc (r->vmap, (col + 1) * sizeof (*r->vmap));
196
197       r->vmap[col] = var;
198       r->vmapsize = col + 1;
199     }
200
201   return var;
202 }
203
204
205
206
207 /* Fill the cache */
208 static bool
209 reload_cache (struct psql_reader *r)
210 {
211   PQclear (r->res);
212   r->tuple = 0;
213
214   r->res = PQexec (r->conn, ds_cstr (&r->fetch_cmd));
215
216   if (PQresultStatus (r->res) != PGRES_TUPLES_OK || PQntuples (r->res) < 1)
217     {
218       PQclear (r->res);
219       r->res = NULL;
220       return false;
221     }
222
223   return true;
224 }
225
226
227 struct casereader *
228 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
229 {
230   int i;
231   int n_fields, n_tuples;
232   PGresult *qres = NULL;
233   casenumber n_cases = CASENUMBER_MAX;
234   const char *encoding;
235
236   struct psql_reader *r = XZALLOC (struct psql_reader);
237
238   r->conn = PQconnectdb (info->conninfo);
239   if (NULL == r->conn)
240     {
241       msg (ME, _("Memory error whilst opening psql source"));
242       goto error;
243     }
244
245   if (PQstatus (r->conn) != CONNECTION_OK)
246     {
247       msg (ME, _("Error opening psql source: %s."),
248            PQerrorMessage (r->conn));
249
250       goto error;
251     }
252
253   {
254     int ver_num = 0;
255     const char *vers = PQparameterStatus (r->conn, "server_version");
256
257     sscanf (vers, "%d", &ver_num);
258
259     if (ver_num < 8)
260       {
261         msg (ME,
262              _("Postgres server is version %s."
263                " Reading from versions earlier than 8.0 is not supported."),
264              vers);
265
266         goto error;
267       }
268   }
269
270   {
271     const char *dt =  PQparameterStatus (r->conn, "integer_datetimes");
272
273     r->integer_datetimes = (0 == c_strcasecmp (dt, "on"));
274   }
275
276 #if USE_SSL
277   if (PQgetssl (r->conn) == NULL)
278 #endif
279     {
280       if (! info->allow_clear)
281         {
282           msg (ME, _("Connection is unencrypted, "
283                      "but unencrypted connections have not been permitted."));
284           goto error;
285         }
286     }
287
288   r->postgres_epoch = calendar_gregorian_to_offset (
289     2000, 1, 1, settings_get_fmt_settings (), NULL);
290
291   {
292     const int enc = PQclientEncoding (r->conn);
293
294     /* According to section 22.2 of the Postgresql manual
295        a value of zero (SQL_ASCII) indicates
296        "a declaration of ignorance about the encoding".
297        Accordingly, we use the default encoding
298        if we find this value.
299     */
300     encoding = enc ? pg_encoding_to_char (enc) : get_default_encoding ();
301
302     /* Create the dictionary and populate it */
303     *dict = r->dict = dict_create (encoding);
304   }
305
306   const int version = PQserverVersion (r->conn);
307   /*
308     Versions before 9.1 don't have the REPEATABLE READ isolation level.
309     However according to <a12321aabb@gmail.com> if the server is in the
310     "hot standby" mode then SERIALIZABLE won't work.
311    */
312   char *query = xasprintf (
313     "BEGIN READ ONLY ISOLATION LEVEL %s; "
314     "DECLARE  pspp BINARY CURSOR FOR %s",
315     (version < 90100) ? "SERIALIZABLE" : "REPEATABLE READ",
316     info->sql);
317   qres = PQexec (r->conn, query);
318   free (query);
319
320   if (PQresultStatus (qres) != PGRES_COMMAND_OK)
321     {
322       msg (ME, _("Error from psql source: %s."),
323            PQresultErrorMessage (qres));
324       goto error;
325     }
326
327   PQclear (qres);
328
329
330   /* Now use the count() function to find the total number of cases
331      that this query returns.
332      Doing this incurs some overhead.  The server has to iterate every
333      case in order to find this number.  However, it's performed on the
334      server side, and in all except the most huge databases the extra
335      overhead will be worth the effort.
336      On the other hand, most PSPP functions don't need to know this.
337      The GUI is the notable exception.
338   */
339   query = xasprintf ("SELECT count (*) FROM (%s) stupid_sql_standard",
340                      info->sql);
341   qres = PQexec (r->conn, query);
342   free (query);
343
344   if (PQresultStatus (qres) != PGRES_TUPLES_OK)
345     {
346       msg (ME, _("Error from psql source: %s."),
347            PQresultErrorMessage (qres));
348       goto error;
349     }
350   n_cases = atol (PQgetvalue (qres, 0, 0));
351   PQclear (qres);
352
353   qres = PQexec (r->conn, "FETCH FIRST FROM pspp");
354   if (PQresultStatus (qres) != PGRES_TUPLES_OK)
355     {
356       msg (ME, _("Error from psql source: %s."),
357            PQresultErrorMessage (qres));
358       goto error;
359     }
360
361   n_tuples = PQntuples (qres);
362   n_fields = PQnfields (qres);
363
364   r->proto = NULL;
365   r->vmap = NULL;
366   r->vmapsize = 0;
367
368   for (i = 0 ; i < n_fields ; ++i)
369     {
370       struct variable *var;
371       struct fmt_spec fmt = { .type = FMT_F, .w = 8, .d = 2 };
372       Oid type = PQftype (qres, i);
373       int width = 0;
374       int length ;
375
376       /* If there are no data then make a finger in the air
377          guess at the contents */
378       if (n_tuples > 0)
379         length = PQgetlength (qres, 0, i);
380       else
381         length = PSQL_DEFAULT_WIDTH;
382
383       switch (type)
384         {
385         case BOOLOID:
386         case OIDOID:
387         case INT2OID:
388         case INT4OID:
389         case INT8OID:
390         case FLOAT4OID:
391         case FLOAT8OID:
392           fmt.type = FMT_F;
393           break;
394         case CASHOID:
395           fmt.type = FMT_DOLLAR;
396           break;
397         case CHAROID:
398           fmt.type = FMT_A;
399           width = length > 0 ? length : 1;
400           fmt.d = 0;
401           fmt.w = 1;
402           break;
403         case TEXTOID:
404         case VARCHAROID:
405         case BPCHAROID:
406           fmt.type = FMT_A;
407           width = (info->str_width == -1) ?
408             ROUND_UP (length, PSQL_DEFAULT_WIDTH) : info->str_width;
409           fmt.w = width;
410           fmt.d = 0;
411           break;
412         case BYTEAOID:
413           fmt.type = FMT_AHEX;
414           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
415           fmt.w = width * 2;
416           fmt.d = 0;
417           break;
418         case INTERVALOID:
419           fmt.type = FMT_DTIME;
420           width = 0;
421           fmt.d = 0;
422           fmt.w = 13;
423           break;
424         case DATEOID:
425           fmt.type = FMT_DATE;
426           width = 0;
427           fmt.w = 11;
428           fmt.d = 0;
429           break;
430         case TIMEOID:
431         case TIMETZOID:
432           fmt.type = FMT_TIME;
433           width = 0;
434           fmt.w = 11;
435           fmt.d = 0;
436           break;
437         case TIMESTAMPOID:
438         case TIMESTAMPTZOID:
439           fmt.type = FMT_DATETIME;
440           fmt.d = 0;
441           fmt.w = 22;
442           width = 0;
443           break;
444         case NUMERICOID:
445           fmt.type = FMT_E;
446           fmt.d = 2;
447           fmt.w = 40;
448           width = 0;
449           break;
450         default:
451           msg (MW, _("Unsupported OID %d.  SYSMIS values will be inserted."), type);
452           fmt.type = FMT_A;
453           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
454           fmt.w = width ;
455           fmt.d = 0;
456           break;
457         }
458
459       if (width == 0 && fmt_is_string (fmt.type))
460         fmt.w = width = PSQL_DEFAULT_WIDTH;
461
462
463       var = create_var (r, &fmt, width, PQfname (qres, i), i);
464       if (type == NUMERICOID && n_tuples > 0)
465         {
466           const uint8_t *vptr = (const uint8_t *) PQgetvalue (qres, 0, i);
467           struct fmt_spec fmt;
468           int16_t n_digits, weight, dscale;
469           uint16_t sign;
470
471           GET_VALUE (&vptr, n_digits);
472           GET_VALUE (&vptr, weight);
473           GET_VALUE (&vptr, sign);
474           GET_VALUE (&vptr, dscale);
475
476           fmt.d = dscale;
477           fmt.type = FMT_E;
478           fmt.w = fmt_max_output_width (fmt.type) ;
479           fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
480           var_set_both_formats (var, &fmt);
481         }
482
483       /* Timezones need an extra variable */
484       switch (type)
485         {
486         case TIMETZOID:
487           {
488             struct string name;
489             ds_init_cstr (&name, var_get_name (var));
490             ds_put_cstr (&name, "-zone");
491             fmt.type = FMT_F;
492             fmt.w = 8;
493             fmt.d = 2;
494
495             create_var (r, &fmt, 0, ds_cstr (&name), -1);
496
497             ds_destroy (&name);
498           }
499           break;
500
501         case INTERVALOID:
502           {
503             struct string name;
504             ds_init_cstr (&name, var_get_name (var));
505             ds_put_cstr (&name, "-months");
506             fmt.type = FMT_F;
507             fmt.w = 3;
508             fmt.d = 0;
509
510             create_var (r, &fmt, 0, ds_cstr (&name), -1);
511
512             ds_destroy (&name);
513           }
514         default:
515           break;
516         }
517     }
518
519   PQclear (qres);
520
521   qres = PQexec (r->conn, "MOVE BACKWARD 1 FROM pspp");
522   if (PQresultStatus (qres) != PGRES_COMMAND_OK)
523     {
524       PQclear (qres);
525       goto error;
526     }
527   PQclear (qres);
528
529   r->cache_size = info->bsize != -1 ? info->bsize: 4096;
530
531   ds_init_empty (&r->fetch_cmd);
532   ds_put_format (&r->fetch_cmd,  "FETCH FORWARD %d FROM pspp", r->cache_size);
533
534   reload_cache (r);
535   r->proto = caseproto_ref (dict_get_proto (*dict));
536
537   return casereader_create_sequential
538     (NULL,
539      r->proto,
540      n_cases,
541      &psql_casereader_class, r);
542
543  error:
544   dict_unref (*dict);
545
546   psql_casereader_destroy (NULL, r);
547   return NULL;
548 }
549
550
551 static void
552 psql_casereader_destroy (struct casereader *reader UNUSED, void *r_)
553 {
554   struct psql_reader *r = r_;
555   if (r == NULL)
556     return ;
557
558   ds_destroy (&r->fetch_cmd);
559   free (r->vmap);
560   if (r->res) PQclear (r->res);
561   PQfinish (r->conn);
562   caseproto_unref (r->proto);
563
564   free (r);
565 }
566
567
568
569 static struct ccase *
570 psql_casereader_read (struct casereader *reader UNUSED, void *r_)
571 {
572   struct psql_reader *r = r_;
573
574   if (NULL == r->res || r->tuple >= r->cache_size)
575     {
576       if (! reload_cache (r))
577         return false;
578     }
579
580   return set_value (r);
581 }
582
583 static struct ccase *
584 set_value (struct psql_reader *r)
585 {
586   struct ccase *c;
587   int n_vars;
588   int i;
589
590   assert (r->res);
591
592   n_vars = PQnfields (r->res);
593
594   if (r->tuple >= PQntuples (r->res))
595     return NULL;
596
597   c = case_create (r->proto);
598   case_set_missing (c);
599
600
601   for (i = 0 ; i < n_vars ; ++i)
602     {
603       Oid type = PQftype (r->res, i);
604       const struct variable *v = r->vmap[i];
605       union value *val = case_data_rw (c, v);
606
607       union value *val1 = NULL;
608
609       switch (type)
610         {
611         case INTERVALOID:
612         case TIMESTAMPTZOID:
613         case TIMETZOID:
614           if (i < r->vmapsize && var_get_dict_index(v) + 1 < dict_get_n_vars (r->dict))
615             {
616               const struct variable *v1 = NULL;
617               v1 = dict_get_var (r->dict, var_get_dict_index (v) + 1);
618
619               val1 = case_data_rw (c, v1);
620             }
621           break;
622         default:
623           break;
624         }
625
626
627       if (PQgetisnull (r->res, r->tuple, i))
628         {
629           value_set_missing (val, var_get_width (v));
630
631           switch (type)
632             {
633             case INTERVALOID:
634             case TIMESTAMPTZOID:
635             case TIMETZOID:
636               val1->f = SYSMIS;
637               break;
638             default:
639               break;
640             }
641         }
642       else
643         {
644           const uint8_t *vptr = (const uint8_t *) PQgetvalue (r->res, r->tuple, i);
645           int length = PQgetlength (r->res, r->tuple, i);
646
647           int var_width = var_get_width (v);
648           switch (type)
649             {
650             case BOOLOID:
651               {
652                 int8_t x;
653                 GET_VALUE (&vptr, x);
654                 val->f = x;
655               }
656               break;
657
658             case OIDOID:
659             case INT2OID:
660               {
661                 int16_t x;
662                 GET_VALUE (&vptr, x);
663                 val->f = x;
664               }
665               break;
666
667             case INT4OID:
668               {
669                 int32_t x;
670                 GET_VALUE (&vptr, x);
671                 val->f = x;
672               }
673               break;
674
675             case INT8OID:
676               {
677                 int64_t x;
678                 GET_VALUE (&vptr, x);
679                 val->f = x;
680               }
681               break;
682
683             case FLOAT4OID:
684               {
685                 float n;
686                 GET_VALUE (&vptr, n);
687                 val->f = n;
688               }
689               break;
690
691             case FLOAT8OID:
692               {
693                 double n;
694                 GET_VALUE (&vptr, n);
695                 val->f = n;
696               }
697               break;
698
699             case CASHOID:
700               {
701                 /* Postgres 8.3 uses 64 bits.
702                    Earlier versions use 32 */
703                 switch (length)
704                   {
705                   case 8:
706                     {
707                       int64_t x;
708                       GET_VALUE (&vptr, x);
709                       val->f = x / 100.0;
710                     }
711                     break;
712                   case 4:
713                     {
714                       int32_t x;
715                       GET_VALUE (&vptr, x);
716                       val->f = x / 100.0;
717                     }
718                     break;
719                   default:
720                     val->f = SYSMIS;
721                     break;
722                   }
723               }
724               break;
725
726             case INTERVALOID:
727               {
728                 if (r->integer_datetimes)
729                   {
730                     uint32_t months;
731                     uint32_t days;
732                     uint32_t us;
733                     uint32_t things;
734
735                     GET_VALUE (&vptr, things);
736                     GET_VALUE (&vptr, us);
737                     GET_VALUE (&vptr, days);
738                     GET_VALUE (&vptr, months);
739
740                     val->f = us / 1000000.0;
741                     val->f += days * 24 * 3600;
742
743                     val1->f = months;
744                   }
745                 else
746                   {
747                     uint32_t days, months;
748                     double seconds;
749
750                     GET_VALUE (&vptr, seconds);
751                     GET_VALUE (&vptr, days);
752                     GET_VALUE (&vptr, months);
753
754                     val->f = seconds;
755                     val->f += days * 24 * 3600;
756
757                     val1->f = months;
758                   }
759               }
760               break;
761
762             case DATEOID:
763               {
764                 int32_t x;
765
766                 GET_VALUE (&vptr, x);
767
768                 val->f = (x + r->postgres_epoch) * 24 * 3600 ;
769               }
770               break;
771
772             case TIMEOID:
773               {
774                 if (r->integer_datetimes)
775                   {
776                     uint64_t x;
777                     GET_VALUE (&vptr, x);
778                     val->f = x / 1000000.0;
779                   }
780                 else
781                   {
782                     double x;
783                     GET_VALUE (&vptr, x);
784                     val->f = x;
785                   }
786               }
787               break;
788
789             case TIMETZOID:
790               {
791                 int32_t zone;
792                 if (r->integer_datetimes)
793                   {
794                     uint64_t x;
795
796
797                     GET_VALUE (&vptr, x);
798                     val->f = x / 1000000.0;
799                   }
800                 else
801                   {
802                     double x;
803
804                     GET_VALUE (&vptr, x);
805                     val->f = x ;
806                   }
807
808                 GET_VALUE (&vptr, zone);
809                 val1->f = zone / 3600.0;
810               }
811               break;
812
813             case TIMESTAMPOID:
814             case TIMESTAMPTZOID:
815               {
816                 if (r->integer_datetimes)
817                   {
818                     int64_t x;
819
820                     GET_VALUE (&vptr, x);
821
822                     x /= 1000000;
823
824                     val->f = (x + r->postgres_epoch * 24 * 3600);
825                   }
826                 else
827                   {
828                     double x;
829
830                     GET_VALUE (&vptr, x);
831
832                     val->f = (x + r->postgres_epoch * 24 * 3600);
833                   }
834               }
835               break;
836             case TEXTOID:
837             case VARCHAROID:
838             case BPCHAROID:
839             case BYTEAOID:
840               memcpy (val->s, vptr, MIN (length, var_width));
841               break;
842
843             case NUMERICOID:
844               {
845                 double f = 0.0;
846                 int i;
847                 int16_t n_digits, weight, dscale;
848                 uint16_t sign;
849
850                 GET_VALUE (&vptr, n_digits);
851                 GET_VALUE (&vptr, weight);
852                 GET_VALUE (&vptr, sign);
853                 GET_VALUE (&vptr, dscale);
854
855 #if 0
856                 {
857                   struct fmt_spec fmt;
858                   fmt.d = dscale;
859                   fmt.type = FMT_E;
860                   fmt.w = fmt_max_output_width (fmt.type) ;
861                   fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
862                   var_set_both_formats (v, &fmt);
863                 }
864 #endif
865
866                 for (i = 0 ; i < n_digits;  ++i)
867                   {
868                     uint16_t x;
869                     GET_VALUE (&vptr, x);
870                     f += x * pow (10000, weight--);
871                   }
872
873                 if (sign == 0x4000)
874                   f *= -1.0;
875
876                 if (sign == 0xC000)
877                   val->f = SYSMIS;
878                 else
879                   val->f = f;
880               }
881               break;
882
883             default:
884               val->f = SYSMIS;
885               break;
886             }
887         }
888     }
889
890   r->tuple++;
891
892   return c;
893 }
894
895 #endif