Change "union value" to dynamically allocate long strings.
[pspp-builds.git] / src / data / psql-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2008, 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/casereader-provider.h>
20 #include <libpspp/message.h>
21 #include <gl/xalloc.h>
22 #include <data/dictionary.h>
23 #include <math.h>
24 #include <stdlib.h>
25
26 #include "psql-reader.h"
27 #include "variable.h"
28 #include "format.h"
29 #include "calendar.h"
30
31 #include <inttypes.h>
32 #include <libpspp/misc.h>
33 #include <libpspp/str.h>
34
35 #include "minmax.h"
36
37 #include "gettext.h"
38 #define _(msgid) gettext (msgid)
39 #define N_(msgid) (msgid)
40
41
42 #if !PSQL_SUPPORT
43 struct casereader *
44 psql_open_reader (struct psql_read_info *info UNUSED, struct dictionary **dict UNUSED)
45 {
46   msg (ME, _("Support for reading postgres databases was not compiled into this installation of PSPP"));
47
48   return NULL;
49 }
50
51 #else
52
53 #include <stdint.h>
54 #include <libpq-fe.h>
55
56
57 /* These macros  must be the same as in catalog/pg_types.h from the postgres source */
58 #define BOOLOID            16
59 #define BYTEAOID           17
60 #define CHAROID            18
61 #define NAMEOID            19
62 #define INT8OID            20
63 #define INT2OID            21
64 #define INT4OID            23
65 #define TEXTOID            25
66 #define OIDOID             26
67 #define FLOAT4OID          700
68 #define FLOAT8OID          701
69 #define CASHOID            790
70 #define BPCHAROID          1042
71 #define VARCHAROID         1043
72 #define DATEOID            1082
73 #define TIMEOID            1083
74 #define TIMESTAMPOID       1114
75 #define TIMESTAMPTZOID     1184
76 #define INTERVALOID        1186
77 #define TIMETZOID          1266
78 #define NUMERICOID         1700
79
80 static void psql_casereader_destroy (struct casereader *reader UNUSED, void *r_);
81
82 static struct ccase *psql_casereader_read (struct casereader *, void *);
83
84 static const struct casereader_class psql_casereader_class =
85   {
86     psql_casereader_read,
87     psql_casereader_destroy,
88     NULL,
89     NULL,
90   };
91
92 struct psql_reader
93 {
94   PGconn *conn;
95   PGresult *res;
96   int tuple;
97
98   bool integer_datetimes;
99
100   double postgres_epoch;
101
102   struct caseproto *proto;
103   struct dictionary *dict;
104
105   /* An array of ints, which maps psql column numbers into
106      pspp variables */
107   struct variable **vmap;
108   size_t vmapsize;
109
110   struct string fetch_cmd;
111   int cache_size;
112 };
113
114
115 static struct ccase *set_value (struct psql_reader *r);
116
117
118
119 #if WORDS_BIGENDIAN
120 static void
121 data_to_native (const void *in_, void *out_, int len)
122 {
123   int i;
124   const unsigned char *in = in_;
125   unsigned char *out = out_;
126   for (i = 0 ; i < len ; ++i )
127     out[i] = in[i];
128 }
129 #else
130 static void
131 data_to_native (const void *in_, void *out_, int len)
132 {
133   int i;
134   const unsigned char *in = in_;
135   unsigned char *out = out_;
136   for (i = 0 ; i < len ; ++i )
137     out[len - i - 1] = in[i];
138 }
139 #endif
140
141
142 #define GET_VALUE(IN, OUT) do { \
143     size_t sz = sizeof (OUT); \
144     data_to_native (*(IN), &(OUT), sz) ; \
145     (*IN) += sz; \
146 } while (false)
147
148
149 #if 0
150 static void
151 dump (const unsigned char *x, int l)
152 {
153   int i;
154
155   for (i = 0; i < l ; ++i)
156     {
157       printf ("%02x ", x[i]);
158     }
159
160   putchar ('\n');
161
162   for (i = 0; i < l ; ++i)
163     {
164       if ( isprint (x[i]))
165         printf ("%c ", x[i]);
166       else
167         printf ("   ");
168     }
169
170   putchar ('\n');
171 }
172 #endif
173
174 static struct variable *
175 create_var (struct psql_reader *r, const struct fmt_spec *fmt,
176             int width, const char *suggested_name, int col)
177 {
178   unsigned long int vx = 0;
179   struct variable *var;
180   char name[VAR_NAME_LEN + 1];
181
182   if ( ! dict_make_unique_var_name (r->dict, suggested_name, &vx, name))
183     {
184       msg (ME, _("Cannot create variable name from %s"), suggested_name);
185       return NULL;
186     }
187
188   var = dict_create_var (r->dict, name, width);
189   var_set_both_formats (var, fmt);
190
191   if ( col != -1)
192     {
193       r->vmap = xrealloc (r->vmap, (col + 1) * sizeof (*r->vmap));
194
195       r->vmap[col] = var;
196       r->vmapsize = col + 1;
197     }
198
199   return var;
200 }
201
202
203
204
205 /* Fill the cache */
206 static bool
207 reload_cache (struct psql_reader *r)
208 {
209   PQclear (r->res);
210   r->tuple = 0;
211
212   r->res = PQexec (r->conn, ds_cstr (&r->fetch_cmd));
213
214   if (PQresultStatus (r->res) != PGRES_TUPLES_OK || PQntuples (r->res) < 1)
215     {
216       PQclear (r->res);
217       r->res = NULL;
218       return false;
219     }
220
221   return true;
222 }
223
224
225 struct casereader *
226 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
227 {
228   int i;
229   int n_fields, n_tuples;
230   PGresult *qres = NULL;
231   casenumber n_cases = CASENUMBER_MAX;
232
233   struct psql_reader *r = xzalloc (sizeof *r);
234   struct string query ;
235
236   r->conn = PQconnectdb (info->conninfo);
237   if ( NULL == r->conn)
238     {
239       msg (ME, _("Memory error whilst opening psql source"));
240       goto error;
241     }
242
243   if ( PQstatus (r->conn) != CONNECTION_OK )
244     {
245       msg (ME, _("Error opening psql source: %s."),
246            PQerrorMessage (r->conn));
247
248       goto error;
249     }
250
251   {
252     int ver_num;
253     const char *vers = PQparameterStatus (r->conn, "server_version");
254
255     sscanf (vers, "%d", &ver_num);
256
257     if ( ver_num < 8)
258       {
259         msg (ME,
260              _("Postgres server is version %s."
261                " Reading from versions earlier than 8.0 is not supported."),
262              vers);
263
264         goto error;
265       }
266   }
267
268   {
269     const char *dt =  PQparameterStatus (r->conn, "integer_datetimes");
270
271     r->integer_datetimes = ( 0 == strcasecmp (dt, "on"));
272   }
273
274 #if USE_SSL
275   if ( PQgetssl (r->conn) == NULL)
276 #endif
277     {
278       if (! info->allow_clear)
279         {
280           msg (ME, _("Connection is unencrypted, "
281                      "but unencrypted connections have not been permitted."));
282           goto error;
283         }
284     }
285
286   r->postgres_epoch =
287     calendar_gregorian_to_offset (2000, 1, 1, NULL, NULL);
288
289
290   /* Create the dictionary and populate it */
291   *dict = r->dict = dict_create ();
292
293   {
294     const int enc = PQclientEncoding (r->conn);
295
296     /* According to section 22.2 of the Postgresql manual
297        a value of zero (SQL_ASCII) indicates
298        "a declaration of ignorance about the encoding".
299        Accordingly, we don't set the dictionary's encoding
300        if we find this value.
301     */
302     if ( enc != 0 )
303       dict_set_encoding (r->dict, pg_encoding_to_char (enc));
304   }
305
306   /*
307     select count (*) from (select * from medium) stupid_sql_standard;
308   */
309   ds_init_cstr (&query,
310                 "BEGIN READ ONLY ISOLATION LEVEL SERIALIZABLE; "
311                 "DECLARE  pspp BINARY CURSOR FOR ");
312
313   ds_put_substring (&query, info->sql.ss);
314
315   qres = PQexec (r->conn, ds_cstr (&query));
316   ds_destroy (&query);
317   if ( PQresultStatus (qres) != PGRES_COMMAND_OK )
318     {
319       msg (ME, _("Error from psql source: %s."),
320            PQresultErrorMessage (qres));
321       goto error;
322     }
323
324   PQclear (qres);
325
326
327   /* Now use the count() function to find the total number of cases
328      that this query returns.
329      Doing this incurs some overhead.  The server has to iterate every
330      case in order to find this number.  However, it's performed on the
331      server side, and in all except the most huge databases the extra
332      overhead will be worth the effort.
333      On the other hand, most PSPP functions don't need to know this.
334      The GUI is the notable exception.
335   */
336   ds_init_cstr (&query, "SELECT count (*) FROM (");
337   ds_put_substring (&query, info->sql.ss);
338   ds_put_cstr (&query, ") stupid_sql_standard");
339
340   qres = PQexec (r->conn, ds_cstr (&query));
341   ds_destroy (&query);
342   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
343     {
344       msg (ME, _("Error from psql source: %s."),
345            PQresultErrorMessage (qres));
346       goto error;
347     }
348   n_cases = atol (PQgetvalue (qres, 0, 0));
349   PQclear (qres);
350
351   qres = PQexec (r->conn, "FETCH FIRST FROM pspp");
352   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
353     {
354       msg (ME, _("Error from psql source: %s."),
355            PQresultErrorMessage (qres));
356       goto error;
357     }
358
359   n_tuples = PQntuples (qres);
360   n_fields = PQnfields (qres);
361
362   r->proto = NULL;
363   r->vmap = NULL;
364   r->vmapsize = 0;
365
366   for (i = 0 ; i < n_fields ; ++i )
367     {
368       struct variable *var;
369       struct fmt_spec fmt = {FMT_F, 8, 2};
370       Oid type = PQftype (qres, i);
371       int width = 0;
372       int length ;
373
374       /* If there are no data then make a finger in the air 
375          guess at the contents */
376       if ( n_tuples > 0 )
377         length = PQgetlength (qres, 0, i);
378       else 
379         length = MAX_SHORT_STRING;
380
381       switch (type)
382         {
383         case BOOLOID:
384         case OIDOID:
385         case INT2OID:
386         case INT4OID:
387         case INT8OID:
388         case FLOAT4OID:
389         case FLOAT8OID:
390           fmt.type = FMT_F;
391           break;
392         case CASHOID:
393           fmt.type = FMT_DOLLAR;
394           break;
395         case CHAROID:
396           fmt.type = FMT_A;
397           width = length > 0 ? length : 1;
398           fmt.d = 0;
399           fmt.w = 1;
400           break;
401         case TEXTOID:
402         case VARCHAROID:
403         case BPCHAROID:
404           fmt.type = FMT_A;
405           width = (info->str_width == -1) ?
406             ROUND_UP (length, MAX_SHORT_STRING) : info->str_width;
407           fmt.w = width;
408           fmt.d = 0;
409           break;
410         case BYTEAOID:
411           fmt.type = FMT_AHEX;
412           width = length > 0 ? length : MAX_SHORT_STRING;
413           fmt.w = width * 2;
414           fmt.d = 0;
415           break;
416         case INTERVALOID:
417           fmt.type = FMT_DTIME;
418           width = 0;
419           fmt.d = 0;
420           fmt.w = 13;
421           break;
422         case DATEOID:
423           fmt.type = FMT_DATE;
424           width = 0;
425           fmt.w = 11;
426           fmt.d = 0;
427           break;
428         case TIMEOID:
429         case TIMETZOID:
430           fmt.type = FMT_TIME;
431           width = 0;
432           fmt.w = 11;
433           fmt.d = 0;
434           break;
435         case TIMESTAMPOID:
436         case TIMESTAMPTZOID:
437           fmt.type = FMT_DATETIME;
438           fmt.d = 0;
439           fmt.w = 22;
440           width = 0;
441           break;
442         case NUMERICOID:
443           fmt.type = FMT_E;
444           fmt.d = 2;
445           fmt.w = 40;
446           width = 0;
447           break;
448         default:
449           msg (MW, _("Unsupported OID %d.  SYSMIS values will be inserted."), type);
450           fmt.type = FMT_A;
451           width = length > 0 ? length : MAX_SHORT_STRING;
452           fmt.w = width ;
453           fmt.d = 0;
454           break;
455         }
456
457       if ( width == 0 && fmt_is_string (fmt.type))
458         fmt.w = width = MAX_SHORT_STRING;
459
460
461       var = create_var (r, &fmt, width, PQfname (qres, i), i);
462       if ( type == NUMERICOID && n_tuples > 0)
463         {
464           const uint8_t *vptr = (const uint8_t *) PQgetvalue (qres, 0, i);
465           struct fmt_spec fmt;
466           int16_t n_digits, weight, dscale;
467           uint16_t sign;
468
469           GET_VALUE (&vptr, n_digits);
470           GET_VALUE (&vptr, weight);
471           GET_VALUE (&vptr, sign);
472           GET_VALUE (&vptr, dscale);
473
474           fmt.d = dscale;
475           fmt.type = FMT_E;
476           fmt.w = fmt_max_output_width (fmt.type) ;
477           fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
478           var_set_both_formats (var, &fmt);
479         }
480
481       /* Timezones need an extra variable */
482       switch (type)
483         {
484         case TIMETZOID:
485           {
486             struct string name;
487             ds_init_cstr (&name, var_get_name (var));
488             ds_put_cstr (&name, "-zone");
489             fmt.type = FMT_F;
490             fmt.w = 8;
491             fmt.d = 2;
492
493             create_var (r, &fmt, 0, ds_cstr (&name), -1);
494
495             ds_destroy (&name);
496           }
497           break;
498
499         case INTERVALOID:
500           {
501             struct string name;
502             ds_init_cstr (&name, var_get_name (var));
503             ds_put_cstr (&name, "-months");
504             fmt.type = FMT_F;
505             fmt.w = 3;
506             fmt.d = 0;
507
508             create_var (r, &fmt, 0, ds_cstr (&name), -1);
509
510             ds_destroy (&name);
511           }
512         default:
513           break;
514         }
515     }
516
517   PQclear (qres);
518
519   qres = PQexec (r->conn, "MOVE BACKWARD 1 FROM pspp");
520   if ( PQresultStatus (qres) != PGRES_COMMAND_OK)
521     {
522       PQclear (qres);
523       goto error;
524     }
525   PQclear (qres);
526
527   r->cache_size = info->bsize != -1 ? info->bsize: 4096;
528
529   ds_init_empty (&r->fetch_cmd);
530   ds_put_format (&r->fetch_cmd,  "FETCH FORWARD %d FROM pspp", r->cache_size);
531
532   reload_cache (r);
533   r->proto = caseproto_ref (dict_get_proto (*dict));
534
535   return casereader_create_sequential
536     (NULL,
537      r->proto,
538      n_cases,
539      &psql_casereader_class, r);
540
541  error:
542   dict_destroy (*dict);
543
544   psql_casereader_destroy (NULL, r);
545   return NULL;
546 }
547
548
549 static void
550 psql_casereader_destroy (struct casereader *reader UNUSED, void *r_)
551 {
552   struct psql_reader *r = r_;
553   if (r == NULL)
554     return ;
555
556   ds_destroy (&r->fetch_cmd);
557   free (r->vmap);
558   if (r->res) PQclear (r->res);
559   PQfinish (r->conn);
560   caseproto_unref (r->proto);
561
562   free (r);
563 }
564
565
566
567 static struct ccase *
568 psql_casereader_read (struct casereader *reader UNUSED, void *r_)
569 {
570   struct psql_reader *r = r_;
571
572   if ( NULL == r->res || r->tuple >= r->cache_size)
573     {
574       if ( ! reload_cache (r) )
575         return false;
576     }
577
578   return set_value (r);
579 }
580
581 static struct ccase *
582 set_value (struct psql_reader *r)
583 {
584   struct ccase *c;
585   int n_vars;
586   int i;
587
588   assert (r->res);
589
590   n_vars = PQnfields (r->res);
591
592   if ( r->tuple >= PQntuples (r->res))
593     return NULL;
594
595   c = case_create (r->proto);
596   case_set_missing (c);
597
598
599   for (i = 0 ; i < n_vars ; ++i )
600     {
601       Oid type = PQftype (r->res, i);
602       const struct variable *v = r->vmap[i];
603       union value *val = case_data_rw (c, v);
604
605       union value *val1 = NULL;
606
607       switch (type)
608         {
609         case INTERVALOID:
610         case TIMESTAMPTZOID:
611         case TIMETZOID:
612           if (i < r->vmapsize && var_get_dict_index(v) + 1 < dict_get_var_cnt (r->dict))
613             {
614               const struct variable *v1 = NULL;
615               v1 = dict_get_var (r->dict, var_get_dict_index (v) + 1);
616
617               val1 = case_data_rw (c, v1);
618             }
619           break;
620         default:
621           break;
622         }
623
624
625       if (PQgetisnull (r->res, r->tuple, i))
626         {
627           value_set_missing (val, var_get_width (v));
628
629           switch (type)
630             {
631             case INTERVALOID:
632             case TIMESTAMPTZOID:
633             case TIMETZOID:
634               val1->f = SYSMIS;
635               break;
636             default:
637               break;
638             }
639         }
640       else
641         {
642           const uint8_t *vptr = (const uint8_t *) PQgetvalue (r->res, r->tuple, i);
643           int length = PQgetlength (r->res, r->tuple, i);
644
645           int var_width = var_get_width (v);
646           switch (type)
647             {
648             case BOOLOID:
649               {
650                 int8_t x;
651                 GET_VALUE (&vptr, x);
652                 val->f = x;
653               }
654               break;
655
656             case OIDOID:
657             case INT2OID:
658               {
659                 int16_t x;
660                 GET_VALUE (&vptr, x);
661                 val->f = x;
662               }
663               break;
664
665             case INT4OID:
666               {
667                 int32_t x;
668                 GET_VALUE (&vptr, x);
669                 val->f = x;
670               }
671               break;
672
673             case INT8OID:
674               {
675                 int64_t x;
676                 GET_VALUE (&vptr, x);
677                 val->f = x;
678               }
679               break;
680
681             case FLOAT4OID:
682               {
683                 float n;
684                 GET_VALUE (&vptr, n);
685                 val->f = n;
686               }
687               break;
688
689             case FLOAT8OID:
690               {
691                 double n;
692                 GET_VALUE (&vptr, n);
693                 val->f = n;
694               }
695               break;
696
697             case CASHOID:
698               {
699                 /* Postgres 8.3 uses 64 bits.
700                    Earlier versions use 32 */
701                 switch (length)
702                   {
703                   case 8:
704                     {
705                       int64_t x;
706                       GET_VALUE (&vptr, x);
707                       val->f = x / 100.0;
708                     }
709                     break;
710                   case 4:
711                     {
712                       int32_t x;
713                       GET_VALUE (&vptr, x);
714                       val->f = x / 100.0;
715                     }
716                     break;
717                   default:
718                     val->f = SYSMIS;
719                     break;
720                   }
721               }
722               break;
723
724             case INTERVALOID:
725               {
726                 if ( r->integer_datetimes )
727                   {
728                     uint32_t months;
729                     uint32_t days;
730                     uint32_t us;
731                     uint32_t things;
732
733                     GET_VALUE (&vptr, things);
734                     GET_VALUE (&vptr, us);
735                     GET_VALUE (&vptr, days);
736                     GET_VALUE (&vptr, months);
737
738                     val->f = us / 1000000.0;
739                     val->f += days * 24 * 3600;
740
741                     val1->f = months;
742                   }
743                 else
744                   {
745                     uint32_t days, months;
746                     double seconds;
747
748                     GET_VALUE (&vptr, seconds);
749                     GET_VALUE (&vptr, days);
750                     GET_VALUE (&vptr, months);
751
752                     val->f = seconds;
753                     val->f += days * 24 * 3600;
754
755                     val1->f = months;
756                   }
757               }
758               break;
759
760             case DATEOID:
761               {
762                 int32_t x;
763
764                 GET_VALUE (&vptr, x);
765
766                 val->f = (x + r->postgres_epoch) * 24 * 3600 ;
767               }
768               break;
769
770             case TIMEOID:
771               {
772                 if ( r->integer_datetimes)
773                   {
774                     uint64_t x;
775                     GET_VALUE (&vptr, x);
776                     val->f = x / 1000000.0;
777                   }
778                 else
779                   {
780                     double x;
781                     GET_VALUE (&vptr, x);
782                     val->f = x;
783                   }
784               }
785               break;
786
787             case TIMETZOID:
788               {
789                 int32_t zone;
790                 if ( r->integer_datetimes)
791                   {
792                     uint64_t x;
793
794
795                     GET_VALUE (&vptr, x);
796                     val->f = x / 1000000.0;
797                   }
798                 else
799                   {
800                     double x;
801
802                     GET_VALUE (&vptr, x);
803                     val->f = x ;
804                   }
805
806                 GET_VALUE (&vptr, zone);
807                 val1->f = zone / 3600.0;
808               }
809               break;
810
811             case TIMESTAMPOID:
812             case TIMESTAMPTZOID:
813               {
814                 if ( r->integer_datetimes)
815                   {
816                     int64_t x;
817
818                     GET_VALUE (&vptr, x);
819
820                     x /= 1000000;
821
822                     val->f = (x + r->postgres_epoch * 24 * 3600 );
823                   }
824                 else
825                   {
826                     double x;
827
828                     GET_VALUE (&vptr, x);
829
830                     val->f = (x + r->postgres_epoch * 24 * 3600 );
831                   }
832               }
833               break;
834             case TEXTOID:
835             case VARCHAROID:
836             case BPCHAROID:
837             case BYTEAOID:
838               memcpy (value_str_rw (val, var_width), (char *) vptr,
839                       MIN (length, var_width));
840               break;
841
842             case NUMERICOID:
843               {
844                 double f = 0.0;
845                 int i;
846                 int16_t n_digits, weight, dscale;
847                 uint16_t sign;
848
849                 GET_VALUE (&vptr, n_digits);
850                 GET_VALUE (&vptr, weight);
851                 GET_VALUE (&vptr, sign);
852                 GET_VALUE (&vptr, dscale);
853
854 #if 0
855                 {
856                   struct fmt_spec fmt;
857                   fmt.d = dscale;
858                   fmt.type = FMT_E;
859                   fmt.w = fmt_max_output_width (fmt.type) ;
860                   fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
861                   var_set_both_formats (v, &fmt);
862                 }
863 #endif
864
865                 for (i = 0 ; i < n_digits;  ++i)
866                   {
867                     uint16_t x;
868                     GET_VALUE (&vptr, x);
869                     f += x * pow (10000, weight--);
870                   }
871
872                 if ( sign == 0x4000)
873                   f *= -1.0;
874
875                 if ( sign == 0xC000)
876                   val->f = SYSMIS;
877                 else
878                   val->f = f;
879               }
880               break;
881
882             default:
883               val->f = SYSMIS;
884               break;
885             }
886         }
887     }
888
889   r->tuple++;
890
891   return c;
892 }
893
894 #endif