378a56affbb45f66fd04a889fcc8c53d5fc66941
[pspp-builds.git] / src / data / psql-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2008, 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/casereader-provider.h>
20 #include <libpspp/message.h>
21 #include <gl/xalloc.h>
22 #include <data/dictionary.h>
23 #include <math.h>
24 #include <stdlib.h>
25
26 #include "psql-reader.h"
27 #include "variable.h"
28 #include "format.h"
29 #include "calendar.h"
30
31 #include <inttypes.h>
32 #include <libpspp/misc.h>
33 #include <libpspp/str.h>
34
35 #include "minmax.h"
36
37 #include "gettext.h"
38 #define _(msgid) gettext (msgid)
39 #define N_(msgid) (msgid)
40
41
42 #if !PSQL_SUPPORT
43 struct casereader *
44 psql_open_reader (struct psql_read_info *info UNUSED, struct dictionary **dict UNUSED)
45 {
46   msg (ME, _("Support for reading postgres databases was not compiled into this installation of PSPP"));
47
48   return NULL;
49 }
50
51 #else
52
53 #include <stdint.h>
54 #include <libpq-fe.h>
55
56
57 /* Default width of string variables. */
58 #define PSQL_DEFAULT_WIDTH 8
59
60 /* These macros  must be the same as in catalog/pg_types.h from the postgres source */
61 #define BOOLOID            16
62 #define BYTEAOID           17
63 #define CHAROID            18
64 #define NAMEOID            19
65 #define INT8OID            20
66 #define INT2OID            21
67 #define INT4OID            23
68 #define TEXTOID            25
69 #define OIDOID             26
70 #define FLOAT4OID          700
71 #define FLOAT8OID          701
72 #define CASHOID            790
73 #define BPCHAROID          1042
74 #define VARCHAROID         1043
75 #define DATEOID            1082
76 #define TIMEOID            1083
77 #define TIMESTAMPOID       1114
78 #define TIMESTAMPTZOID     1184
79 #define INTERVALOID        1186
80 #define TIMETZOID          1266
81 #define NUMERICOID         1700
82
83 static void psql_casereader_destroy (struct casereader *reader UNUSED, void *r_);
84
85 static struct ccase *psql_casereader_read (struct casereader *, void *);
86
87 static const struct casereader_class psql_casereader_class =
88   {
89     psql_casereader_read,
90     psql_casereader_destroy,
91     NULL,
92     NULL,
93   };
94
95 struct psql_reader
96 {
97   PGconn *conn;
98   PGresult *res;
99   int tuple;
100
101   bool integer_datetimes;
102
103   double postgres_epoch;
104
105   struct caseproto *proto;
106   struct dictionary *dict;
107
108   /* An array of ints, which maps psql column numbers into
109      pspp variables */
110   struct variable **vmap;
111   size_t vmapsize;
112
113   struct string fetch_cmd;
114   int cache_size;
115 };
116
117
118 static struct ccase *set_value (struct psql_reader *r);
119
120
121
122 #if WORDS_BIGENDIAN
123 static void
124 data_to_native (const void *in_, void *out_, int len)
125 {
126   int i;
127   const unsigned char *in = in_;
128   unsigned char *out = out_;
129   for (i = 0 ; i < len ; ++i )
130     out[i] = in[i];
131 }
132 #else
133 static void
134 data_to_native (const void *in_, void *out_, int len)
135 {
136   int i;
137   const unsigned char *in = in_;
138   unsigned char *out = out_;
139   for (i = 0 ; i < len ; ++i )
140     out[len - i - 1] = in[i];
141 }
142 #endif
143
144
145 #define GET_VALUE(IN, OUT) do { \
146     size_t sz = sizeof (OUT); \
147     data_to_native (*(IN), &(OUT), sz) ; \
148     (*IN) += sz; \
149 } while (false)
150
151
152 #if 0
153 static void
154 dump (const unsigned char *x, int l)
155 {
156   int i;
157
158   for (i = 0; i < l ; ++i)
159     {
160       printf ("%02x ", x[i]);
161     }
162
163   putchar ('\n');
164
165   for (i = 0; i < l ; ++i)
166     {
167       if ( isprint (x[i]))
168         printf ("%c ", x[i]);
169       else
170         printf ("   ");
171     }
172
173   putchar ('\n');
174 }
175 #endif
176
177 static struct variable *
178 create_var (struct psql_reader *r, const struct fmt_spec *fmt,
179             int width, const char *suggested_name, int col)
180 {
181   unsigned long int vx = 0;
182   struct variable *var;
183   char *name;
184
185   name = dict_make_unique_var_name (r->dict, suggested_name, &vx);
186   var = dict_create_var (r->dict, name, width);
187   free (name);
188
189   var_set_both_formats (var, fmt);
190
191   if ( col != -1)
192     {
193       r->vmap = xrealloc (r->vmap, (col + 1) * sizeof (*r->vmap));
194
195       r->vmap[col] = var;
196       r->vmapsize = col + 1;
197     }
198
199   return var;
200 }
201
202
203
204
205 /* Fill the cache */
206 static bool
207 reload_cache (struct psql_reader *r)
208 {
209   PQclear (r->res);
210   r->tuple = 0;
211
212   r->res = PQexec (r->conn, ds_cstr (&r->fetch_cmd));
213
214   if (PQresultStatus (r->res) != PGRES_TUPLES_OK || PQntuples (r->res) < 1)
215     {
216       PQclear (r->res);
217       r->res = NULL;
218       return false;
219     }
220
221   return true;
222 }
223
224
225 struct casereader *
226 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
227 {
228   int i;
229   int n_fields, n_tuples;
230   PGresult *qres = NULL;
231   casenumber n_cases = CASENUMBER_MAX;
232
233   struct psql_reader *r = xzalloc (sizeof *r);
234   struct string query ;
235
236   r->conn = PQconnectdb (info->conninfo);
237   if ( NULL == r->conn)
238     {
239       msg (ME, _("Memory error whilst opening psql source"));
240       goto error;
241     }
242
243   if ( PQstatus (r->conn) != CONNECTION_OK )
244     {
245       msg (ME, _("Error opening psql source: %s."),
246            PQerrorMessage (r->conn));
247
248       goto error;
249     }
250
251   {
252     int ver_num;
253     const char *vers = PQparameterStatus (r->conn, "server_version");
254
255     sscanf (vers, "%d", &ver_num);
256
257     if ( ver_num < 8)
258       {
259         msg (ME,
260              _("Postgres server is version %s."
261                " Reading from versions earlier than 8.0 is not supported."),
262              vers);
263
264         goto error;
265       }
266   }
267
268   {
269     const char *dt =  PQparameterStatus (r->conn, "integer_datetimes");
270
271     r->integer_datetimes = ( 0 == strcasecmp (dt, "on"));
272   }
273
274 #if USE_SSL
275   if ( PQgetssl (r->conn) == NULL)
276 #endif
277     {
278       if (! info->allow_clear)
279         {
280           msg (ME, _("Connection is unencrypted, "
281                      "but unencrypted connections have not been permitted."));
282           goto error;
283         }
284     }
285
286   r->postgres_epoch = calendar_gregorian_to_offset (2000, 1, 1, NULL);
287
288
289   /* Create the dictionary and populate it */
290   *dict = r->dict = dict_create ();
291
292   {
293     const int enc = PQclientEncoding (r->conn);
294
295     /* According to section 22.2 of the Postgresql manual
296        a value of zero (SQL_ASCII) indicates
297        "a declaration of ignorance about the encoding".
298        Accordingly, we don't set the dictionary's encoding
299        if we find this value.
300     */
301     if ( enc != 0 )
302       dict_set_encoding (r->dict, pg_encoding_to_char (enc));
303   }
304
305   /*
306     select count (*) from (select * from medium) stupid_sql_standard;
307   */
308   ds_init_cstr (&query,
309                 "BEGIN READ ONLY ISOLATION LEVEL SERIALIZABLE; "
310                 "DECLARE  pspp BINARY CURSOR FOR ");
311
312   ds_put_substring (&query, info->sql.ss);
313
314   qres = PQexec (r->conn, ds_cstr (&query));
315   ds_destroy (&query);
316   if ( PQresultStatus (qres) != PGRES_COMMAND_OK )
317     {
318       msg (ME, _("Error from psql source: %s."),
319            PQresultErrorMessage (qres));
320       goto error;
321     }
322
323   PQclear (qres);
324
325
326   /* Now use the count() function to find the total number of cases
327      that this query returns.
328      Doing this incurs some overhead.  The server has to iterate every
329      case in order to find this number.  However, it's performed on the
330      server side, and in all except the most huge databases the extra
331      overhead will be worth the effort.
332      On the other hand, most PSPP functions don't need to know this.
333      The GUI is the notable exception.
334   */
335   ds_init_cstr (&query, "SELECT count (*) FROM (");
336   ds_put_substring (&query, info->sql.ss);
337   ds_put_cstr (&query, ") stupid_sql_standard");
338
339   qres = PQexec (r->conn, ds_cstr (&query));
340   ds_destroy (&query);
341   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
342     {
343       msg (ME, _("Error from psql source: %s."),
344            PQresultErrorMessage (qres));
345       goto error;
346     }
347   n_cases = atol (PQgetvalue (qres, 0, 0));
348   PQclear (qres);
349
350   qres = PQexec (r->conn, "FETCH FIRST FROM pspp");
351   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
352     {
353       msg (ME, _("Error from psql source: %s."),
354            PQresultErrorMessage (qres));
355       goto error;
356     }
357
358   n_tuples = PQntuples (qres);
359   n_fields = PQnfields (qres);
360
361   r->proto = NULL;
362   r->vmap = NULL;
363   r->vmapsize = 0;
364
365   for (i = 0 ; i < n_fields ; ++i )
366     {
367       struct variable *var;
368       struct fmt_spec fmt = {FMT_F, 8, 2};
369       Oid type = PQftype (qres, i);
370       int width = 0;
371       int length ;
372
373       /* If there are no data then make a finger in the air 
374          guess at the contents */
375       if ( n_tuples > 0 )
376         length = PQgetlength (qres, 0, i);
377       else 
378         length = PSQL_DEFAULT_WIDTH;
379
380       switch (type)
381         {
382         case BOOLOID:
383         case OIDOID:
384         case INT2OID:
385         case INT4OID:
386         case INT8OID:
387         case FLOAT4OID:
388         case FLOAT8OID:
389           fmt.type = FMT_F;
390           break;
391         case CASHOID:
392           fmt.type = FMT_DOLLAR;
393           break;
394         case CHAROID:
395           fmt.type = FMT_A;
396           width = length > 0 ? length : 1;
397           fmt.d = 0;
398           fmt.w = 1;
399           break;
400         case TEXTOID:
401         case VARCHAROID:
402         case BPCHAROID:
403           fmt.type = FMT_A;
404           width = (info->str_width == -1) ?
405             ROUND_UP (length, PSQL_DEFAULT_WIDTH) : info->str_width;
406           fmt.w = width;
407           fmt.d = 0;
408           break;
409         case BYTEAOID:
410           fmt.type = FMT_AHEX;
411           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
412           fmt.w = width * 2;
413           fmt.d = 0;
414           break;
415         case INTERVALOID:
416           fmt.type = FMT_DTIME;
417           width = 0;
418           fmt.d = 0;
419           fmt.w = 13;
420           break;
421         case DATEOID:
422           fmt.type = FMT_DATE;
423           width = 0;
424           fmt.w = 11;
425           fmt.d = 0;
426           break;
427         case TIMEOID:
428         case TIMETZOID:
429           fmt.type = FMT_TIME;
430           width = 0;
431           fmt.w = 11;
432           fmt.d = 0;
433           break;
434         case TIMESTAMPOID:
435         case TIMESTAMPTZOID:
436           fmt.type = FMT_DATETIME;
437           fmt.d = 0;
438           fmt.w = 22;
439           width = 0;
440           break;
441         case NUMERICOID:
442           fmt.type = FMT_E;
443           fmt.d = 2;
444           fmt.w = 40;
445           width = 0;
446           break;
447         default:
448           msg (MW, _("Unsupported OID %d.  SYSMIS values will be inserted."), type);
449           fmt.type = FMT_A;
450           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
451           fmt.w = width ;
452           fmt.d = 0;
453           break;
454         }
455
456       if ( width == 0 && fmt_is_string (fmt.type))
457         fmt.w = width = PSQL_DEFAULT_WIDTH;
458
459
460       var = create_var (r, &fmt, width, PQfname (qres, i), i);
461       if ( type == NUMERICOID && n_tuples > 0)
462         {
463           const uint8_t *vptr = (const uint8_t *) PQgetvalue (qres, 0, i);
464           struct fmt_spec fmt;
465           int16_t n_digits, weight, dscale;
466           uint16_t sign;
467
468           GET_VALUE (&vptr, n_digits);
469           GET_VALUE (&vptr, weight);
470           GET_VALUE (&vptr, sign);
471           GET_VALUE (&vptr, dscale);
472
473           fmt.d = dscale;
474           fmt.type = FMT_E;
475           fmt.w = fmt_max_output_width (fmt.type) ;
476           fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
477           var_set_both_formats (var, &fmt);
478         }
479
480       /* Timezones need an extra variable */
481       switch (type)
482         {
483         case TIMETZOID:
484           {
485             struct string name;
486             ds_init_cstr (&name, var_get_name (var));
487             ds_put_cstr (&name, "-zone");
488             fmt.type = FMT_F;
489             fmt.w = 8;
490             fmt.d = 2;
491
492             create_var (r, &fmt, 0, ds_cstr (&name), -1);
493
494             ds_destroy (&name);
495           }
496           break;
497
498         case INTERVALOID:
499           {
500             struct string name;
501             ds_init_cstr (&name, var_get_name (var));
502             ds_put_cstr (&name, "-months");
503             fmt.type = FMT_F;
504             fmt.w = 3;
505             fmt.d = 0;
506
507             create_var (r, &fmt, 0, ds_cstr (&name), -1);
508
509             ds_destroy (&name);
510           }
511         default:
512           break;
513         }
514     }
515
516   PQclear (qres);
517
518   qres = PQexec (r->conn, "MOVE BACKWARD 1 FROM pspp");
519   if ( PQresultStatus (qres) != PGRES_COMMAND_OK)
520     {
521       PQclear (qres);
522       goto error;
523     }
524   PQclear (qres);
525
526   r->cache_size = info->bsize != -1 ? info->bsize: 4096;
527
528   ds_init_empty (&r->fetch_cmd);
529   ds_put_format (&r->fetch_cmd,  "FETCH FORWARD %d FROM pspp", r->cache_size);
530
531   reload_cache (r);
532   r->proto = caseproto_ref (dict_get_proto (*dict));
533
534   return casereader_create_sequential
535     (NULL,
536      r->proto,
537      n_cases,
538      &psql_casereader_class, r);
539
540  error:
541   dict_destroy (*dict);
542
543   psql_casereader_destroy (NULL, r);
544   return NULL;
545 }
546
547
548 static void
549 psql_casereader_destroy (struct casereader *reader UNUSED, void *r_)
550 {
551   struct psql_reader *r = r_;
552   if (r == NULL)
553     return ;
554
555   ds_destroy (&r->fetch_cmd);
556   free (r->vmap);
557   if (r->res) PQclear (r->res);
558   PQfinish (r->conn);
559   caseproto_unref (r->proto);
560
561   free (r);
562 }
563
564
565
566 static struct ccase *
567 psql_casereader_read (struct casereader *reader UNUSED, void *r_)
568 {
569   struct psql_reader *r = r_;
570
571   if ( NULL == r->res || r->tuple >= r->cache_size)
572     {
573       if ( ! reload_cache (r) )
574         return false;
575     }
576
577   return set_value (r);
578 }
579
580 static struct ccase *
581 set_value (struct psql_reader *r)
582 {
583   struct ccase *c;
584   int n_vars;
585   int i;
586
587   assert (r->res);
588
589   n_vars = PQnfields (r->res);
590
591   if ( r->tuple >= PQntuples (r->res))
592     return NULL;
593
594   c = case_create (r->proto);
595   case_set_missing (c);
596
597
598   for (i = 0 ; i < n_vars ; ++i )
599     {
600       Oid type = PQftype (r->res, i);
601       const struct variable *v = r->vmap[i];
602       union value *val = case_data_rw (c, v);
603
604       union value *val1 = NULL;
605
606       switch (type)
607         {
608         case INTERVALOID:
609         case TIMESTAMPTZOID:
610         case TIMETZOID:
611           if (i < r->vmapsize && var_get_dict_index(v) + 1 < dict_get_var_cnt (r->dict))
612             {
613               const struct variable *v1 = NULL;
614               v1 = dict_get_var (r->dict, var_get_dict_index (v) + 1);
615
616               val1 = case_data_rw (c, v1);
617             }
618           break;
619         default:
620           break;
621         }
622
623
624       if (PQgetisnull (r->res, r->tuple, i))
625         {
626           value_set_missing (val, var_get_width (v));
627
628           switch (type)
629             {
630             case INTERVALOID:
631             case TIMESTAMPTZOID:
632             case TIMETZOID:
633               val1->f = SYSMIS;
634               break;
635             default:
636               break;
637             }
638         }
639       else
640         {
641           const uint8_t *vptr = (const uint8_t *) PQgetvalue (r->res, r->tuple, i);
642           int length = PQgetlength (r->res, r->tuple, i);
643
644           int var_width = var_get_width (v);
645           switch (type)
646             {
647             case BOOLOID:
648               {
649                 int8_t x;
650                 GET_VALUE (&vptr, x);
651                 val->f = x;
652               }
653               break;
654
655             case OIDOID:
656             case INT2OID:
657               {
658                 int16_t x;
659                 GET_VALUE (&vptr, x);
660                 val->f = x;
661               }
662               break;
663
664             case INT4OID:
665               {
666                 int32_t x;
667                 GET_VALUE (&vptr, x);
668                 val->f = x;
669               }
670               break;
671
672             case INT8OID:
673               {
674                 int64_t x;
675                 GET_VALUE (&vptr, x);
676                 val->f = x;
677               }
678               break;
679
680             case FLOAT4OID:
681               {
682                 float n;
683                 GET_VALUE (&vptr, n);
684                 val->f = n;
685               }
686               break;
687
688             case FLOAT8OID:
689               {
690                 double n;
691                 GET_VALUE (&vptr, n);
692                 val->f = n;
693               }
694               break;
695
696             case CASHOID:
697               {
698                 /* Postgres 8.3 uses 64 bits.
699                    Earlier versions use 32 */
700                 switch (length)
701                   {
702                   case 8:
703                     {
704                       int64_t x;
705                       GET_VALUE (&vptr, x);
706                       val->f = x / 100.0;
707                     }
708                     break;
709                   case 4:
710                     {
711                       int32_t x;
712                       GET_VALUE (&vptr, x);
713                       val->f = x / 100.0;
714                     }
715                     break;
716                   default:
717                     val->f = SYSMIS;
718                     break;
719                   }
720               }
721               break;
722
723             case INTERVALOID:
724               {
725                 if ( r->integer_datetimes )
726                   {
727                     uint32_t months;
728                     uint32_t days;
729                     uint32_t us;
730                     uint32_t things;
731
732                     GET_VALUE (&vptr, things);
733                     GET_VALUE (&vptr, us);
734                     GET_VALUE (&vptr, days);
735                     GET_VALUE (&vptr, months);
736
737                     val->f = us / 1000000.0;
738                     val->f += days * 24 * 3600;
739
740                     val1->f = months;
741                   }
742                 else
743                   {
744                     uint32_t days, months;
745                     double seconds;
746
747                     GET_VALUE (&vptr, seconds);
748                     GET_VALUE (&vptr, days);
749                     GET_VALUE (&vptr, months);
750
751                     val->f = seconds;
752                     val->f += days * 24 * 3600;
753
754                     val1->f = months;
755                   }
756               }
757               break;
758
759             case DATEOID:
760               {
761                 int32_t x;
762
763                 GET_VALUE (&vptr, x);
764
765                 val->f = (x + r->postgres_epoch) * 24 * 3600 ;
766               }
767               break;
768
769             case TIMEOID:
770               {
771                 if ( r->integer_datetimes)
772                   {
773                     uint64_t x;
774                     GET_VALUE (&vptr, x);
775                     val->f = x / 1000000.0;
776                   }
777                 else
778                   {
779                     double x;
780                     GET_VALUE (&vptr, x);
781                     val->f = x;
782                   }
783               }
784               break;
785
786             case TIMETZOID:
787               {
788                 int32_t zone;
789                 if ( r->integer_datetimes)
790                   {
791                     uint64_t x;
792
793
794                     GET_VALUE (&vptr, x);
795                     val->f = x / 1000000.0;
796                   }
797                 else
798                   {
799                     double x;
800
801                     GET_VALUE (&vptr, x);
802                     val->f = x ;
803                   }
804
805                 GET_VALUE (&vptr, zone);
806                 val1->f = zone / 3600.0;
807               }
808               break;
809
810             case TIMESTAMPOID:
811             case TIMESTAMPTZOID:
812               {
813                 if ( r->integer_datetimes)
814                   {
815                     int64_t x;
816
817                     GET_VALUE (&vptr, x);
818
819                     x /= 1000000;
820
821                     val->f = (x + r->postgres_epoch * 24 * 3600 );
822                   }
823                 else
824                   {
825                     double x;
826
827                     GET_VALUE (&vptr, x);
828
829                     val->f = (x + r->postgres_epoch * 24 * 3600 );
830                   }
831               }
832               break;
833             case TEXTOID:
834             case VARCHAROID:
835             case BPCHAROID:
836             case BYTEAOID:
837               memcpy (value_str_rw (val, var_width), vptr,
838                       MIN (length, var_width));
839               break;
840
841             case NUMERICOID:
842               {
843                 double f = 0.0;
844                 int i;
845                 int16_t n_digits, weight, dscale;
846                 uint16_t sign;
847
848                 GET_VALUE (&vptr, n_digits);
849                 GET_VALUE (&vptr, weight);
850                 GET_VALUE (&vptr, sign);
851                 GET_VALUE (&vptr, dscale);
852
853 #if 0
854                 {
855                   struct fmt_spec fmt;
856                   fmt.d = dscale;
857                   fmt.type = FMT_E;
858                   fmt.w = fmt_max_output_width (fmt.type) ;
859                   fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
860                   var_set_both_formats (v, &fmt);
861                 }
862 #endif
863
864                 for (i = 0 ; i < n_digits;  ++i)
865                   {
866                     uint16_t x;
867                     GET_VALUE (&vptr, x);
868                     f += x * pow (10000, weight--);
869                   }
870
871                 if ( sign == 0x4000)
872                   f *= -1.0;
873
874                 if ( sign == 0xC000)
875                   val->f = SYSMIS;
876                 else
877                   val->f = f;
878               }
879               break;
880
881             default:
882               val->f = SYSMIS;
883               break;
884             }
885         }
886     }
887
888   r->tuple++;
889
890   return c;
891 }
892
893 #endif