85e777a991b9218bfbf2c5f17555a9972093bd5e
[pspp-builds.git] / src / data / psql-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2008, 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/casereader-provider.h>
20 #include <libpspp/message.h>
21 #include <gl/xalloc.h>
22 #include <data/dictionary.h>
23 #include <stdlib.h>
24
25 #include "psql-reader.h"
26 #include "variable.h"
27 #include "format.h"
28 #include "calendar.h"
29
30 #include <inttypes.h>
31 #include <libpspp/str.h>
32
33 #include "gettext.h"
34 #define _(msgid) gettext (msgid)
35 #define N_(msgid) (msgid)
36
37
38 #if !PSQL_SUPPORT
39 struct casereader *
40 psql_open_reader (struct psql_read_info *info UNUSED, struct dictionary **dict UNUSED)
41 {
42   msg (ME, _("Support for reading postgres databases was not compiled into this installation of PSPP"));
43
44   return NULL;
45 }
46
47 #else
48
49 #include <stdint.h>
50 #include <libpq-fe.h>
51
52
53 /* These macros  must be the same as in catalog/pg_types.h from the postgres source */
54 #define BOOLOID            16
55 #define BYTEAOID           17
56 #define CHAROID            18
57 #define NAMEOID            19
58 #define INT8OID            20
59 #define INT2OID            21
60 #define INT4OID            23
61 #define TEXTOID            25
62 #define OIDOID             26
63 #define FLOAT4OID          700
64 #define FLOAT8OID          701
65 #define CASHOID            790
66 #define BPCHAROID          1042
67 #define VARCHAROID         1043
68 #define DATEOID            1082
69 #define TIMEOID            1083
70 #define TIMESTAMPOID       1114
71 #define TIMESTAMPTZOID     1184
72 #define INTERVALOID        1186
73 #define TIMETZOID          1266
74 #define NUMERICOID         1700
75
76 static void psql_casereader_destroy (struct casereader *reader UNUSED, void *r_);
77
78 static struct ccase *psql_casereader_read (struct casereader *, void *);
79
80 static const struct casereader_class psql_casereader_class =
81   {
82     psql_casereader_read,
83     psql_casereader_destroy,
84     NULL,
85     NULL,
86   };
87
88 struct psql_reader
89 {
90   PGconn *conn;
91   PGresult *res;
92   int tuple;
93
94   bool integer_datetimes;
95
96   double postgres_epoch;
97
98   size_t value_cnt;
99   struct dictionary *dict;
100
101   /* An array of ints, which maps psql column numbers into
102      pspp variables */
103   struct variable **vmap;
104   size_t vmapsize;
105
106   struct string fetch_cmd;
107   int cache_size;
108 };
109
110
111 static struct ccase *set_value (struct psql_reader *r);
112
113
114
115 #if WORDS_BIGENDIAN
116 static void
117 data_to_native (const void *in_, void *out_, int len)
118 {
119   int i;
120   const unsigned char *in = in_;
121   unsigned char *out = out_;
122   for (i = 0 ; i < len ; ++i )
123     out[i] = in[i];
124 }
125 #else
126 static void
127 data_to_native (const void *in_, void *out_, int len)
128 {
129   int i;
130   const unsigned char *in = in_;
131   unsigned char *out = out_;
132   for (i = 0 ; i < len ; ++i )
133     out[len - i - 1] = in[i];
134 }
135 #endif
136
137
138 #define GET_VALUE(IN, OUT) do { \
139     size_t sz = sizeof (OUT); \
140     data_to_native (*(IN), &(OUT), sz) ; \
141     (*IN) += sz; \
142 } while (false)
143
144
145 #if 0
146 static void
147 dump (const unsigned char *x, int l)
148 {
149   int i;
150
151   for (i = 0; i < l ; ++i)
152     {
153       printf ("%02x ", x[i]);
154     }
155
156   putchar ('\n');
157
158   for (i = 0; i < l ; ++i)
159     {
160       if ( isprint (x[i]))
161         printf ("%c ", x[i]);
162       else
163         printf ("   ");
164     }
165
166   putchar ('\n');
167 }
168 #endif
169
170 static struct variable *
171 create_var (struct psql_reader *r, const struct fmt_spec *fmt,
172             int width, const char *suggested_name, int col)
173 {
174   unsigned long int vx = 0;
175   struct variable *var;
176   char name[VAR_NAME_LEN + 1];
177
178   r->value_cnt += value_cnt_from_width (width);
179
180   if ( ! dict_make_unique_var_name (r->dict, suggested_name, &vx, name))
181     {
182       msg (ME, _("Cannot create variable name from %s"), suggested_name);
183       return NULL;
184     }
185
186   var = dict_create_var (r->dict, name, width);
187   var_set_both_formats (var, fmt);
188
189   if ( col != -1)
190     {
191       r->vmap = xrealloc (r->vmap, (col + 1) * sizeof (*r->vmap));
192
193       r->vmap[col] = var;
194       r->vmapsize = col + 1;
195     }
196
197   return var;
198 }
199
200
201
202
203 /* Fill the cache */
204 static bool
205 reload_cache (struct psql_reader *r)
206 {
207   PQclear (r->res);
208   r->tuple = 0;
209
210   r->res = PQexec (r->conn, ds_cstr (&r->fetch_cmd));
211
212   if (PQresultStatus (r->res) != PGRES_TUPLES_OK || PQntuples (r->res) < 1)
213     {
214       PQclear (r->res);
215       r->res = NULL;
216       return false;
217     }
218
219   return true;
220 }
221
222
223 struct casereader *
224 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
225 {
226   int i;
227   int n_fields, n_tuples;
228   PGresult *qres = NULL;
229   casenumber n_cases = CASENUMBER_MAX;
230
231   struct psql_reader *r = xzalloc (sizeof *r);
232   struct string query ;
233
234   r->conn = PQconnectdb (info->conninfo);
235   if ( NULL == r->conn)
236     {
237       msg (ME, _("Memory error whilst opening psql source"));
238       goto error;
239     }
240
241   if ( PQstatus (r->conn) != CONNECTION_OK )
242     {
243       msg (ME, _("Error opening psql source: %s."),
244            PQerrorMessage (r->conn));
245
246       goto error;
247     }
248
249   {
250     int ver_num;
251     const char *vers = PQparameterStatus (r->conn, "server_version");
252
253     sscanf (vers, "%d", &ver_num);
254
255     if ( ver_num < 8)
256       {
257         msg (ME,
258              _("Postgres server is version %s."
259                " Reading from versions earlier than 8.0 is not supported."),
260              vers);
261
262         goto error;
263       }
264   }
265
266   {
267     const char *dt =  PQparameterStatus (r->conn, "integer_datetimes");
268
269     r->integer_datetimes = ( 0 == strcasecmp (dt, "on"));
270   }
271
272 #if USE_SSL
273   if ( PQgetssl (r->conn) == NULL)
274 #endif
275     {
276       if (! info->allow_clear)
277         {
278           msg (ME, _("Connection is unencrypted, "
279                      "but unencrypted connections have not been permitted."));
280           goto error;
281         }
282     }
283
284   r->postgres_epoch =
285     calendar_gregorian_to_offset (2000, 1, 1, NULL, NULL);
286
287
288   /* Create the dictionary and populate it */
289   *dict = r->dict = dict_create ();
290
291   {
292     const int enc = PQclientEncoding (r->conn);
293
294     /* According to section 22.2 of the Postgresql manual
295        a value of zero (SQL_ASCII) indicates
296        "a declaration of ignorance about the encoding".
297        Accordingly, we don't set the dictionary's encoding
298        if we find this value.
299     */
300     if ( enc != 0 )
301       dict_set_encoding (r->dict, pg_encoding_to_char (enc));
302   }
303
304   /*
305     select count (*) from (select * from medium) stupid_sql_standard;
306   */
307   ds_init_cstr (&query,
308                 "BEGIN READ ONLY ISOLATION LEVEL SERIALIZABLE; "
309                 "DECLARE  pspp BINARY CURSOR FOR ");
310
311   ds_put_substring (&query, info->sql.ss);
312
313   qres = PQexec (r->conn, ds_cstr (&query));
314   ds_destroy (&query);
315   if ( PQresultStatus (qres) != PGRES_COMMAND_OK )
316     {
317       msg (ME, _("Error from psql source: %s."),
318            PQresultErrorMessage (qres));
319       goto error;
320     }
321
322   PQclear (qres);
323
324
325   /* Now use the count() function to find the total number of cases
326      that this query returns.
327      Doing this incurs some overhead.  The server has to iterate every
328      case in order to find this number.  However, it's performed on the
329      server side, and in all except the most huge databases the extra
330      overhead will be worth the effort.
331      On the other hand, most PSPP functions don't need to know this.
332      The GUI is the notable exception.
333   */
334   ds_init_cstr (&query, "SELECT count (*) FROM (");
335   ds_put_substring (&query, info->sql.ss);
336   ds_put_cstr (&query, ") stupid_sql_standard");
337
338   qres = PQexec (r->conn, ds_cstr (&query));
339   ds_destroy (&query);
340   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
341     {
342       msg (ME, _("Error from psql source: %s."),
343            PQresultErrorMessage (qres));
344       goto error;
345     }
346   n_cases = atol (PQgetvalue (qres, 0, 0));
347   PQclear (qres);
348
349   qres = PQexec (r->conn, "FETCH FIRST FROM pspp");
350   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
351     {
352       msg (ME, _("Error from psql source: %s."),
353            PQresultErrorMessage (qres));
354       goto error;
355     }
356
357   n_tuples = PQntuples (qres);
358   n_fields = PQnfields (qres);
359
360   r->value_cnt = 0;
361   r->vmap = NULL;
362   r->vmapsize = 0;
363
364   for (i = 0 ; i < n_fields ; ++i )
365     {
366       struct variable *var;
367       struct fmt_spec fmt = {FMT_F, 8, 2};
368       Oid type = PQftype (qres, i);
369       int width = 0;
370       int length ;
371
372       /* If there are no data then make a finger in the air 
373          guess at the contents */
374       if ( n_tuples > 0 )
375         length = PQgetlength (qres, 0, i);
376       else 
377         length = MAX_SHORT_STRING;
378
379       switch (type)
380         {
381         case BOOLOID:
382         case OIDOID:
383         case INT2OID:
384         case INT4OID:
385         case INT8OID:
386         case FLOAT4OID:
387         case FLOAT8OID:
388           fmt.type = FMT_F;
389           break;
390         case CASHOID:
391           fmt.type = FMT_DOLLAR;
392           break;
393         case CHAROID:
394           fmt.type = FMT_A;
395           width = length > 0 ? length : 1;
396           fmt.d = 0;
397           fmt.w = 1;
398           break;
399         case TEXTOID:
400         case VARCHAROID:
401         case BPCHAROID:
402           fmt.type = FMT_A;
403           width = (info->str_width == -1) ?
404             ROUND_UP (length, MAX_SHORT_STRING) : info->str_width;
405           fmt.w = width;
406           fmt.d = 0;
407           break;
408         case BYTEAOID:
409           fmt.type = FMT_AHEX;
410           width = length > 0 ? length : MAX_SHORT_STRING;
411           fmt.w = width * 2;
412           fmt.d = 0;
413           break;
414         case INTERVALOID:
415           fmt.type = FMT_DTIME;
416           width = 0;
417           fmt.d = 0;
418           fmt.w = 13;
419           break;
420         case DATEOID:
421           fmt.type = FMT_DATE;
422           width = 0;
423           fmt.w = 11;
424           fmt.d = 0;
425           break;
426         case TIMEOID:
427         case TIMETZOID:
428           fmt.type = FMT_TIME;
429           width = 0;
430           fmt.w = 11;
431           fmt.d = 0;
432           break;
433         case TIMESTAMPOID:
434         case TIMESTAMPTZOID:
435           fmt.type = FMT_DATETIME;
436           fmt.d = 0;
437           fmt.w = 22;
438           width = 0;
439           break;
440         case NUMERICOID:
441           fmt.type = FMT_E;
442           fmt.d = 2;
443           fmt.w = 40;
444           width = 0;
445           break;
446         default:
447           msg (MW, _("Unsupported OID %d.  SYSMIS values will be inserted."), type);
448           fmt.type = FMT_A;
449           width = length > 0 ? length : MAX_SHORT_STRING;
450           fmt.w = width ;
451           fmt.d = 0;
452           break;
453         }
454
455       if ( width == 0 && fmt_is_string (fmt.type))
456         fmt.w = width = MAX_SHORT_STRING;
457
458
459       var = create_var (r, &fmt, width, PQfname (qres, i), i);
460       if ( type == NUMERICOID && n_tuples > 0)
461         {
462           const uint8_t *vptr = (const uint8_t *) PQgetvalue (qres, 0, i);
463           struct fmt_spec fmt;
464           int16_t n_digits, weight, dscale;
465           uint16_t sign;
466
467           GET_VALUE (&vptr, n_digits);
468           GET_VALUE (&vptr, weight);
469           GET_VALUE (&vptr, sign);
470           GET_VALUE (&vptr, dscale);
471
472           fmt.d = dscale;
473           fmt.type = FMT_E;
474           fmt.w = fmt_max_output_width (fmt.type) ;
475           fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
476           var_set_both_formats (var, &fmt);
477         }
478
479       /* Timezones need an extra variable */
480       switch (type)
481         {
482         case TIMETZOID:
483           {
484             struct string name;
485             ds_init_cstr (&name, var_get_name (var));
486             ds_put_cstr (&name, "-zone");
487             fmt.type = FMT_F;
488             fmt.w = 8;
489             fmt.d = 2;
490
491             create_var (r, &fmt, 0, ds_cstr (&name), -1);
492
493             ds_destroy (&name);
494           }
495           break;
496
497         case INTERVALOID:
498           {
499             struct string name;
500             ds_init_cstr (&name, var_get_name (var));
501             ds_put_cstr (&name, "-months");
502             fmt.type = FMT_F;
503             fmt.w = 3;
504             fmt.d = 0;
505
506             create_var (r, &fmt, 0, ds_cstr (&name), -1);
507
508             ds_destroy (&name);
509           }
510         default:
511           break;
512         }
513     }
514
515   PQclear (qres);
516
517   qres = PQexec (r->conn, "MOVE BACKWARD 1 FROM pspp");
518   if ( PQresultStatus (qres) != PGRES_COMMAND_OK)
519     {
520       PQclear (qres);
521       goto error;
522     }
523   PQclear (qres);
524
525   r->cache_size = info->bsize != -1 ? info->bsize: 4096;
526
527   ds_init_empty (&r->fetch_cmd);
528   ds_put_format (&r->fetch_cmd,  "FETCH FORWARD %d FROM pspp", r->cache_size);
529
530   reload_cache (r);
531
532   return casereader_create_sequential
533     (NULL,
534      r->value_cnt,
535      n_cases,
536      &psql_casereader_class, r);
537
538  error:
539   dict_destroy (*dict);
540
541   psql_casereader_destroy (NULL, r);
542   return NULL;
543 }
544
545
546 static void
547 psql_casereader_destroy (struct casereader *reader UNUSED, void *r_)
548 {
549   struct psql_reader *r = r_;
550   if (r == NULL)
551     return ;
552
553   ds_destroy (&r->fetch_cmd);
554   free (r->vmap);
555   if (r->res) PQclear (r->res);
556   PQfinish (r->conn);
557
558   free (r);
559 }
560
561
562
563 static struct ccase *
564 psql_casereader_read (struct casereader *reader UNUSED, void *r_)
565 {
566   struct psql_reader *r = r_;
567
568   if ( NULL == r->res || r->tuple >= r->cache_size)
569     {
570       if ( ! reload_cache (r) )
571         return false;
572     }
573
574   return set_value (r);
575 }
576
577 static struct ccase *
578 set_value (struct psql_reader *r)
579 {
580   struct ccase *c;
581   int n_vars;
582   int i;
583
584   assert (r->res);
585
586   n_vars = PQnfields (r->res);
587
588   if ( r->tuple >= PQntuples (r->res))
589     return NULL;
590
591   c = case_create (r->value_cnt);
592   memset (case_data_rw_idx (c, 0)->s, ' ', MAX_SHORT_STRING * r->value_cnt);
593
594
595   for (i = 0 ; i < n_vars ; ++i )
596     {
597       Oid type = PQftype (r->res, i);
598       const struct variable *v = r->vmap[i];
599       union value *val = case_data_rw (c, v);
600
601       union value *val1 = NULL;
602
603       switch (type)
604         {
605         case INTERVALOID:
606         case TIMESTAMPTZOID:
607         case TIMETZOID:
608           if (i < r->vmapsize && var_get_dict_index(v) + 1 < dict_get_var_cnt (r->dict))
609             {
610               const struct variable *v1 = NULL;
611               v1 = dict_get_var (r->dict, var_get_dict_index (v) + 1);
612
613               val1 = case_data_rw (c, v1);
614             }
615           break;
616         default:
617           break;
618         }
619
620
621       if (PQgetisnull (r->res, r->tuple, i))
622         {
623           value_set_missing (val, var_get_width (v));
624
625           switch (type)
626             {
627             case INTERVALOID:
628             case TIMESTAMPTZOID:
629             case TIMETZOID:
630               val1->f = SYSMIS;
631               break;
632             default:
633               break;
634             }
635         }
636       else
637         {
638           const uint8_t *vptr = (const uint8_t *) PQgetvalue (r->res, r->tuple, i);
639           int length = PQgetlength (r->res, r->tuple, i);
640
641           int var_width = var_get_width (v);
642           switch (type)
643             {
644             case BOOLOID:
645               {
646                 int8_t x;
647                 GET_VALUE (&vptr, x);
648                 val->f = x;
649               }
650               break;
651
652             case OIDOID:
653             case INT2OID:
654               {
655                 int16_t x;
656                 GET_VALUE (&vptr, x);
657                 val->f = x;
658               }
659               break;
660
661             case INT4OID:
662               {
663                 int32_t x;
664                 GET_VALUE (&vptr, x);
665                 val->f = x;
666               }
667               break;
668
669             case INT8OID:
670               {
671                 int64_t x;
672                 GET_VALUE (&vptr, x);
673                 val->f = x;
674               }
675               break;
676
677             case FLOAT4OID:
678               {
679                 float n;
680                 GET_VALUE (&vptr, n);
681                 val->f = n;
682               }
683               break;
684
685             case FLOAT8OID:
686               {
687                 double n;
688                 GET_VALUE (&vptr, n);
689                 val->f = n;
690               }
691               break;
692
693             case CASHOID:
694               {
695                 /* Postgres 8.3 uses 64 bits.
696                    Earlier versions use 32 */
697                 switch (length)
698                   {
699                   case 8:
700                     {
701                       int64_t x;
702                       GET_VALUE (&vptr, x);
703                       val->f = x / 100.0;
704                     }
705                     break;
706                   case 4:
707                     {
708                       int32_t x;
709                       GET_VALUE (&vptr, x);
710                       val->f = x / 100.0;
711                     }
712                     break;
713                   default:
714                     val->f = SYSMIS;
715                     break;
716                   }
717               }
718               break;
719
720             case INTERVALOID:
721               {
722                 if ( r->integer_datetimes )
723                   {
724                     uint32_t months;
725                     uint32_t days;
726                     uint32_t us;
727                     uint32_t things;
728
729                     GET_VALUE (&vptr, things);
730                     GET_VALUE (&vptr, us);
731                     GET_VALUE (&vptr, days);
732                     GET_VALUE (&vptr, months);
733
734                     val->f = us / 1000000.0;
735                     val->f += days * 24 * 3600;
736
737                     val1->f = months;
738                   }
739                 else
740                   {
741                     uint32_t days, months;
742                     double seconds;
743
744                     GET_VALUE (&vptr, seconds);
745                     GET_VALUE (&vptr, days);
746                     GET_VALUE (&vptr, months);
747
748                     val->f = seconds;
749                     val->f += days * 24 * 3600;
750
751                     val1->f = months;
752                   }
753               }
754               break;
755
756             case DATEOID:
757               {
758                 int32_t x;
759
760                 GET_VALUE (&vptr, x);
761
762                 val->f = (x + r->postgres_epoch) * 24 * 3600 ;
763               }
764               break;
765
766             case TIMEOID:
767               {
768                 if ( r->integer_datetimes)
769                   {
770                     uint64_t x;
771                     GET_VALUE (&vptr, x);
772                     val->f = x / 1000000.0;
773                   }
774                 else
775                   {
776                     double x;
777                     GET_VALUE (&vptr, x);
778                     val->f = x;
779                   }
780               }
781               break;
782
783             case TIMETZOID:
784               {
785                 int32_t zone;
786                 if ( r->integer_datetimes)
787                   {
788                     uint64_t x;
789
790
791                     GET_VALUE (&vptr, x);
792                     val->f = x / 1000000.0;
793                   }
794                 else
795                   {
796                     double x;
797
798                     GET_VALUE (&vptr, x);
799                     val->f = x ;
800                   }
801
802                 GET_VALUE (&vptr, zone);
803                 val1->f = zone / 3600.0;
804               }
805               break;
806
807             case TIMESTAMPOID:
808             case TIMESTAMPTZOID:
809               {
810                 if ( r->integer_datetimes)
811                   {
812                     int64_t x;
813
814                     GET_VALUE (&vptr, x);
815
816                     x /= 1000000;
817
818                     val->f = (x + r->postgres_epoch * 24 * 3600 );
819                   }
820                 else
821                   {
822                     double x;
823
824                     GET_VALUE (&vptr, x);
825
826                     val->f = (x + r->postgres_epoch * 24 * 3600 );
827                   }
828               }
829               break;
830             case TEXTOID:
831             case VARCHAROID:
832             case BPCHAROID:
833             case BYTEAOID:
834               memcpy (val->s, (char *) vptr, MIN (length, var_width));
835               break;
836
837             case NUMERICOID:
838               {
839                 double f = 0.0;
840                 int i;
841                 int16_t n_digits, weight, dscale;
842                 uint16_t sign;
843
844                 GET_VALUE (&vptr, n_digits);
845                 GET_VALUE (&vptr, weight);
846                 GET_VALUE (&vptr, sign);
847                 GET_VALUE (&vptr, dscale);
848
849 #if 0
850                 {
851                   struct fmt_spec fmt;
852                   fmt.d = dscale;
853                   fmt.type = FMT_E;
854                   fmt.w = fmt_max_output_width (fmt.type) ;
855                   fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
856                   var_set_both_formats (v, &fmt);
857                 }
858 #endif
859
860                 for (i = 0 ; i < n_digits;  ++i)
861                   {
862                     uint16_t x;
863                     GET_VALUE (&vptr, x);
864                     f += x * pow (10000, weight--);
865                   }
866
867                 if ( sign == 0x4000)
868                   f *= -1.0;
869
870                 if ( sign == 0xC000)
871                   val->f = SYSMIS;
872                 else
873                   val->f = f;
874               }
875               break;
876
877             default:
878               val->f = SYSMIS;
879               break;
880             }
881         }
882     }
883
884   r->tuple++;
885
886   return c;
887 }
888
889 #endif