dictionary: Make dict_create() take the new dictionary's encoding.
[pspp-builds.git] / src / data / psql-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2008, 2009, 2010, 2011 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include "data/psql-reader.h"
20
21 #include <inttypes.h>
22 #include <math.h>
23 #include <stdlib.h>
24
25 #include "data/calendar.h"
26 #include "data/casereader-provider.h"
27 #include "data/dictionary.h"
28 #include "data/format.h"
29 #include "data/variable.h"
30 #include "libpspp/i18n.h"
31 #include "libpspp/message.h"
32 #include "libpspp/misc.h"
33 #include "libpspp/str.h"
34
35 #include "gl/xalloc.h"
36 #include "gl/minmax.h"
37
38 #include "gettext.h"
39 #define _(msgid) gettext (msgid)
40 #define N_(msgid) (msgid)
41
42
43 #if !PSQL_SUPPORT
44 struct casereader *
45 psql_open_reader (struct psql_read_info *info UNUSED, struct dictionary **dict UNUSED)
46 {
47   msg (ME, _("Support for reading postgres databases was not compiled into this installation of PSPP"));
48
49   return NULL;
50 }
51
52 #else
53
54 #include <stdint.h>
55 #include <libpq-fe.h>
56
57
58 /* Default width of string variables. */
59 #define PSQL_DEFAULT_WIDTH 8
60
61 /* These macros  must be the same as in catalog/pg_types.h from the postgres source */
62 #define BOOLOID            16
63 #define BYTEAOID           17
64 #define CHAROID            18
65 #define NAMEOID            19
66 #define INT8OID            20
67 #define INT2OID            21
68 #define INT4OID            23
69 #define TEXTOID            25
70 #define OIDOID             26
71 #define FLOAT4OID          700
72 #define FLOAT8OID          701
73 #define CASHOID            790
74 #define BPCHAROID          1042
75 #define VARCHAROID         1043
76 #define DATEOID            1082
77 #define TIMEOID            1083
78 #define TIMESTAMPOID       1114
79 #define TIMESTAMPTZOID     1184
80 #define INTERVALOID        1186
81 #define TIMETZOID          1266
82 #define NUMERICOID         1700
83
84 static void psql_casereader_destroy (struct casereader *reader UNUSED, void *r_);
85
86 static struct ccase *psql_casereader_read (struct casereader *, void *);
87
88 static const struct casereader_class psql_casereader_class =
89   {
90     psql_casereader_read,
91     psql_casereader_destroy,
92     NULL,
93     NULL,
94   };
95
96 struct psql_reader
97 {
98   PGconn *conn;
99   PGresult *res;
100   int tuple;
101
102   bool integer_datetimes;
103
104   double postgres_epoch;
105
106   struct caseproto *proto;
107   struct dictionary *dict;
108
109   /* An array of ints, which maps psql column numbers into
110      pspp variables */
111   struct variable **vmap;
112   size_t vmapsize;
113
114   struct string fetch_cmd;
115   int cache_size;
116 };
117
118
119 static struct ccase *set_value (struct psql_reader *r);
120
121
122
123 #if WORDS_BIGENDIAN
124 static void
125 data_to_native (const void *in_, void *out_, int len)
126 {
127   int i;
128   const unsigned char *in = in_;
129   unsigned char *out = out_;
130   for (i = 0 ; i < len ; ++i )
131     out[i] = in[i];
132 }
133 #else
134 static void
135 data_to_native (const void *in_, void *out_, int len)
136 {
137   int i;
138   const unsigned char *in = in_;
139   unsigned char *out = out_;
140   for (i = 0 ; i < len ; ++i )
141     out[len - i - 1] = in[i];
142 }
143 #endif
144
145
146 #define GET_VALUE(IN, OUT) do { \
147     size_t sz = sizeof (OUT); \
148     data_to_native (*(IN), &(OUT), sz) ; \
149     (*IN) += sz; \
150 } while (false)
151
152
153 #if 0
154 static void
155 dump (const unsigned char *x, int l)
156 {
157   int i;
158
159   for (i = 0; i < l ; ++i)
160     {
161       printf ("%02x ", x[i]);
162     }
163
164   putchar ('\n');
165
166   for (i = 0; i < l ; ++i)
167     {
168       if ( isprint (x[i]))
169         printf ("%c ", x[i]);
170       else
171         printf ("   ");
172     }
173
174   putchar ('\n');
175 }
176 #endif
177
178 static struct variable *
179 create_var (struct psql_reader *r, const struct fmt_spec *fmt,
180             int width, const char *suggested_name, int col)
181 {
182   unsigned long int vx = 0;
183   struct variable *var;
184   char *name;
185
186   name = dict_make_unique_var_name (r->dict, suggested_name, &vx);
187   var = dict_create_var (r->dict, name, width);
188   free (name);
189
190   var_set_both_formats (var, fmt);
191
192   if ( col != -1)
193     {
194       r->vmap = xrealloc (r->vmap, (col + 1) * sizeof (*r->vmap));
195
196       r->vmap[col] = var;
197       r->vmapsize = col + 1;
198     }
199
200   return var;
201 }
202
203
204
205
206 /* Fill the cache */
207 static bool
208 reload_cache (struct psql_reader *r)
209 {
210   PQclear (r->res);
211   r->tuple = 0;
212
213   r->res = PQexec (r->conn, ds_cstr (&r->fetch_cmd));
214
215   if (PQresultStatus (r->res) != PGRES_TUPLES_OK || PQntuples (r->res) < 1)
216     {
217       PQclear (r->res);
218       r->res = NULL;
219       return false;
220     }
221
222   return true;
223 }
224
225
226 struct casereader *
227 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
228 {
229   int i;
230   int n_fields, n_tuples;
231   PGresult *qres = NULL;
232   casenumber n_cases = CASENUMBER_MAX;
233   const char *encoding;
234
235   struct psql_reader *r = xzalloc (sizeof *r);
236   struct string query ;
237
238   r->conn = PQconnectdb (info->conninfo);
239   if ( NULL == r->conn)
240     {
241       msg (ME, _("Memory error whilst opening psql source"));
242       goto error;
243     }
244
245   if ( PQstatus (r->conn) != CONNECTION_OK )
246     {
247       msg (ME, _("Error opening psql source: %s."),
248            PQerrorMessage (r->conn));
249
250       goto error;
251     }
252
253   {
254     int ver_num;
255     const char *vers = PQparameterStatus (r->conn, "server_version");
256
257     sscanf (vers, "%d", &ver_num);
258
259     if ( ver_num < 8)
260       {
261         msg (ME,
262              _("Postgres server is version %s."
263                " Reading from versions earlier than 8.0 is not supported."),
264              vers);
265
266         goto error;
267       }
268   }
269
270   {
271     const char *dt =  PQparameterStatus (r->conn, "integer_datetimes");
272
273     r->integer_datetimes = ( 0 == strcasecmp (dt, "on"));
274   }
275
276 #if USE_SSL
277   if ( PQgetssl (r->conn) == NULL)
278 #endif
279     {
280       if (! info->allow_clear)
281         {
282           msg (ME, _("Connection is unencrypted, "
283                      "but unencrypted connections have not been permitted."));
284           goto error;
285         }
286     }
287
288   r->postgres_epoch = calendar_gregorian_to_offset (2000, 1, 1, NULL);
289
290   {
291     const int enc = PQclientEncoding (r->conn);
292
293     /* According to section 22.2 of the Postgresql manual
294        a value of zero (SQL_ASCII) indicates
295        "a declaration of ignorance about the encoding".
296        Accordingly, we use the default encoding
297        if we find this value.
298     */
299     encoding = enc ? pg_encoding_to_char (enc) : get_default_encoding ();
300   }
301
302   /* Create the dictionary and populate it */
303   *dict = r->dict = dict_create ();
304
305   /*
306     select count (*) from (select * from medium) stupid_sql_standard;
307   */
308   ds_init_cstr (&query,
309                 "BEGIN READ ONLY ISOLATION LEVEL SERIALIZABLE; "
310                 "DECLARE  pspp BINARY CURSOR FOR ");
311
312   ds_put_substring (&query, info->sql.ss);
313
314   qres = PQexec (r->conn, ds_cstr (&query));
315   ds_destroy (&query);
316   if ( PQresultStatus (qres) != PGRES_COMMAND_OK )
317     {
318       msg (ME, _("Error from psql source: %s."),
319            PQresultErrorMessage (qres));
320       goto error;
321     }
322
323   PQclear (qres);
324
325
326   /* Now use the count() function to find the total number of cases
327      that this query returns.
328      Doing this incurs some overhead.  The server has to iterate every
329      case in order to find this number.  However, it's performed on the
330      server side, and in all except the most huge databases the extra
331      overhead will be worth the effort.
332      On the other hand, most PSPP functions don't need to know this.
333      The GUI is the notable exception.
334   */
335   ds_init_cstr (&query, "SELECT count (*) FROM (");
336   ds_put_substring (&query, info->sql.ss);
337   ds_put_cstr (&query, ") stupid_sql_standard");
338
339   qres = PQexec (r->conn, ds_cstr (&query));
340   ds_destroy (&query);
341   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
342     {
343       msg (ME, _("Error from psql source: %s."),
344            PQresultErrorMessage (qres));
345       goto error;
346     }
347   n_cases = atol (PQgetvalue (qres, 0, 0));
348   PQclear (qres);
349
350   qres = PQexec (r->conn, "FETCH FIRST FROM pspp");
351   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
352     {
353       msg (ME, _("Error from psql source: %s."),
354            PQresultErrorMessage (qres));
355       goto error;
356     }
357
358   n_tuples = PQntuples (qres);
359   n_fields = PQnfields (qres);
360
361   r->proto = NULL;
362   r->vmap = NULL;
363   r->vmapsize = 0;
364
365   for (i = 0 ; i < n_fields ; ++i )
366     {
367       struct variable *var;
368       struct fmt_spec fmt = {FMT_F, 8, 2};
369       Oid type = PQftype (qres, i);
370       int width = 0;
371       int length ;
372
373       /* If there are no data then make a finger in the air 
374          guess at the contents */
375       if ( n_tuples > 0 )
376         length = PQgetlength (qres, 0, i);
377       else 
378         length = PSQL_DEFAULT_WIDTH;
379
380       switch (type)
381         {
382         case BOOLOID:
383         case OIDOID:
384         case INT2OID:
385         case INT4OID:
386         case INT8OID:
387         case FLOAT4OID:
388         case FLOAT8OID:
389           fmt.type = FMT_F;
390           break;
391         case CASHOID:
392           fmt.type = FMT_DOLLAR;
393           break;
394         case CHAROID:
395           fmt.type = FMT_A;
396           width = length > 0 ? length : 1;
397           fmt.d = 0;
398           fmt.w = 1;
399           break;
400         case TEXTOID:
401         case VARCHAROID:
402         case BPCHAROID:
403           fmt.type = FMT_A;
404           width = (info->str_width == -1) ?
405             ROUND_UP (length, PSQL_DEFAULT_WIDTH) : info->str_width;
406           fmt.w = width;
407           fmt.d = 0;
408           break;
409         case BYTEAOID:
410           fmt.type = FMT_AHEX;
411           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
412           fmt.w = width * 2;
413           fmt.d = 0;
414           break;
415         case INTERVALOID:
416           fmt.type = FMT_DTIME;
417           width = 0;
418           fmt.d = 0;
419           fmt.w = 13;
420           break;
421         case DATEOID:
422           fmt.type = FMT_DATE;
423           width = 0;
424           fmt.w = 11;
425           fmt.d = 0;
426           break;
427         case TIMEOID:
428         case TIMETZOID:
429           fmt.type = FMT_TIME;
430           width = 0;
431           fmt.w = 11;
432           fmt.d = 0;
433           break;
434         case TIMESTAMPOID:
435         case TIMESTAMPTZOID:
436           fmt.type = FMT_DATETIME;
437           fmt.d = 0;
438           fmt.w = 22;
439           width = 0;
440           break;
441         case NUMERICOID:
442           fmt.type = FMT_E;
443           fmt.d = 2;
444           fmt.w = 40;
445           width = 0;
446           break;
447         default:
448           msg (MW, _("Unsupported OID %d.  SYSMIS values will be inserted."), type);
449           fmt.type = FMT_A;
450           width = length > 0 ? length : PSQL_DEFAULT_WIDTH;
451           fmt.w = width ;
452           fmt.d = 0;
453           break;
454         }
455
456       if ( width == 0 && fmt_is_string (fmt.type))
457         fmt.w = width = PSQL_DEFAULT_WIDTH;
458
459
460       var = create_var (r, &fmt, width, PQfname (qres, i), i);
461       if ( type == NUMERICOID && n_tuples > 0)
462         {
463           const uint8_t *vptr = (const uint8_t *) PQgetvalue (qres, 0, i);
464           struct fmt_spec fmt;
465           int16_t n_digits, weight, dscale;
466           uint16_t sign;
467
468           GET_VALUE (&vptr, n_digits);
469           GET_VALUE (&vptr, weight);
470           GET_VALUE (&vptr, sign);
471           GET_VALUE (&vptr, dscale);
472
473           fmt.d = dscale;
474           fmt.type = FMT_E;
475           fmt.w = fmt_max_output_width (fmt.type) ;
476           fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
477           var_set_both_formats (var, &fmt);
478         }
479
480       /* Timezones need an extra variable */
481       switch (type)
482         {
483         case TIMETZOID:
484           {
485             struct string name;
486             ds_init_cstr (&name, var_get_name (var));
487             ds_put_cstr (&name, "-zone");
488             fmt.type = FMT_F;
489             fmt.w = 8;
490             fmt.d = 2;
491
492             create_var (r, &fmt, 0, ds_cstr (&name), -1);
493
494             ds_destroy (&name);
495           }
496           break;
497
498         case INTERVALOID:
499           {
500             struct string name;
501             ds_init_cstr (&name, var_get_name (var));
502             ds_put_cstr (&name, "-months");
503             fmt.type = FMT_F;
504             fmt.w = 3;
505             fmt.d = 0;
506
507             create_var (r, &fmt, 0, ds_cstr (&name), -1);
508
509             ds_destroy (&name);
510           }
511         default:
512           break;
513         }
514     }
515
516   PQclear (qres);
517
518   qres = PQexec (r->conn, "MOVE BACKWARD 1 FROM pspp");
519   if ( PQresultStatus (qres) != PGRES_COMMAND_OK)
520     {
521       PQclear (qres);
522       goto error;
523     }
524   PQclear (qres);
525
526   r->cache_size = info->bsize != -1 ? info->bsize: 4096;
527
528   ds_init_empty (&r->fetch_cmd);
529   ds_put_format (&r->fetch_cmd,  "FETCH FORWARD %d FROM pspp", r->cache_size);
530
531   reload_cache (r);
532   r->proto = caseproto_ref (dict_get_proto (*dict));
533
534   return casereader_create_sequential
535     (NULL,
536      r->proto,
537      n_cases,
538      &psql_casereader_class, r);
539
540  error:
541   dict_destroy (*dict);
542
543   psql_casereader_destroy (NULL, r);
544   return NULL;
545 }
546
547
548 static void
549 psql_casereader_destroy (struct casereader *reader UNUSED, void *r_)
550 {
551   struct psql_reader *r = r_;
552   if (r == NULL)
553     return ;
554
555   ds_destroy (&r->fetch_cmd);
556   free (r->vmap);
557   if (r->res) PQclear (r->res);
558   PQfinish (r->conn);
559   caseproto_unref (r->proto);
560
561   free (r);
562 }
563
564
565
566 static struct ccase *
567 psql_casereader_read (struct casereader *reader UNUSED, void *r_)
568 {
569   struct psql_reader *r = r_;
570
571   if ( NULL == r->res || r->tuple >= r->cache_size)
572     {
573       if ( ! reload_cache (r) )
574         return false;
575     }
576
577   return set_value (r);
578 }
579
580 static struct ccase *
581 set_value (struct psql_reader *r)
582 {
583   struct ccase *c;
584   int n_vars;
585   int i;
586
587   assert (r->res);
588
589   n_vars = PQnfields (r->res);
590
591   if ( r->tuple >= PQntuples (r->res))
592     return NULL;
593
594   c = case_create (r->proto);
595   case_set_missing (c);
596
597
598   for (i = 0 ; i < n_vars ; ++i )
599     {
600       Oid type = PQftype (r->res, i);
601       const struct variable *v = r->vmap[i];
602       union value *val = case_data_rw (c, v);
603
604       union value *val1 = NULL;
605
606       switch (type)
607         {
608         case INTERVALOID:
609         case TIMESTAMPTZOID:
610         case TIMETZOID:
611           if (i < r->vmapsize && var_get_dict_index(v) + 1 < dict_get_var_cnt (r->dict))
612             {
613               const struct variable *v1 = NULL;
614               v1 = dict_get_var (r->dict, var_get_dict_index (v) + 1);
615
616               val1 = case_data_rw (c, v1);
617             }
618           break;
619         default:
620           break;
621         }
622
623
624       if (PQgetisnull (r->res, r->tuple, i))
625         {
626           value_set_missing (val, var_get_width (v));
627
628           switch (type)
629             {
630             case INTERVALOID:
631             case TIMESTAMPTZOID:
632             case TIMETZOID:
633               val1->f = SYSMIS;
634               break;
635             default:
636               break;
637             }
638         }
639       else
640         {
641           const uint8_t *vptr = (const uint8_t *) PQgetvalue (r->res, r->tuple, i);
642           int length = PQgetlength (r->res, r->tuple, i);
643
644           int var_width = var_get_width (v);
645           switch (type)
646             {
647             case BOOLOID:
648               {
649                 int8_t x;
650                 GET_VALUE (&vptr, x);
651                 val->f = x;
652               }
653               break;
654
655             case OIDOID:
656             case INT2OID:
657               {
658                 int16_t x;
659                 GET_VALUE (&vptr, x);
660                 val->f = x;
661               }
662               break;
663
664             case INT4OID:
665               {
666                 int32_t x;
667                 GET_VALUE (&vptr, x);
668                 val->f = x;
669               }
670               break;
671
672             case INT8OID:
673               {
674                 int64_t x;
675                 GET_VALUE (&vptr, x);
676                 val->f = x;
677               }
678               break;
679
680             case FLOAT4OID:
681               {
682                 float n;
683                 GET_VALUE (&vptr, n);
684                 val->f = n;
685               }
686               break;
687
688             case FLOAT8OID:
689               {
690                 double n;
691                 GET_VALUE (&vptr, n);
692                 val->f = n;
693               }
694               break;
695
696             case CASHOID:
697               {
698                 /* Postgres 8.3 uses 64 bits.
699                    Earlier versions use 32 */
700                 switch (length)
701                   {
702                   case 8:
703                     {
704                       int64_t x;
705                       GET_VALUE (&vptr, x);
706                       val->f = x / 100.0;
707                     }
708                     break;
709                   case 4:
710                     {
711                       int32_t x;
712                       GET_VALUE (&vptr, x);
713                       val->f = x / 100.0;
714                     }
715                     break;
716                   default:
717                     val->f = SYSMIS;
718                     break;
719                   }
720               }
721               break;
722
723             case INTERVALOID:
724               {
725                 if ( r->integer_datetimes )
726                   {
727                     uint32_t months;
728                     uint32_t days;
729                     uint32_t us;
730                     uint32_t things;
731
732                     GET_VALUE (&vptr, things);
733                     GET_VALUE (&vptr, us);
734                     GET_VALUE (&vptr, days);
735                     GET_VALUE (&vptr, months);
736
737                     val->f = us / 1000000.0;
738                     val->f += days * 24 * 3600;
739
740                     val1->f = months;
741                   }
742                 else
743                   {
744                     uint32_t days, months;
745                     double seconds;
746
747                     GET_VALUE (&vptr, seconds);
748                     GET_VALUE (&vptr, days);
749                     GET_VALUE (&vptr, months);
750
751                     val->f = seconds;
752                     val->f += days * 24 * 3600;
753
754                     val1->f = months;
755                   }
756               }
757               break;
758
759             case DATEOID:
760               {
761                 int32_t x;
762
763                 GET_VALUE (&vptr, x);
764
765                 val->f = (x + r->postgres_epoch) * 24 * 3600 ;
766               }
767               break;
768
769             case TIMEOID:
770               {
771                 if ( r->integer_datetimes)
772                   {
773                     uint64_t x;
774                     GET_VALUE (&vptr, x);
775                     val->f = x / 1000000.0;
776                   }
777                 else
778                   {
779                     double x;
780                     GET_VALUE (&vptr, x);
781                     val->f = x;
782                   }
783               }
784               break;
785
786             case TIMETZOID:
787               {
788                 int32_t zone;
789                 if ( r->integer_datetimes)
790                   {
791                     uint64_t x;
792
793
794                     GET_VALUE (&vptr, x);
795                     val->f = x / 1000000.0;
796                   }
797                 else
798                   {
799                     double x;
800
801                     GET_VALUE (&vptr, x);
802                     val->f = x ;
803                   }
804
805                 GET_VALUE (&vptr, zone);
806                 val1->f = zone / 3600.0;
807               }
808               break;
809
810             case TIMESTAMPOID:
811             case TIMESTAMPTZOID:
812               {
813                 if ( r->integer_datetimes)
814                   {
815                     int64_t x;
816
817                     GET_VALUE (&vptr, x);
818
819                     x /= 1000000;
820
821                     val->f = (x + r->postgres_epoch * 24 * 3600 );
822                   }
823                 else
824                   {
825                     double x;
826
827                     GET_VALUE (&vptr, x);
828
829                     val->f = (x + r->postgres_epoch * 24 * 3600 );
830                   }
831               }
832               break;
833             case TEXTOID:
834             case VARCHAROID:
835             case BPCHAROID:
836             case BYTEAOID:
837               memcpy (value_str_rw (val, var_width), vptr,
838                       MIN (length, var_width));
839               break;
840
841             case NUMERICOID:
842               {
843                 double f = 0.0;
844                 int i;
845                 int16_t n_digits, weight, dscale;
846                 uint16_t sign;
847
848                 GET_VALUE (&vptr, n_digits);
849                 GET_VALUE (&vptr, weight);
850                 GET_VALUE (&vptr, sign);
851                 GET_VALUE (&vptr, dscale);
852
853 #if 0
854                 {
855                   struct fmt_spec fmt;
856                   fmt.d = dscale;
857                   fmt.type = FMT_E;
858                   fmt.w = fmt_max_output_width (fmt.type) ;
859                   fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
860                   var_set_both_formats (v, &fmt);
861                 }
862 #endif
863
864                 for (i = 0 ; i < n_digits;  ++i)
865                   {
866                     uint16_t x;
867                     GET_VALUE (&vptr, x);
868                     f += x * pow (10000, weight--);
869                   }
870
871                 if ( sign == 0x4000)
872                   f *= -1.0;
873
874                 if ( sign == 0xC000)
875                   val->f = SYSMIS;
876                 else
877                   val->f = f;
878               }
879               break;
880
881             default:
882               val->f = SYSMIS;
883               break;
884             }
885         }
886     }
887
888   r->tuple++;
889
890   return c;
891 }
892
893 #endif