Reference count struct dictionary.
[pspp] / src / data / psql-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2008, 2009, 2010, 2011, 2012 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "data/psql-reader.h"
20
21 #include <inttypes.h>
22 #include <math.h>
23 #include <stdlib.h>
24
25 #include "data/calendar.h"
26 #include "data/casereader-provider.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/variable.h"
30 #include "libpspp/i18n.h"
31 #include "libpspp/message.h"
32 #include "libpspp/misc.h"
33 #include "libpspp/str.h"
34
35 #include "gl/c-strcase.h"
36 #include "gl/minmax.h"
37 #include "gl/xalloc.h"
38
39 #include "gettext.h"
40 #define _(msgid) gettext (msgid)
41 #define N_(msgid) (msgid)
42
43
44 #if !PSQL_SUPPORT
45 struct casereader *
46 psql_open_reader (struct psql_read_info *info UNUSED, struct dictionary **dict UNUSED)
47 {
48   msg (ME, _("Support for reading postgres databases was not compiled into this installation of PSPP"));
49
50   return NULL;
51 }
52
53 #else
54
55 #include <stdint.h>
56 #include <libpq-fe.h>
57
58
59 /* Default width of string variables. */
60 #define PSQL_DEFAULT_WIDTH 8
61
62 /* These macros  must be the same as in catalog/pg_types.h from the postgres source */
63 #define BOOLOID            16
64 #define BYTEAOID           17
65 #define CHAROID            18
66 #define NAMEOID            19
67 #define INT8OID            20
68 #define INT2OID            21
69 #define INT4OID            23
70 #define TEXTOID            25
71 #define OIDOID             26
72 #define FLOAT4OID          700
73 #define FLOAT8OID          701
74 #define CASHOID            790
75 #define BPCHAROID          1042
76 #define VARCHAROID         1043
77 #define DATEOID            1082
78 #define TIMEOID            1083
79 #define TIMESTAMPOID       1114
80 #define TIMESTAMPTZOID     1184
81 #define INTERVALOID        1186
82 #define TIMETZOID          1266
83 #define NUMERICOID         1700
84
85 static void psql_casereader_destroy (struct casereader *reader UNUSED, void *r_);
86
87 static struct ccase *psql_casereader_read (struct casereader *, void *);
88
89 static const struct casereader_class psql_casereader_class =
90   {
91     psql_casereader_read,
92     psql_casereader_destroy,
93     NULL,
94     NULL,
95   };
96
97 struct psql_reader
98 {
99   PGconn *conn;
100   PGresult *res;
101   int tuple;
102
103   bool integer_datetimes;
104
105   double postgres_epoch;
106
107   struct caseproto *proto;
108   struct dictionary *dict;
109
110   /* An array of ints, which maps psql column numbers into
111      pspp variables */
112   struct variable **vmap;
113   size_t vmapsize;
114
115   struct string fetch_cmd;
116   int cache_size;
117 };
118
119
120 static struct ccase *set_value (struct psql_reader *r);
121
122
123
124 #if WORDS_BIGENDIAN
125 static void
126 data_to_native (const void *in_, void *out_, int len)
127 {
128   int i;
129   const unsigned char *in = in_;
130   unsigned char *out = out_;
131   for (i = 0 ; i < len ; ++i )
132     out[i] = in[i];
133 }
134 #else
135 static void
136 data_to_native (const void *in_, void *out_, int len)
137 {
138   int i;
139   const unsigned char *in = in_;
140   unsigned char *out = out_;
141   for (i = 0 ; i < len ; ++i )
142     out[len - i - 1] = in[i];
143 }
144 #endif
145
146
147 #define GET_VALUE(IN, OUT) do { \
148     size_t sz = sizeof (OUT); \
149     data_to_native (*(IN), &(OUT), sz) ; \
150     (*IN) += sz; \
151 } while (false)
152
153
154 #if 0
155 static void
156 dump (const unsigned char *x, int l)
157 {
158   int i;
159
160   for (i = 0; i < l ; ++i)
161     {
162       printf ("%02x ", x[i]);
163     }
164
165   putchar ('\n');
166
167   for (i = 0; i < l ; ++i)
168     {
169       if ( isprint (x[i]))
170         printf ("%c ", x[i]);
171       else
172         printf ("   ");
173     }
174
175   putchar ('\n');
176 }
177 #endif
178
179 static struct variable *
180 create_var (struct psql_reader *r, const struct fmt_spec *fmt,
181             int width, const char *suggested_name, int col)
182 {
183   unsigned long int vx = 0;
184   struct variable *var;
185   char *name;
186
187   name = dict_make_unique_var_name (r->dict, suggested_name, &vx);
188   var = dict_create_var (r->dict, name, width);
189   free (name);
190
191   var_set_both_formats (var, fmt);
192
193   if ( col != -1)
194     {
195       r->vmap = xrealloc (r->vmap, (col + 1) * sizeof (*r->vmap));
196
197       r->vmap[col] = var;
198       r->vmapsize = col + 1;
199     }
200
201   return var;
202 }
203
204
205
206
207 /* Fill the cache */
208 static bool
209 reload_cache (struct psql_reader *r)
210 {
211   PQclear (r->res);
212   r->tuple = 0;
213
214   r->res = PQexec (r->conn, ds_cstr (&r->fetch_cmd));
215
216   if (PQresultStatus (r->res) != PGRES_TUPLES_OK || PQntuples (r->res) < 1)
217     {
218       PQclear (r->res);
219       r->res = NULL;
220       return false;
221     }
222
223   return true;
224 }
225
226
227 struct casereader *
228 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
229 {
230   int i;
231   int n_fields, n_tuples;
232   PGresult *qres = NULL;
233   casenumber n_cases = CASENUMBER_MAX;
234   const char *encoding;
235
236   struct psql_reader *r = xzalloc (sizeof *r);
237   struct string query ;
238
239   r->conn = PQconnectdb (info->conninfo);
240   if ( NULL == r->conn)
241     {
242       msg (ME, _("Memory error whilst opening psql source"));
243       goto error;
244     }
245
246   if ( PQstatus (r->conn) != CONNECTION_OK )
247     {
248       msg (ME, _("Error opening psql source: %s."),
249            PQerrorMessage (r->conn));
250
251       goto error;
252     }
253
254   {
255     int ver_num = 0;
256     const char *vers = PQparameterStatus (r->conn, "server_version");
257
258     sscanf (vers, "%d", &ver_num);
259
260     if ( ver_num < 8)
261       {
262         msg (ME,
263              _("Postgres server is version %s."
264                " Reading from versions earlier than 8.0 is not supported."),
265              vers);
266
267         goto error;
268       }
269   }
270
271   {
272     const char *dt =  PQparameterStatus (r->conn, "integer_datetimes");
273
274     r->integer_datetimes = ( 0 == c_strcasecmp (dt, "on"));
275   }
276
277 #if USE_SSL
278   if ( PQgetssl (r->conn) == NULL)
279 #endif
280     {
281       if (! info->allow_clear)
282         {
283           msg (ME, _("Connection is unencrypted, "
284                      "but unencrypted connections have not been permitted."));
285           goto error;
286         }
287     }
288
289   r->postgres_epoch = calendar_gregorian_to_offset (2000, 1, 1, NULL);
290
291   {
292     const int enc = PQclientEncoding (r->conn);
293
294     /* According to section 22.2 of the Postgresql manual
295        a value of zero (SQL_ASCII) indicates
296        "a declaration of ignorance about the encoding".
297        Accordingly, we use the default encoding
298        if we find this value.
299     */
300     encoding = enc ? pg_encoding_to_char (enc) : get_default_encoding ();
301
302     /* Create the dictionary and populate it */
303     *dict = r->dict = dict_create (encoding);
304   }
305
306   const int version = PQserverVersion (r->conn);
307   ds_init_empty (&query);
308   /*
309     Versions before 9.1 don't have the REPEATABLE READ isolation level.
310     However according to <a12321aabb@gmail.com> if the server is in the
311     "hot standby" mode then SERIALIZABLE won't work.
312    */
313   ds_put_c_format (&query,
314                    "BEGIN READ ONLY ISOLATION LEVEL %s; "
315                    "DECLARE  pspp BINARY CURSOR FOR ",
316                    (version < 90100) ? "SERIALIZABLE" : "REPEATABLE READ");
317
318   ds_put_substring (&query, info->sql.ss);
319
320   qres = PQexec (r->conn, ds_cstr (&query));
321   ds_destroy (&query);
322   if ( PQresultStatus (qres) != PGRES_COMMAND_OK )
323     {
324       msg (ME, _("Error from psql source: %s."),
325            PQresultErrorMessage (qres));
326       goto error;
327     }
328
329   PQclear (qres);
330
331
332   /* Now use the count() function to find the total number of cases
333      that this query returns.
334      Doing this incurs some overhead.  The server has to iterate every
335      case in order to find this number.  However, it's performed on the
336      server side, and in all except the most huge databases the extra
337      overhead will be worth the effort.
338      On the other hand, most PSPP functions don't need to know this.
339      The GUI is the notable exception.
340   */
341   ds_init_cstr (&query, "SELECT count (*) FROM (");
342   ds_put_substring (&query, info->sql.ss);
343   ds_put_cstr (&query, ") stupid_sql_standard");
344
345   qres = PQexec (r->conn, ds_cstr (&query));
346   ds_destroy (&query);
347   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
348     {
349       msg (ME, _("Error from psql source: %s."),
350            PQresultErrorMessage (qres));
351       goto error;
352     }
353   n_cases = atol (PQgetvalue (qres, 0, 0));
354   PQclear (qres);
355
356   qres = PQexec (r->conn, "FETCH FIRST FROM pspp");
357   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
358     {
359       msg (ME, _("Error from psql source: %s."),
360            PQresultErrorMessage (qres));
361       goto error;
362     }
363
364   n_tuples = PQntuples (qres);
365   n_fields = PQnfields (qres);
366
367   r->proto = NULL;
368   r->vmap = NULL;
369   r->vmapsize = 0;
370
371   for (i = 0 ; i < n_fields ; ++i )
372     {
373       struct variable *var;
374       struct fmt_spec fmt = {FMT_F, 8, 2};
375       Oid type = PQftype (qres, i);
376       int width = 0;
377       int length ;
378
379       /* If there are no data then make a finger in the air
380          guess at the contents */
381       if ( n_tuples > 0 )
382         length = PQgetlength (qres, 0, i);
383       else
384         length = PSQL_DEFAULT_WIDTH;
385
386       switch (type)
387         {
388         case BOOLOID:
389         case OIDOID:
390         case INT2OID:
391         case INT4OID:
392         case INT8OID:
393         case FLOAT4OID:
394         case FLOAT8OID:
395           fmt.type = FMT_F;
396           break;
397         case CASHOID:
398           fmt.type = FMT_DOLLAR;
399           break;
400         case CHAROID:
401           fmt.type = FMT_A;
402           width = length > 0 ? length : 1;
403           fmt.d = 0;
404           fmt.w = 1;
405           break;
406         case TEXTOID:
407         case VARCHAROID:
408         case BPCHAROID:
409           fmt.type = FMT_A;
410           width = (info->str_width == -1) ?
411             ROUND_UP (length, PSQL_DEFAULT_WIDTH) : info->str_width;
412           fmt.w = width;
413           fmt.d = 0;
414           break;
415         case BYTEAOID:
416           fmt.type = FMT_AHEX;
417           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
418           fmt.w = width * 2;
419           fmt.d = 0;
420           break;
421         case INTERVALOID:
422           fmt.type = FMT_DTIME;
423           width = 0;
424           fmt.d = 0;
425           fmt.w = 13;
426           break;
427         case DATEOID:
428           fmt.type = FMT_DATE;
429           width = 0;
430           fmt.w = 11;
431           fmt.d = 0;
432           break;
433         case TIMEOID:
434         case TIMETZOID:
435           fmt.type = FMT_TIME;
436           width = 0;
437           fmt.w = 11;
438           fmt.d = 0;
439           break;
440         case TIMESTAMPOID:
441         case TIMESTAMPTZOID:
442           fmt.type = FMT_DATETIME;
443           fmt.d = 0;
444           fmt.w = 22;
445           width = 0;
446           break;
447         case NUMERICOID:
448           fmt.type = FMT_E;
449           fmt.d = 2;
450           fmt.w = 40;
451           width = 0;
452           break;
453         default:
454           msg (MW, _("Unsupported OID %d.  SYSMIS values will be inserted."), type);
455           fmt.type = FMT_A;
456           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
457           fmt.w = width ;
458           fmt.d = 0;
459           break;
460         }
461
462       if ( width == 0 && fmt_is_string (fmt.type))
463         fmt.w = width = PSQL_DEFAULT_WIDTH;
464
465
466       var = create_var (r, &fmt, width, PQfname (qres, i), i);
467       if ( type == NUMERICOID && n_tuples > 0)
468         {
469           const uint8_t *vptr = (const uint8_t *) PQgetvalue (qres, 0, i);
470           struct fmt_spec fmt;
471           int16_t n_digits, weight, dscale;
472           uint16_t sign;
473
474           GET_VALUE (&vptr, n_digits);
475           GET_VALUE (&vptr, weight);
476           GET_VALUE (&vptr, sign);
477           GET_VALUE (&vptr, dscale);
478
479           fmt.d = dscale;
480           fmt.type = FMT_E;
481           fmt.w = fmt_max_output_width (fmt.type) ;
482           fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
483           var_set_both_formats (var, &fmt);
484         }
485
486       /* Timezones need an extra variable */
487       switch (type)
488         {
489         case TIMETZOID:
490           {
491             struct string name;
492             ds_init_cstr (&name, var_get_name (var));
493             ds_put_cstr (&name, "-zone");
494             fmt.type = FMT_F;
495             fmt.w = 8;
496             fmt.d = 2;
497
498             create_var (r, &fmt, 0, ds_cstr (&name), -1);
499
500             ds_destroy (&name);
501           }
502           break;
503
504         case INTERVALOID:
505           {
506             struct string name;
507             ds_init_cstr (&name, var_get_name (var));
508             ds_put_cstr (&name, "-months");
509             fmt.type = FMT_F;
510             fmt.w = 3;
511             fmt.d = 0;
512
513             create_var (r, &fmt, 0, ds_cstr (&name), -1);
514
515             ds_destroy (&name);
516           }
517         default:
518           break;
519         }
520     }
521
522   PQclear (qres);
523
524   qres = PQexec (r->conn, "MOVE BACKWARD 1 FROM pspp");
525   if ( PQresultStatus (qres) != PGRES_COMMAND_OK)
526     {
527       PQclear (qres);
528       goto error;
529     }
530   PQclear (qres);
531
532   r->cache_size = info->bsize != -1 ? info->bsize: 4096;
533
534   ds_init_empty (&r->fetch_cmd);
535   ds_put_format (&r->fetch_cmd,  "FETCH FORWARD %d FROM pspp", r->cache_size);
536
537   reload_cache (r);
538   r->proto = caseproto_ref (dict_get_proto (*dict));
539
540   return casereader_create_sequential
541     (NULL,
542      r->proto,
543      n_cases,
544      &psql_casereader_class, r);
545
546  error:
547   dict_unref (*dict);
548
549   psql_casereader_destroy (NULL, r);
550   return NULL;
551 }
552
553
554 static void
555 psql_casereader_destroy (struct casereader *reader UNUSED, void *r_)
556 {
557   struct psql_reader *r = r_;
558   if (r == NULL)
559     return ;
560
561   ds_destroy (&r->fetch_cmd);
562   free (r->vmap);
563   if (r->res) PQclear (r->res);
564   PQfinish (r->conn);
565   caseproto_unref (r->proto);
566
567   free (r);
568 }
569
570
571
572 static struct ccase *
573 psql_casereader_read (struct casereader *reader UNUSED, void *r_)
574 {
575   struct psql_reader *r = r_;
576
577   if ( NULL == r->res || r->tuple >= r->cache_size)
578     {
579       if ( ! reload_cache (r) )
580         return false;
581     }
582
583   return set_value (r);
584 }
585
586 static struct ccase *
587 set_value (struct psql_reader *r)
588 {
589   struct ccase *c;
590   int n_vars;
591   int i;
592
593   assert (r->res);
594
595   n_vars = PQnfields (r->res);
596
597   if ( r->tuple >= PQntuples (r->res))
598     return NULL;
599
600   c = case_create (r->proto);
601   case_set_missing (c);
602
603
604   for (i = 0 ; i < n_vars ; ++i )
605     {
606       Oid type = PQftype (r->res, i);
607       const struct variable *v = r->vmap[i];
608       union value *val = case_data_rw (c, v);
609
610       union value *val1 = NULL;
611
612       switch (type)
613         {
614         case INTERVALOID:
615         case TIMESTAMPTZOID:
616         case TIMETZOID:
617           if (i < r->vmapsize && var_get_dict_index(v) + 1 < dict_get_var_cnt (r->dict))
618             {
619               const struct variable *v1 = NULL;
620               v1 = dict_get_var (r->dict, var_get_dict_index (v) + 1);
621
622               val1 = case_data_rw (c, v1);
623             }
624           break;
625         default:
626           break;
627         }
628
629
630       if (PQgetisnull (r->res, r->tuple, i))
631         {
632           value_set_missing (val, var_get_width (v));
633
634           switch (type)
635             {
636             case INTERVALOID:
637             case TIMESTAMPTZOID:
638             case TIMETZOID:
639               val1->f = SYSMIS;
640               break;
641             default:
642               break;
643             }
644         }
645       else
646         {
647           const uint8_t *vptr = (const uint8_t *) PQgetvalue (r->res, r->tuple, i);
648           int length = PQgetlength (r->res, r->tuple, i);
649
650           int var_width = var_get_width (v);
651           switch (type)
652             {
653             case BOOLOID:
654               {
655                 int8_t x;
656                 GET_VALUE (&vptr, x);
657                 val->f = x;
658               }
659               break;
660
661             case OIDOID:
662             case INT2OID:
663               {
664                 int16_t x;
665                 GET_VALUE (&vptr, x);
666                 val->f = x;
667               }
668               break;
669
670             case INT4OID:
671               {
672                 int32_t x;
673                 GET_VALUE (&vptr, x);
674                 val->f = x;
675               }
676               break;
677
678             case INT8OID:
679               {
680                 int64_t x;
681                 GET_VALUE (&vptr, x);
682                 val->f = x;
683               }
684               break;
685
686             case FLOAT4OID:
687               {
688                 float n;
689                 GET_VALUE (&vptr, n);
690                 val->f = n;
691               }
692               break;
693
694             case FLOAT8OID:
695               {
696                 double n;
697                 GET_VALUE (&vptr, n);
698                 val->f = n;
699               }
700               break;
701
702             case CASHOID:
703               {
704                 /* Postgres 8.3 uses 64 bits.
705                    Earlier versions use 32 */
706                 switch (length)
707                   {
708                   case 8:
709                     {
710                       int64_t x;
711                       GET_VALUE (&vptr, x);
712                       val->f = x / 100.0;
713                     }
714                     break;
715                   case 4:
716                     {
717                       int32_t x;
718                       GET_VALUE (&vptr, x);
719                       val->f = x / 100.0;
720                     }
721                     break;
722                   default:
723                     val->f = SYSMIS;
724                     break;
725                   }
726               }
727               break;
728
729             case INTERVALOID:
730               {
731                 if ( r->integer_datetimes )
732                   {
733                     uint32_t months;
734                     uint32_t days;
735                     uint32_t us;
736                     uint32_t things;
737
738                     GET_VALUE (&vptr, things);
739                     GET_VALUE (&vptr, us);
740                     GET_VALUE (&vptr, days);
741                     GET_VALUE (&vptr, months);
742
743                     val->f = us / 1000000.0;
744                     val->f += days * 24 * 3600;
745
746                     val1->f = months;
747                   }
748                 else
749                   {
750                     uint32_t days, months;
751                     double seconds;
752
753                     GET_VALUE (&vptr, seconds);
754                     GET_VALUE (&vptr, days);
755                     GET_VALUE (&vptr, months);
756
757                     val->f = seconds;
758                     val->f += days * 24 * 3600;
759
760                     val1->f = months;
761                   }
762               }
763               break;
764
765             case DATEOID:
766               {
767                 int32_t x;
768
769                 GET_VALUE (&vptr, x);
770
771                 val->f = (x + r->postgres_epoch) * 24 * 3600 ;
772               }
773               break;
774
775             case TIMEOID:
776               {
777                 if ( r->integer_datetimes)
778                   {
779                     uint64_t x;
780                     GET_VALUE (&vptr, x);
781                     val->f = x / 1000000.0;
782                   }
783                 else
784                   {
785                     double x;
786                     GET_VALUE (&vptr, x);
787                     val->f = x;
788                   }
789               }
790               break;
791
792             case TIMETZOID:
793               {
794                 int32_t zone;
795                 if ( r->integer_datetimes)
796                   {
797                     uint64_t x;
798
799
800                     GET_VALUE (&vptr, x);
801                     val->f = x / 1000000.0;
802                   }
803                 else
804                   {
805                     double x;
806
807                     GET_VALUE (&vptr, x);
808                     val->f = x ;
809                   }
810
811                 GET_VALUE (&vptr, zone);
812                 val1->f = zone / 3600.0;
813               }
814               break;
815
816             case TIMESTAMPOID:
817             case TIMESTAMPTZOID:
818               {
819                 if ( r->integer_datetimes)
820                   {
821                     int64_t x;
822
823                     GET_VALUE (&vptr, x);
824
825                     x /= 1000000;
826
827                     val->f = (x + r->postgres_epoch * 24 * 3600 );
828                   }
829                 else
830                   {
831                     double x;
832
833                     GET_VALUE (&vptr, x);
834
835                     val->f = (x + r->postgres_epoch * 24 * 3600 );
836                   }
837               }
838               break;
839             case TEXTOID:
840             case VARCHAROID:
841             case BPCHAROID:
842             case BYTEAOID:
843               memcpy (value_str_rw (val, var_width), vptr,
844                       MIN (length, var_width));
845               break;
846
847             case NUMERICOID:
848               {
849                 double f = 0.0;
850                 int i;
851                 int16_t n_digits, weight, dscale;
852                 uint16_t sign;
853
854                 GET_VALUE (&vptr, n_digits);
855                 GET_VALUE (&vptr, weight);
856                 GET_VALUE (&vptr, sign);
857                 GET_VALUE (&vptr, dscale);
858
859 #if 0
860                 {
861                   struct fmt_spec fmt;
862                   fmt.d = dscale;
863                   fmt.type = FMT_E;
864                   fmt.w = fmt_max_output_width (fmt.type) ;
865                   fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
866                   var_set_both_formats (v, &fmt);
867                 }
868 #endif
869
870                 for (i = 0 ; i < n_digits;  ++i)
871                   {
872                     uint16_t x;
873                     GET_VALUE (&vptr, x);
874                     f += x * pow (10000, weight--);
875                   }
876
877                 if ( sign == 0x4000)
878                   f *= -1.0;
879
880                 if ( sign == 0xC000)
881                   val->f = SYSMIS;
882                 else
883                   val->f = f;
884               }
885               break;
886
887             default:
888               val->f = SYSMIS;
889               break;
890             }
891         }
892     }
893
894   r->tuple++;
895
896   return c;
897 }
898
899 #endif