20aa57dbfc17f28a8c342e763e137c4580d35fcb
[pspp-builds.git] / src / data / psql-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2008 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/casereader-provider.h>
20 #include <libpspp/message.h>
21 #include <gl/xalloc.h>
22 #include <data/dictionary.h>
23 #include <stdlib.h>
24
25 #include "psql-reader.h"
26 #include "variable.h"
27 #include "format.h"
28 #include "calendar.h"
29
30 #include <inttypes.h>
31
32 #include "gettext.h"
33 #define _(msgid) gettext (msgid)
34 #define N_(msgid) (msgid)
35
36
37 #if !PSQL_SUPPORT
38 struct casereader *
39 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
40 {
41   msg (ME, _("Support for reading postgres databases was not compiled into this installation of PSPP"));
42
43   return NULL;
44 }
45
46 #else
47
48 #include <stdint.h>
49 #include <libpq-fe.h>
50
51
52 /* These macros  must be the same as in catalog/pg_types.h from the postgres source */
53 #define BOOLOID            16
54 #define BYTEAOID           17
55 #define CHAROID            18
56 #define NAMEOID            19
57 #define INT8OID            20
58 #define INT2OID            21
59 #define INT4OID            23
60 #define TEXTOID            25
61 #define OIDOID             26
62 #define FLOAT4OID          700
63 #define FLOAT8OID          701
64 #define CASHOID            790
65 #define BPCHAROID          1042
66 #define VARCHAROID         1043
67 #define DATEOID            1082
68 #define TIMEOID            1083
69 #define TIMESTAMPOID       1114
70 #define TIMESTAMPTZOID     1184
71 #define INTERVALOID        1186
72 #define TIMETZOID          1266
73 #define NUMERICOID         1700
74
75 static void psql_casereader_destroy (struct casereader *reader UNUSED, void *r_);
76
77 static bool psql_casereader_read (struct casereader *, void *,
78                                   struct ccase *);
79
80 static const struct casereader_class psql_casereader_class =
81   {
82     psql_casereader_read,
83     psql_casereader_destroy,
84     NULL,
85     NULL,
86   };
87
88 struct psql_reader
89 {
90   PGconn *conn;
91
92   bool integer_datetimes;
93
94   double postgres_epoch;
95
96   size_t value_cnt;
97   struct dictionary *dict;
98
99   bool used_first_case;
100   struct ccase first_case;
101
102   /* An array of ints, which maps psql column numbers into
103      pspp variable numbers */
104   int *vmap;
105   size_t vmapsize;
106 };
107
108
109 static void set_value (const struct psql_reader *r,
110                        PGresult *res, struct ccase *c);
111
112
113
114 #if WORDS_BIGENDIAN
115 static void
116 data_to_native (const void *in_, void *out_, int len)
117 {
118   int i;
119   const unsigned char *in = in_;
120   unsigned char *out = out_;
121   for (i = 0 ; i < len ; ++i )
122     out[i] = in[i];
123 }
124 #else
125 static void
126 data_to_native (const void *in_, void *out_, int len)
127 {
128   int i;
129   const unsigned char *in = in_;
130   unsigned char *out = out_;
131   for (i = 0 ; i < len ; ++i )
132     out[len - i - 1] = in[i];
133 }
134 #endif
135
136
137 #define GET_VALUE(IN, OUT) do { \
138     size_t sz = sizeof (OUT); \
139     data_to_native (*(IN), &(OUT), sz) ; \
140     (*IN) += sz; \
141 } while (false)
142
143
144 #if 0
145 static void
146 dump (const unsigned char *x, int l)
147 {
148   int i;
149
150   for (i = 0; i < l ; ++i)
151     {
152       printf ("%02x ", x[i]);
153     }
154
155   putchar ('\n');
156
157   for (i = 0; i < l ; ++i)
158     {
159       if ( isprint (x[i]))
160         printf ("%c ", x[i]);
161       else
162         printf ("   ");
163     }
164
165   putchar ('\n');
166 }
167 #endif
168
169 static struct variable *
170 create_var (struct psql_reader *r, const struct fmt_spec *fmt,
171             int width, const char *suggested_name, int col)
172 {
173   unsigned long int vx = 0;
174   int vidx;
175   struct variable *var;
176   char name[VAR_NAME_LEN + 1];
177
178   r->value_cnt += value_cnt_from_width (width);
179
180   if ( ! dict_make_unique_var_name (r->dict, suggested_name, &vx, name))
181     {
182       msg (ME, _("Cannot create variable name from %s"), suggested_name);
183       return NULL;
184     }
185
186   var = dict_create_var (r->dict, name, width);
187   var_set_both_formats (var, fmt);
188
189   vidx = var_get_dict_index (var);
190
191   if ( col != -1)
192     {
193       r->vmap = xrealloc (r->vmap, (col + 1) * sizeof (int));
194
195       r->vmap[col] = vidx;
196       r->vmapsize = col + 1;
197     }
198
199   return var;
200 }
201
202 struct casereader *
203 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
204 {
205   int i;
206   int n_fields;
207   PGresult *res = NULL;
208
209   struct psql_reader *r = xzalloc (sizeof *r);
210   struct string query ;
211
212
213   r->conn = PQconnectdb (info->conninfo);
214   if ( NULL == r->conn)
215     {
216       msg (ME, _("Memory error whilst opening psql source"));
217       goto error;
218     }
219
220   if ( PQstatus (r->conn) != CONNECTION_OK )
221     {
222       msg (ME, _("Error opening psql source: %s."),
223            PQerrorMessage (r->conn));
224
225       goto error;
226     }
227
228   {
229     int v1;
230     const char *vers = PQparameterStatus (r->conn, "server_version");
231
232     sscanf (vers, "%d", &v1);
233
234     if ( v1 < 8)
235       {
236         msg (ME,
237              _("Postgres server is version %s."
238                " Reading from versions earlier than 8.0 is not supported."),
239              vers);
240
241         goto error;
242       }
243   }
244
245   {
246     const char *dt =  PQparameterStatus (r->conn, "integer_datetimes");
247
248     r->integer_datetimes = ( 0 == strcasecmp (dt, "on"));
249   }
250
251 #if USE_SSL
252   if ( PQgetssl (r->conn) == NULL)
253 #endif
254     {
255       if (! info->allow_clear)
256         {
257           msg (ME, _("Connection is unencrypted, "
258                      "but unencrypted connections have not been permitted."));
259           goto error;
260         }
261     }
262
263   r->postgres_epoch =
264     calendar_gregorian_to_offset (2000, 1, 1, NULL, NULL);
265
266
267   /* Create the dictionary and populate it */
268   *dict = r->dict = dict_create ();
269
270   ds_init_cstr (&query, "BEGIN READ ONLY ISOLATION LEVEL SERIALIZABLE; DECLARE  pspp BINARY CURSOR FOR ");
271   ds_put_substring (&query, info->sql.ss);
272
273   res = PQexec (r->conn, ds_cstr (&query));
274   ds_destroy (&query);
275   if ( PQresultStatus (res) != PGRES_COMMAND_OK )
276     {
277       msg (ME, _("Error from psql source: %s."),
278            PQresultErrorMessage (res));
279       goto error;
280     }
281
282   PQclear (res);
283
284   res = PQexec (r->conn, "FETCH FIRST FROM pspp");
285   if ( PQresultStatus (res) != PGRES_TUPLES_OK )
286     {
287       msg (ME, _("Error from psql source: %s."),
288            PQresultErrorMessage (res));
289       goto error;
290     }
291
292   n_fields = PQnfields (res);
293
294   r->value_cnt = 0;
295   r->vmap = NULL;
296   r->vmapsize = 0;
297
298   for (i = 0 ; i < n_fields ; ++i )
299     {
300       struct variable *var;
301       struct fmt_spec fmt = {FMT_F, 8, 2};
302       Oid type = PQftype (res, i);
303       int width = 0;
304       int length = PQgetlength (res, 0, i);
305
306       switch (type)
307         {
308         case BOOLOID:
309         case OIDOID:
310         case INT2OID:
311         case INT4OID:
312         case INT8OID:
313         case FLOAT4OID:
314         case FLOAT8OID:
315           fmt.type = FMT_F;
316           break;
317         case CASHOID:
318           fmt.type = FMT_DOLLAR;
319           break;
320         case CHAROID:
321           fmt.type = FMT_A;
322           width = length > 0 ? length : 1;
323           fmt.d = 0;
324           fmt.w = 1;
325           break;
326         case TEXTOID:
327         case VARCHAROID:
328         case BPCHAROID:
329           fmt.type = FMT_A;
330           width = (info->str_width == -1) ?
331             ROUND_UP (length, MAX_SHORT_STRING) : info->str_width;
332           fmt.w = width;
333           fmt.d = 0;
334           break;
335         case BYTEAOID:
336           fmt.type = FMT_AHEX;
337           width = length > 0 ? length : MAX_SHORT_STRING;
338           fmt.w = width * 2;
339           fmt.d = 0;
340           break;
341         case INTERVALOID:
342           fmt.type = FMT_DTIME;
343           width = 0;
344           fmt.d = 0;
345           fmt.w = 13;
346           break;
347         case DATEOID:
348           fmt.type = FMT_DATE;
349           width = 0;
350           fmt.w = 11;
351           fmt.d = 0;
352           break;
353         case TIMEOID:
354         case TIMETZOID:
355           fmt.type = FMT_TIME;
356           width = 0;
357           fmt.w = 11;
358           fmt.d = 0;
359           break;
360         case TIMESTAMPOID:
361         case TIMESTAMPTZOID:
362           fmt.type = FMT_DATETIME;
363           fmt.d = 0;
364           fmt.w = 22;
365           width = 0;
366           break;
367         case NUMERICOID:
368           fmt.type = FMT_E;
369           fmt.d = 2;
370           fmt.w = 40;
371           width = 0;
372           break;
373         default:
374           msg (MW, _("Unsupported OID %d.  SYSMIS values will be inserted."), type);
375           fmt.type = FMT_A;
376           width = length > 0 ? length : MAX_SHORT_STRING;
377           fmt.w = width ;
378           fmt.d = 0;
379
380           break;
381         }
382
383       var = create_var (r, &fmt, width, PQfname (res, i), i);
384
385       /* Timezones need an extra variable */
386       switch (type)
387         {
388         case TIMETZOID:
389           {
390             struct string name;
391             ds_init_cstr (&name, var_get_name (var));
392             ds_put_cstr (&name, "-zone");
393             fmt.type = FMT_F;
394             fmt.w = 8;
395             fmt.d = 2;
396
397             create_var (r, &fmt, 0, ds_cstr (&name), -1);
398
399             ds_destroy (&name);
400           }
401           break;
402
403         case INTERVALOID:
404           {
405             struct string name;
406             ds_init_cstr (&name, var_get_name (var));
407             ds_put_cstr (&name, "-months");
408             fmt.type = FMT_F;
409             fmt.w = 3;
410             fmt.d = 0;
411
412             create_var (r, &fmt, 0, ds_cstr (&name), -1);
413
414             ds_destroy (&name);
415           }
416         default:
417           break;
418         }
419
420     }
421
422   /* Create the first case, and cache it */
423   r->used_first_case = false;
424
425
426   case_create (&r->first_case, r->value_cnt);
427   memset (case_data_rw_idx (&r->first_case, 0)->s,
428           ' ', MAX_SHORT_STRING * r->value_cnt);
429
430   set_value (r, res, &r->first_case);
431
432   PQclear (res);
433
434   return casereader_create_sequential
435     (NULL,
436      r->value_cnt,
437      CASENUMBER_MAX,
438      &psql_casereader_class, r);
439
440  error:
441   PQclear (res);
442   dict_destroy (*dict);
443
444   psql_casereader_destroy (NULL, r);
445   return NULL;
446 }
447
448
449 static void
450 psql_casereader_destroy (struct casereader *reader UNUSED, void *r_)
451 {
452   struct psql_reader *r = r_;
453   if (r == NULL)
454     return ;
455
456   free (r->vmap);
457   PQfinish (r->conn);
458
459   free (r);
460 }
461
462 static bool
463 psql_casereader_read (struct casereader *reader UNUSED, void *r_,
464                       struct ccase *cc)
465 {
466   PGresult *res;
467
468   struct psql_reader *r = r_;
469
470   if ( !r->used_first_case )
471     {
472       *cc = r->first_case;
473       r->used_first_case = true;
474       return true;
475     }
476
477   case_create (cc, r->value_cnt);
478   memset (case_data_rw_idx (cc, 0)->s, ' ', MAX_SHORT_STRING * r->value_cnt);
479
480   res = PQexec (r->conn, "FETCH NEXT FROM pspp");
481   if ( PQresultStatus (res) != PGRES_TUPLES_OK || PQntuples (res) < 1)
482     {
483       PQclear (res);
484       case_destroy (cc);
485       return false;
486     }
487
488   set_value (r, res, cc);
489
490   PQclear (res);
491
492   return true;
493 }
494
495 static void
496 set_value (const struct psql_reader *r,
497            PGresult *res, struct ccase *c)
498 {
499   int i;
500   int n_vars = PQnfields (res);
501
502   for (i = 0 ; i < n_vars ; ++i )
503     {
504       Oid type = PQftype (res, i);
505       struct variable *v = dict_get_var (r->dict, r->vmap[i]);
506       union value *val = case_data_rw (c, v);
507       const struct variable *v1 = NULL;
508       union value *val1 = NULL;
509
510       if (i < r->vmapsize && r->vmap[i] + 1 < dict_get_var_cnt (r->dict))
511         {
512           v1 = dict_get_var (r->dict, r->vmap[i] + 1);
513
514           val1 = case_data_rw (c, v1);
515         }
516
517
518       if (PQgetisnull (res, 0, i))
519         {
520           value_set_missing (val, var_get_width (v));
521
522           switch (type)
523             {
524             case INTERVALOID:
525             case TIMESTAMPTZOID:
526             case TIMETZOID:
527               val1->f = SYSMIS;
528               break;
529             default:
530               break;
531             }
532         }
533       else
534         {
535           const uint8_t *vptr = (const uint8_t *) PQgetvalue (res, 0, i);
536           int length = PQgetlength (res, 0, i);
537
538           int var_width = var_get_width (v);
539           switch (type)
540             {
541             case BOOLOID:
542               {
543                 int8_t x;
544                 GET_VALUE (&vptr, x);
545                 val->f = x;
546               }
547               break;
548
549             case OIDOID:
550             case INT2OID:
551               {
552                 int16_t x;
553                 GET_VALUE (&vptr, x);
554                 val->f = x;
555               }
556               break;
557
558             case INT4OID:
559               {
560                 int32_t x;
561                 GET_VALUE (&vptr, x);
562                 val->f = x;
563               }
564               break;
565
566             case INT8OID:
567               {
568                 int64_t x;
569                 GET_VALUE (&vptr, x);
570                 val->f = x;
571               }
572               break;
573
574             case FLOAT4OID:
575               {
576                 float n;
577                 GET_VALUE (&vptr, n);
578                 val->f = n;
579               }
580               break;
581
582             case FLOAT8OID:
583               {
584                 double n;
585                 GET_VALUE (&vptr, n);
586                 val->f = n;
587               }
588               break;
589
590             case CASHOID:
591               {
592                 int32_t x;
593                 GET_VALUE (&vptr, x);
594                 val->f = x / 100.0;
595               }
596               break;
597
598             case INTERVALOID:
599               {
600                 if ( r->integer_datetimes )
601                   {
602                     uint32_t months;
603                     uint32_t days;
604                     uint32_t us;
605                     uint32_t things;
606
607                     GET_VALUE (&vptr, things);
608                     GET_VALUE (&vptr, us);
609                     GET_VALUE (&vptr, days);
610                     GET_VALUE (&vptr, months);
611
612                     val->f = us / 1000000.0;
613                     val->f += days * 24 * 3600;
614
615                     val1->f = months;
616                   }
617                 else
618                   {
619                     uint32_t days, months;
620                     double seconds;
621
622                     GET_VALUE (&vptr, seconds);
623                     GET_VALUE (&vptr, days);
624                     GET_VALUE (&vptr, months);
625
626                     val->f = seconds;
627                     val->f += days * 24 * 3600;
628
629                     val1->f = months;
630                   }
631               }
632               break;
633
634             case DATEOID:
635               {
636                 int32_t x;
637
638                 GET_VALUE (&vptr, x);
639
640                 val->f = (x + r->postgres_epoch) * 24 * 3600 ;
641               }
642               break;
643
644             case TIMEOID:
645               {
646                 if ( r->integer_datetimes)
647                   {
648                     uint64_t x;
649                     GET_VALUE (&vptr, x);
650                     val->f = x / 1000000.0;
651                   }
652                 else
653                   {
654                     double x;
655                     GET_VALUE (&vptr, x);
656                     val->f = x;
657                   }
658               }
659               break;
660
661             case TIMETZOID:
662               {
663                 int32_t zone;
664                 if ( r->integer_datetimes)
665                   {
666                     uint64_t x;
667
668
669                     GET_VALUE (&vptr, x);
670                     val->f = x / 1000000.0;
671                   }
672                 else
673                   {
674                     double x;
675
676                     GET_VALUE (&vptr, x);
677                     val->f = x ;
678                   }
679
680                 GET_VALUE (&vptr, zone);
681                 val1->f = zone / 3600.0;
682               }
683               break;
684
685             case TIMESTAMPOID:
686             case TIMESTAMPTZOID:
687               {
688                 if ( r->integer_datetimes)
689                   {
690                     int64_t x;
691
692                     GET_VALUE (&vptr, x);
693
694                     x /= 1000000;
695
696                     val->f = (x + r->postgres_epoch * 24 * 3600 );
697                   }
698                 else
699                   {
700                     double x;
701
702                     GET_VALUE (&vptr, x);
703
704                     val->f = (x + r->postgres_epoch * 24 * 3600 );
705                   }
706               }
707               break;
708             case TEXTOID:
709             case VARCHAROID:
710             case BPCHAROID:
711             case BYTEAOID:
712               memcpy (val->s, (char *) vptr, MIN (length, var_width));
713               break;
714
715             case NUMERICOID:
716               {
717                 double f = 0.0;
718                 int i;
719                 int16_t n_digits, weight, dscale;
720                 uint16_t sign;
721
722                 GET_VALUE (&vptr, n_digits);
723                 GET_VALUE (&vptr, weight);
724                 GET_VALUE (&vptr, sign);
725                 GET_VALUE (&vptr, dscale);
726
727                 {
728                   struct fmt_spec fmt;
729                   fmt.d = dscale;
730                   fmt.type = FMT_E;
731                   fmt.w = fmt_max_output_width (fmt.type) ;
732                   fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
733                   var_set_both_formats (v, &fmt);
734                 }
735
736                 for (i = 0 ; i < n_digits;  ++i)
737                   {
738                     uint16_t x;
739                     GET_VALUE (&vptr, x);
740                     f += x * pow (10000, weight--);
741                   }
742
743                 if ( sign == 0x4000)
744                   f *= -1.0;
745
746                 if ( sign == 0xC000)
747                   val->f = SYSMIS;
748                 else
749                   val->f = f;
750               }
751               break;
752
753             default:
754               val->f = SYSMIS;
755               break;
756             }
757         }
758     }
759 }
760
761 #endif