better tests
[pspp] / src / data / psql-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2008, 2009, 2010, 2011, 2012 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "data/psql-reader.h"
20
21 #include <inttypes.h>
22 #include <math.h>
23 #include <stdlib.h>
24
25 #include "data/calendar.h"
26 #include "data/casereader-provider.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/variable.h"
30 #include "libpspp/i18n.h"
31 #include "libpspp/message.h"
32 #include "libpspp/misc.h"
33 #include "libpspp/str.h"
34
35 #include "gl/c-strcase.h"
36 #include "gl/minmax.h"
37 #include "gl/xalloc.h"
38
39 #include "gettext.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) (msgid)
42
43
44 #if !PSQL_SUPPORT
45 struct casereader *
46 psql_open_reader (struct psql_read_info *info UNUSED, struct dictionary **dict UNUSED)
47 {
48   msg (ME, _("Support for reading postgres databases was not compiled into this installation of PSPP"));
49
50   return NULL;
51 }
52
53 #else
54
55 #include <stdint.h>
56 #include <libpq-fe.h>
57
58
59 /* Default width of string variables. */
60 #define PSQL_DEFAULT_WIDTH 8
61
62 /* These macros  must be the same as in catalog/pg_types.h from the postgres source */
63 #define BOOLOID            16
64 #define BYTEAOID           17
65 #define CHAROID            18
66 #define NAMEOID            19
67 #define INT8OID            20
68 #define INT2OID            21
69 #define INT4OID            23
70 #define TEXTOID            25
71 #define OIDOID             26
72 #define FLOAT4OID          700
73 #define FLOAT8OID          701
74 #define CASHOID            790
75 #define BPCHAROID          1042
76 #define VARCHAROID         1043
77 #define DATEOID            1082
78 #define TIMEOID            1083
79 #define TIMESTAMPOID       1114
80 #define TIMESTAMPTZOID     1184
81 #define INTERVALOID        1186
82 #define TIMETZOID          1266
83 #define NUMERICOID         1700
84
85 static void psql_casereader_destroy (struct casereader *reader UNUSED, void *r_);
86
87 static struct ccase *psql_casereader_read (struct casereader *, void *);
88
89 static const struct casereader_class psql_casereader_class =
90   {
91     psql_casereader_read,
92     psql_casereader_destroy,
93     NULL,
94     NULL,
95   };
96
97 struct psql_reader
98 {
99   PGconn *conn;
100   PGresult *res;
101   int tuple;
102
103   bool integer_datetimes;
104
105   double postgres_epoch;
106
107   struct caseproto *proto;
108   struct dictionary *dict;
109
110   /* An array of ints, which maps psql column numbers into
111      pspp variables */
112   struct variable **vmap;
113   size_t vmapsize;
114
115   struct string fetch_cmd;
116   int cache_size;
117 };
118
119
120 static struct ccase *set_value (struct psql_reader *r);
121
122
123
124 #if WORDS_BIGENDIAN
125 static void
126 data_to_native (const void *in_, void *out_, int len)
127 {
128   int i;
129   const unsigned char *in = in_;
130   unsigned char *out = out_;
131   for (i = 0 ; i < len ; ++i)
132     out[i] = in[i];
133 }
134 #else
135 static void
136 data_to_native (const void *in_, void *out_, int len)
137 {
138   int i;
139   const unsigned char *in = in_;
140   unsigned char *out = out_;
141   for (i = 0 ; i < len ; ++i)
142     out[len - i - 1] = in[i];
143 }
144 #endif
145
146
147 #define GET_VALUE(IN, OUT) do { \
148     size_t sz = sizeof (OUT); \
149     data_to_native (*(IN), &(OUT), sz) ; \
150     (*IN) += sz; \
151 } while (false)
152
153
154 #if 0
155 static void
156 dump (const unsigned char *x, int l)
157 {
158   int i;
159
160   for (i = 0; i < l ; ++i)
161     {
162       printf ("%02x ", x[i]);
163     }
164
165   putchar ('\n');
166
167   for (i = 0; i < l ; ++i)
168     {
169       if (isprint (x[i]))
170         printf ("%c ", x[i]);
171       else
172         printf ("   ");
173     }
174
175   putchar ('\n');
176 }
177 #endif
178
179 static struct variable *
180 create_var (struct psql_reader *r, const struct fmt_spec *fmt,
181             int width, const char *suggested_name, int col)
182 {
183   unsigned long int vx = 0;
184   struct variable *var;
185   char *name;
186
187   name = dict_make_unique_var_name (r->dict, suggested_name, &vx);
188   var = dict_create_var (r->dict, name, width);
189   free (name);
190
191   var_set_both_formats (var, fmt);
192
193   if (col != -1)
194     {
195       r->vmap = xrealloc (r->vmap, (col + 1) * sizeof (*r->vmap));
196
197       r->vmap[col] = var;
198       r->vmapsize = col + 1;
199     }
200
201   return var;
202 }
203
204
205
206
207 /* Fill the cache */
208 static bool
209 reload_cache (struct psql_reader *r)
210 {
211   PQclear (r->res);
212   r->tuple = 0;
213
214   r->res = PQexec (r->conn, ds_cstr (&r->fetch_cmd));
215
216   if (PQresultStatus (r->res) != PGRES_TUPLES_OK || PQntuples (r->res) < 1)
217     {
218       PQclear (r->res);
219       r->res = NULL;
220       return false;
221     }
222
223   return true;
224 }
225
226
227 struct casereader *
228 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
229 {
230   int i;
231   int n_fields, n_tuples;
232   PGresult *qres = NULL;
233   casenumber n_cases = CASENUMBER_MAX;
234   const char *encoding;
235
236   struct psql_reader *r = XZALLOC (struct psql_reader);
237   struct string query ;
238
239   r->conn = PQconnectdb (info->conninfo);
240   if (NULL == r->conn)
241     {
242       msg (ME, _("Memory error whilst opening psql source"));
243       goto error;
244     }
245
246   if (PQstatus (r->conn) != CONNECTION_OK)
247     {
248       msg (ME, _("Error opening psql source: %s."),
249            PQerrorMessage (r->conn));
250
251       goto error;
252     }
253
254   {
255     int ver_num = 0;
256     const char *vers = PQparameterStatus (r->conn, "server_version");
257
258     sscanf (vers, "%d", &ver_num);
259
260     if (ver_num < 8)
261       {
262         msg (ME,
263              _("Postgres server is version %s."
264                " Reading from versions earlier than 8.0 is not supported."),
265              vers);
266
267         goto error;
268       }
269   }
270
271   {
272     const char *dt =  PQparameterStatus (r->conn, "integer_datetimes");
273
274     r->integer_datetimes = (0 == c_strcasecmp (dt, "on"));
275   }
276
277 #if USE_SSL
278   if (PQgetssl (r->conn) == NULL)
279 #endif
280     {
281       if (! info->allow_clear)
282         {
283           msg (ME, _("Connection is unencrypted, "
284                      "but unencrypted connections have not been permitted."));
285           goto error;
286         }
287     }
288
289   r->postgres_epoch = calendar_gregorian_to_offset (
290     2000, 1, 1, settings_get_fmt_settings (), NULL);
291
292   {
293     const int enc = PQclientEncoding (r->conn);
294
295     /* According to section 22.2 of the Postgresql manual
296        a value of zero (SQL_ASCII) indicates
297        "a declaration of ignorance about the encoding".
298        Accordingly, we use the default encoding
299        if we find this value.
300     */
301     encoding = enc ? pg_encoding_to_char (enc) : get_default_encoding ();
302
303     /* Create the dictionary and populate it */
304     *dict = r->dict = dict_create (encoding);
305   }
306
307   const int version = PQserverVersion (r->conn);
308   ds_init_empty (&query);
309   /*
310     Versions before 9.1 don't have the REPEATABLE READ isolation level.
311     However according to <a12321aabb@gmail.com> if the server is in the
312     "hot standby" mode then SERIALIZABLE won't work.
313    */
314   ds_put_c_format (&query,
315                    "BEGIN READ ONLY ISOLATION LEVEL %s; "
316                    "DECLARE  pspp BINARY CURSOR FOR ",
317                    (version < 90100) ? "SERIALIZABLE" : "REPEATABLE READ");
318
319   ds_put_substring (&query, info->sql.ss);
320
321   qres = PQexec (r->conn, ds_cstr (&query));
322   ds_destroy (&query);
323   if (PQresultStatus (qres) != PGRES_COMMAND_OK)
324     {
325       msg (ME, _("Error from psql source: %s."),
326            PQresultErrorMessage (qres));
327       goto error;
328     }
329
330   PQclear (qres);
331
332
333   /* Now use the count() function to find the total number of cases
334      that this query returns.
335      Doing this incurs some overhead.  The server has to iterate every
336      case in order to find this number.  However, it's performed on the
337      server side, and in all except the most huge databases the extra
338      overhead will be worth the effort.
339      On the other hand, most PSPP functions don't need to know this.
340      The GUI is the notable exception.
341   */
342   ds_init_cstr (&query, "SELECT count (*) FROM (");
343   ds_put_substring (&query, info->sql.ss);
344   ds_put_cstr (&query, ") stupid_sql_standard");
345
346   qres = PQexec (r->conn, ds_cstr (&query));
347   ds_destroy (&query);
348   if (PQresultStatus (qres) != PGRES_TUPLES_OK)
349     {
350       msg (ME, _("Error from psql source: %s."),
351            PQresultErrorMessage (qres));
352       goto error;
353     }
354   n_cases = atol (PQgetvalue (qres, 0, 0));
355   PQclear (qres);
356
357   qres = PQexec (r->conn, "FETCH FIRST FROM pspp");
358   if (PQresultStatus (qres) != PGRES_TUPLES_OK)
359     {
360       msg (ME, _("Error from psql source: %s."),
361            PQresultErrorMessage (qres));
362       goto error;
363     }
364
365   n_tuples = PQntuples (qres);
366   n_fields = PQnfields (qres);
367
368   r->proto = NULL;
369   r->vmap = NULL;
370   r->vmapsize = 0;
371
372   for (i = 0 ; i < n_fields ; ++i)
373     {
374       struct variable *var;
375       struct fmt_spec fmt = { .type = FMT_F, .w = 8, .d = 2 };
376       Oid type = PQftype (qres, i);
377       int width = 0;
378       int length ;
379
380       /* If there are no data then make a finger in the air
381          guess at the contents */
382       if (n_tuples > 0)
383         length = PQgetlength (qres, 0, i);
384       else
385         length = PSQL_DEFAULT_WIDTH;
386
387       switch (type)
388         {
389         case BOOLOID:
390         case OIDOID:
391         case INT2OID:
392         case INT4OID:
393         case INT8OID:
394         case FLOAT4OID:
395         case FLOAT8OID:
396           fmt.type = FMT_F;
397           break;
398         case CASHOID:
399           fmt.type = FMT_DOLLAR;
400           break;
401         case CHAROID:
402           fmt.type = FMT_A;
403           width = length > 0 ? length : 1;
404           fmt.d = 0;
405           fmt.w = 1;
406           break;
407         case TEXTOID:
408         case VARCHAROID:
409         case BPCHAROID:
410           fmt.type = FMT_A;
411           width = (info->str_width == -1) ?
412             ROUND_UP (length, PSQL_DEFAULT_WIDTH) : info->str_width;
413           fmt.w = width;
414           fmt.d = 0;
415           break;
416         case BYTEAOID:
417           fmt.type = FMT_AHEX;
418           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
419           fmt.w = width * 2;
420           fmt.d = 0;
421           break;
422         case INTERVALOID:
423           fmt.type = FMT_DTIME;
424           width = 0;
425           fmt.d = 0;
426           fmt.w = 13;
427           break;
428         case DATEOID:
429           fmt.type = FMT_DATE;
430           width = 0;
431           fmt.w = 11;
432           fmt.d = 0;
433           break;
434         case TIMEOID:
435         case TIMETZOID:
436           fmt.type = FMT_TIME;
437           width = 0;
438           fmt.w = 11;
439           fmt.d = 0;
440           break;
441         case TIMESTAMPOID:
442         case TIMESTAMPTZOID:
443           fmt.type = FMT_DATETIME;
444           fmt.d = 0;
445           fmt.w = 22;
446           width = 0;
447           break;
448         case NUMERICOID:
449           fmt.type = FMT_E;
450           fmt.d = 2;
451           fmt.w = 40;
452           width = 0;
453           break;
454         default:
455           msg (MW, _("Unsupported OID %d.  SYSMIS values will be inserted."), type);
456           fmt.type = FMT_A;
457           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
458           fmt.w = width ;
459           fmt.d = 0;
460           break;
461         }
462
463       if (width == 0 && fmt_is_string (fmt.type))
464         fmt.w = width = PSQL_DEFAULT_WIDTH;
465
466
467       var = create_var (r, &fmt, width, PQfname (qres, i), i);
468       if (type == NUMERICOID && n_tuples > 0)
469         {
470           const uint8_t *vptr = (const uint8_t *) PQgetvalue (qres, 0, i);
471           struct fmt_spec fmt;
472           int16_t n_digits, weight, dscale;
473           uint16_t sign;
474
475           GET_VALUE (&vptr, n_digits);
476           GET_VALUE (&vptr, weight);
477           GET_VALUE (&vptr, sign);
478           GET_VALUE (&vptr, dscale);
479
480           fmt.d = dscale;
481           fmt.type = FMT_E;
482           fmt.w = fmt_max_output_width (fmt.type) ;
483           fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
484           var_set_both_formats (var, &fmt);
485         }
486
487       /* Timezones need an extra variable */
488       switch (type)
489         {
490         case TIMETZOID:
491           {
492             struct string name;
493             ds_init_cstr (&name, var_get_name (var));
494             ds_put_cstr (&name, "-zone");
495             fmt.type = FMT_F;
496             fmt.w = 8;
497             fmt.d = 2;
498
499             create_var (r, &fmt, 0, ds_cstr (&name), -1);
500
501             ds_destroy (&name);
502           }
503           break;
504
505         case INTERVALOID:
506           {
507             struct string name;
508             ds_init_cstr (&name, var_get_name (var));
509             ds_put_cstr (&name, "-months");
510             fmt.type = FMT_F;
511             fmt.w = 3;
512             fmt.d = 0;
513
514             create_var (r, &fmt, 0, ds_cstr (&name), -1);
515
516             ds_destroy (&name);
517           }
518         default:
519           break;
520         }
521     }
522
523   PQclear (qres);
524
525   qres = PQexec (r->conn, "MOVE BACKWARD 1 FROM pspp");
526   if (PQresultStatus (qres) != PGRES_COMMAND_OK)
527     {
528       PQclear (qres);
529       goto error;
530     }
531   PQclear (qres);
532
533   r->cache_size = info->bsize != -1 ? info->bsize: 4096;
534
535   ds_init_empty (&r->fetch_cmd);
536   ds_put_format (&r->fetch_cmd,  "FETCH FORWARD %d FROM pspp", r->cache_size);
537
538   reload_cache (r);
539   r->proto = caseproto_ref (dict_get_proto (*dict));
540
541   return casereader_create_sequential
542     (NULL,
543      r->proto,
544      n_cases,
545      &psql_casereader_class, r);
546
547  error:
548   dict_unref (*dict);
549
550   psql_casereader_destroy (NULL, r);
551   return NULL;
552 }
553
554
555 static void
556 psql_casereader_destroy (struct casereader *reader UNUSED, void *r_)
557 {
558   struct psql_reader *r = r_;
559   if (r == NULL)
560     return ;
561
562   ds_destroy (&r->fetch_cmd);
563   free (r->vmap);
564   if (r->res) PQclear (r->res);
565   PQfinish (r->conn);
566   caseproto_unref (r->proto);
567
568   free (r);
569 }
570
571
572
573 static struct ccase *
574 psql_casereader_read (struct casereader *reader UNUSED, void *r_)
575 {
576   struct psql_reader *r = r_;
577
578   if (NULL == r->res || r->tuple >= r->cache_size)
579     {
580       if (! reload_cache (r))
581         return false;
582     }
583
584   return set_value (r);
585 }
586
587 static struct ccase *
588 set_value (struct psql_reader *r)
589 {
590   struct ccase *c;
591   int n_vars;
592   int i;
593
594   assert (r->res);
595
596   n_vars = PQnfields (r->res);
597
598   if (r->tuple >= PQntuples (r->res))
599     return NULL;
600
601   c = case_create (r->proto);
602   case_set_missing (c);
603
604
605   for (i = 0 ; i < n_vars ; ++i)
606     {
607       Oid type = PQftype (r->res, i);
608       const struct variable *v = r->vmap[i];
609       union value *val = case_data_rw (c, v);
610
611       union value *val1 = NULL;
612
613       switch (type)
614         {
615         case INTERVALOID:
616         case TIMESTAMPTZOID:
617         case TIMETZOID:
618           if (i < r->vmapsize && var_get_dict_index(v) + 1 < dict_get_n_vars (r->dict))
619             {
620               const struct variable *v1 = NULL;
621               v1 = dict_get_var (r->dict, var_get_dict_index (v) + 1);
622
623               val1 = case_data_rw (c, v1);
624             }
625           break;
626         default:
627           break;
628         }
629
630
631       if (PQgetisnull (r->res, r->tuple, i))
632         {
633           value_set_missing (val, var_get_width (v));
634
635           switch (type)
636             {
637             case INTERVALOID:
638             case TIMESTAMPTZOID:
639             case TIMETZOID:
640               val1->f = SYSMIS;
641               break;
642             default:
643               break;
644             }
645         }
646       else
647         {
648           const uint8_t *vptr = (const uint8_t *) PQgetvalue (r->res, r->tuple, i);
649           int length = PQgetlength (r->res, r->tuple, i);
650
651           int var_width = var_get_width (v);
652           switch (type)
653             {
654             case BOOLOID:
655               {
656                 int8_t x;
657                 GET_VALUE (&vptr, x);
658                 val->f = x;
659               }
660               break;
661
662             case OIDOID:
663             case INT2OID:
664               {
665                 int16_t x;
666                 GET_VALUE (&vptr, x);
667                 val->f = x;
668               }
669               break;
670
671             case INT4OID:
672               {
673                 int32_t x;
674                 GET_VALUE (&vptr, x);
675                 val->f = x;
676               }
677               break;
678
679             case INT8OID:
680               {
681                 int64_t x;
682                 GET_VALUE (&vptr, x);
683                 val->f = x;
684               }
685               break;
686
687             case FLOAT4OID:
688               {
689                 float n;
690                 GET_VALUE (&vptr, n);
691                 val->f = n;
692               }
693               break;
694
695             case FLOAT8OID:
696               {
697                 double n;
698                 GET_VALUE (&vptr, n);
699                 val->f = n;
700               }
701               break;
702
703             case CASHOID:
704               {
705                 /* Postgres 8.3 uses 64 bits.
706                    Earlier versions use 32 */
707                 switch (length)
708                   {
709                   case 8:
710                     {
711                       int64_t x;
712                       GET_VALUE (&vptr, x);
713                       val->f = x / 100.0;
714                     }
715                     break;
716                   case 4:
717                     {
718                       int32_t x;
719                       GET_VALUE (&vptr, x);
720                       val->f = x / 100.0;
721                     }
722                     break;
723                   default:
724                     val->f = SYSMIS;
725                     break;
726                   }
727               }
728               break;
729
730             case INTERVALOID:
731               {
732                 if (r->integer_datetimes)
733                   {
734                     uint32_t months;
735                     uint32_t days;
736                     uint32_t us;
737                     uint32_t things;
738
739                     GET_VALUE (&vptr, things);
740                     GET_VALUE (&vptr, us);
741                     GET_VALUE (&vptr, days);
742                     GET_VALUE (&vptr, months);
743
744                     val->f = us / 1000000.0;
745                     val->f += days * 24 * 3600;
746
747                     val1->f = months;
748                   }
749                 else
750                   {
751                     uint32_t days, months;
752                     double seconds;
753
754                     GET_VALUE (&vptr, seconds);
755                     GET_VALUE (&vptr, days);
756                     GET_VALUE (&vptr, months);
757
758                     val->f = seconds;
759                     val->f += days * 24 * 3600;
760
761                     val1->f = months;
762                   }
763               }
764               break;
765
766             case DATEOID:
767               {
768                 int32_t x;
769
770                 GET_VALUE (&vptr, x);
771
772                 val->f = (x + r->postgres_epoch) * 24 * 3600 ;
773               }
774               break;
775
776             case TIMEOID:
777               {
778                 if (r->integer_datetimes)
779                   {
780                     uint64_t x;
781                     GET_VALUE (&vptr, x);
782                     val->f = x / 1000000.0;
783                   }
784                 else
785                   {
786                     double x;
787                     GET_VALUE (&vptr, x);
788                     val->f = x;
789                   }
790               }
791               break;
792
793             case TIMETZOID:
794               {
795                 int32_t zone;
796                 if (r->integer_datetimes)
797                   {
798                     uint64_t x;
799
800
801                     GET_VALUE (&vptr, x);
802                     val->f = x / 1000000.0;
803                   }
804                 else
805                   {
806                     double x;
807
808                     GET_VALUE (&vptr, x);
809                     val->f = x ;
810                   }
811
812                 GET_VALUE (&vptr, zone);
813                 val1->f = zone / 3600.0;
814               }
815               break;
816
817             case TIMESTAMPOID:
818             case TIMESTAMPTZOID:
819               {
820                 if (r->integer_datetimes)
821                   {
822                     int64_t x;
823
824                     GET_VALUE (&vptr, x);
825
826                     x /= 1000000;
827
828                     val->f = (x + r->postgres_epoch * 24 * 3600);
829                   }
830                 else
831                   {
832                     double x;
833
834                     GET_VALUE (&vptr, x);
835
836                     val->f = (x + r->postgres_epoch * 24 * 3600);
837                   }
838               }
839               break;
840             case TEXTOID:
841             case VARCHAROID:
842             case BPCHAROID:
843             case BYTEAOID:
844               memcpy (val->s, vptr, MIN (length, var_width));
845               break;
846
847             case NUMERICOID:
848               {
849                 double f = 0.0;
850                 int i;
851                 int16_t n_digits, weight, dscale;
852                 uint16_t sign;
853
854                 GET_VALUE (&vptr, n_digits);
855                 GET_VALUE (&vptr, weight);
856                 GET_VALUE (&vptr, sign);
857                 GET_VALUE (&vptr, dscale);
858
859 #if 0
860                 {
861                   struct fmt_spec fmt;
862                   fmt.d = dscale;
863                   fmt.type = FMT_E;
864                   fmt.w = fmt_max_output_width (fmt.type) ;
865                   fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
866                   var_set_both_formats (v, &fmt);
867                 }
868 #endif
869
870                 for (i = 0 ; i < n_digits;  ++i)
871                   {
872                     uint16_t x;
873                     GET_VALUE (&vptr, x);
874                     f += x * pow (10000, weight--);
875                   }
876
877                 if (sign == 0x4000)
878                   f *= -1.0;
879
880                 if (sign == 0xC000)
881                   val->f = SYSMIS;
882                 else
883                   val->f = f;
884               }
885               break;
886
887             default:
888               val->f = SYSMIS;
889               break;
890             }
891         }
892     }
893
894   r->tuple++;
895
896   return c;
897 }
898
899 #endif