Make cases simpler, faster, and easier to understand.
[pspp-builds.git] / src / data / psql-reader.c
1 /* PSPP - a program for statistical analysis.
2    Copyright (C) 2008, 2009 Free Software Foundation, Inc.
3
4    This program is free software: you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation, either version 3 of the License, or
7    (at your option) any later version.
8
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13
14    You should have received a copy of the GNU General Public License
15    along with this program.  If not, see <http://www.gnu.org/licenses/>. */
16
17 #include <config.h>
18
19 #include <data/casereader-provider.h>
20 #include <libpspp/message.h>
21 #include <gl/xalloc.h>
22 #include <data/dictionary.h>
23 #include <stdlib.h>
24
25 #include "psql-reader.h"
26 #include "variable.h"
27 #include "format.h"
28 #include "calendar.h"
29
30 #include <inttypes.h>
31 #include <libpspp/str.h>
32
33 #include "gettext.h"
34 #define _(msgid) gettext (msgid)
35 #define N_(msgid) (msgid)
36
37
38 #if !PSQL_SUPPORT
39 struct casereader *
40 psql_open_reader (struct psql_read_info *info UNUSED, struct dictionary **dict UNUSED)
41 {
42   msg (ME, _("Support for reading postgres databases was not compiled into this installation of PSPP"));
43
44   return NULL;
45 }
46
47 #else
48
49 #include <stdint.h>
50 #include <libpq-fe.h>
51
52
53 /* These macros  must be the same as in catalog/pg_types.h from the postgres source */
54 #define BOOLOID            16
55 #define BYTEAOID           17
56 #define CHAROID            18
57 #define NAMEOID            19
58 #define INT8OID            20
59 #define INT2OID            21
60 #define INT4OID            23
61 #define TEXTOID            25
62 #define OIDOID             26
63 #define FLOAT4OID          700
64 #define FLOAT8OID          701
65 #define CASHOID            790
66 #define BPCHAROID          1042
67 #define VARCHAROID         1043
68 #define DATEOID            1082
69 #define TIMEOID            1083
70 #define TIMESTAMPOID       1114
71 #define TIMESTAMPTZOID     1184
72 #define INTERVALOID        1186
73 #define TIMETZOID          1266
74 #define NUMERICOID         1700
75
76 static void psql_casereader_destroy (struct casereader *reader UNUSED, void *r_);
77
78 static struct ccase *psql_casereader_read (struct casereader *, void *);
79
80 static const struct casereader_class psql_casereader_class =
81   {
82     psql_casereader_read,
83     psql_casereader_destroy,
84     NULL,
85     NULL,
86   };
87
88 struct psql_reader
89 {
90   PGconn *conn;
91   PGresult *res;
92   int tuple;
93
94   bool integer_datetimes;
95
96   double postgres_epoch;
97
98   size_t value_cnt;
99   struct dictionary *dict;
100
101   /* An array of ints, which maps psql column numbers into
102      pspp variables */
103   struct variable **vmap;
104   size_t vmapsize;
105
106   struct string fetch_cmd;
107   int cache_size;
108 };
109
110
111 static struct ccase *set_value (struct psql_reader *r);
112
113
114
115 #if WORDS_BIGENDIAN
116 static void
117 data_to_native (const void *in_, void *out_, int len)
118 {
119   int i;
120   const unsigned char *in = in_;
121   unsigned char *out = out_;
122   for (i = 0 ; i < len ; ++i )
123     out[i] = in[i];
124 }
125 #else
126 static void
127 data_to_native (const void *in_, void *out_, int len)
128 {
129   int i;
130   const unsigned char *in = in_;
131   unsigned char *out = out_;
132   for (i = 0 ; i < len ; ++i )
133     out[len - i - 1] = in[i];
134 }
135 #endif
136
137
138 #define GET_VALUE(IN, OUT) do { \
139     size_t sz = sizeof (OUT); \
140     data_to_native (*(IN), &(OUT), sz) ; \
141     (*IN) += sz; \
142 } while (false)
143
144
145 #if 0
146 static void
147 dump (const unsigned char *x, int l)
148 {
149   int i;
150
151   for (i = 0; i < l ; ++i)
152     {
153       printf ("%02x ", x[i]);
154     }
155
156   putchar ('\n');
157
158   for (i = 0; i < l ; ++i)
159     {
160       if ( isprint (x[i]))
161         printf ("%c ", x[i]);
162       else
163         printf ("   ");
164     }
165
166   putchar ('\n');
167 }
168 #endif
169
170 static struct variable *
171 create_var (struct psql_reader *r, const struct fmt_spec *fmt,
172             int width, const char *suggested_name, int col)
173 {
174   unsigned long int vx = 0;
175   struct variable *var;
176   char name[VAR_NAME_LEN + 1];
177
178   r->value_cnt += value_cnt_from_width (width);
179
180   if ( ! dict_make_unique_var_name (r->dict, suggested_name, &vx, name))
181     {
182       msg (ME, _("Cannot create variable name from %s"), suggested_name);
183       return NULL;
184     }
185
186   var = dict_create_var (r->dict, name, width);
187   var_set_both_formats (var, fmt);
188
189   if ( col != -1)
190     {
191       r->vmap = xrealloc (r->vmap, (col + 1) * sizeof (*r->vmap));
192
193       r->vmap[col] = var;
194       r->vmapsize = col + 1;
195     }
196
197   return var;
198 }
199
200
201
202
203 /* Fill the cache */
204 static bool
205 reload_cache (struct psql_reader *r)
206 {
207   PQclear (r->res);
208   r->tuple = 0;
209
210   r->res = PQexec (r->conn, ds_cstr (&r->fetch_cmd));
211
212   if (PQresultStatus (r->res) != PGRES_TUPLES_OK || PQntuples (r->res) < 1)
213     {
214       PQclear (r->res);
215       r->res = NULL;
216       return false;
217     }
218
219   return true;
220 }
221
222
223 struct casereader *
224 psql_open_reader (struct psql_read_info *info, struct dictionary **dict)
225 {
226   int i;
227   int n_fields, n_tuples;
228   PGresult *qres = NULL;
229   casenumber n_cases = CASENUMBER_MAX;
230
231   struct psql_reader *r = xzalloc (sizeof *r);
232   struct string query ;
233
234   r->conn = PQconnectdb (info->conninfo);
235   if ( NULL == r->conn)
236     {
237       msg (ME, _("Memory error whilst opening psql source"));
238       goto error;
239     }
240
241   if ( PQstatus (r->conn) != CONNECTION_OK )
242     {
243       msg (ME, _("Error opening psql source: %s."),
244            PQerrorMessage (r->conn));
245
246       goto error;
247     }
248
249   {
250     int ver_num;
251     const char *vers = PQparameterStatus (r->conn, "server_version");
252
253     sscanf (vers, "%d", &ver_num);
254
255     if ( ver_num < 8)
256       {
257         msg (ME,
258              _("Postgres server is version %s."
259                " Reading from versions earlier than 8.0 is not supported."),
260              vers);
261
262         goto error;
263       }
264   }
265
266   {
267     const char *dt =  PQparameterStatus (r->conn, "integer_datetimes");
268
269     r->integer_datetimes = ( 0 == strcasecmp (dt, "on"));
270   }
271
272 #if USE_SSL
273   if ( PQgetssl (r->conn) == NULL)
274 #endif
275     {
276       if (! info->allow_clear)
277         {
278           msg (ME, _("Connection is unencrypted, "
279                      "but unencrypted connections have not been permitted."));
280           goto error;
281         }
282     }
283
284   r->postgres_epoch =
285     calendar_gregorian_to_offset (2000, 1, 1, NULL, NULL);
286
287
288   /* Create the dictionary and populate it */
289   *dict = r->dict = dict_create ();
290
291   /*
292     select count (*) from (select * from medium) stupid_sql_standard;
293   */
294
295   ds_init_cstr (&query,
296                 "BEGIN READ ONLY ISOLATION LEVEL SERIALIZABLE; "
297                 "DECLARE  pspp BINARY CURSOR FOR ");
298
299   ds_put_substring (&query, info->sql.ss);
300
301   qres = PQexec (r->conn, ds_cstr (&query));
302   ds_destroy (&query);
303   if ( PQresultStatus (qres) != PGRES_COMMAND_OK )
304     {
305       msg (ME, _("Error from psql source: %s."),
306            PQresultErrorMessage (qres));
307       goto error;
308     }
309
310   PQclear (qres);
311
312
313   /* Now use the count() function to find the total number of cases
314      that this query returns.
315      Doing this incurs some overhead.  The server has to iterate every
316      case in order to find this number.  However, it's performed on the
317      server side, and in all except the most huge databases the extra
318      overhead will be worth the effort.
319      On the other hand, most PSPP functions don't need to know this.
320      The GUI is the notable exception.
321   */
322   ds_init_cstr (&query, "SELECT count (*) FROM (");
323   ds_put_substring (&query, info->sql.ss);
324   ds_put_cstr (&query, ") stupid_sql_standard");
325
326   qres = PQexec (r->conn, ds_cstr (&query));
327   ds_destroy (&query);
328   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
329     {
330       msg (ME, _("Error from psql source: %s."),
331            PQresultErrorMessage (qres));
332       goto error;
333     }
334   n_cases = atol (PQgetvalue (qres, 0, 0));
335   PQclear (qres);
336
337   qres = PQexec (r->conn, "FETCH FIRST FROM pspp");
338   if ( PQresultStatus (qres) != PGRES_TUPLES_OK )
339     {
340       msg (ME, _("Error from psql source: %s."),
341            PQresultErrorMessage (qres));
342       goto error;
343     }
344
345   n_tuples = PQntuples (qres);
346   n_fields = PQnfields (qres);
347
348   r->value_cnt = 0;
349   r->vmap = NULL;
350   r->vmapsize = 0;
351
352   for (i = 0 ; i < n_fields ; ++i )
353     {
354       struct variable *var;
355       struct fmt_spec fmt = {FMT_F, 8, 2};
356       Oid type = PQftype (qres, i);
357       int width = 0;
358       int length ;
359
360       /* If there are no data then make a finger in the air 
361          guess at the contents */
362       if ( n_tuples > 0 )
363         length = PQgetlength (qres, 0, i);
364       else 
365         length = MAX_SHORT_STRING;
366
367       switch (type)
368         {
369         case BOOLOID:
370         case OIDOID:
371         case INT2OID:
372         case INT4OID:
373         case INT8OID:
374         case FLOAT4OID:
375         case FLOAT8OID:
376           fmt.type = FMT_F;
377           break;
378         case CASHOID:
379           fmt.type = FMT_DOLLAR;
380           break;
381         case CHAROID:
382           fmt.type = FMT_A;
383           width = length > 0 ? length : 1;
384           fmt.d = 0;
385           fmt.w = 1;
386           break;
387         case TEXTOID:
388         case VARCHAROID:
389         case BPCHAROID:
390           fmt.type = FMT_A;
391           width = (info->str_width == -1) ?
392             ROUND_UP (length, MAX_SHORT_STRING) : info->str_width;
393           fmt.w = width;
394           fmt.d = 0;
395           break;
396         case BYTEAOID:
397           fmt.type = FMT_AHEX;
398           width = length > 0 ? length : MAX_SHORT_STRING;
399           fmt.w = width * 2;
400           fmt.d = 0;
401           break;
402         case INTERVALOID:
403           fmt.type = FMT_DTIME;
404           width = 0;
405           fmt.d = 0;
406           fmt.w = 13;
407           break;
408         case DATEOID:
409           fmt.type = FMT_DATE;
410           width = 0;
411           fmt.w = 11;
412           fmt.d = 0;
413           break;
414         case TIMEOID:
415         case TIMETZOID:
416           fmt.type = FMT_TIME;
417           width = 0;
418           fmt.w = 11;
419           fmt.d = 0;
420           break;
421         case TIMESTAMPOID:
422         case TIMESTAMPTZOID:
423           fmt.type = FMT_DATETIME;
424           fmt.d = 0;
425           fmt.w = 22;
426           width = 0;
427           break;
428         case NUMERICOID:
429           fmt.type = FMT_E;
430           fmt.d = 2;
431           fmt.w = 40;
432           width = 0;
433           break;
434         default:
435           msg (MW, _("Unsupported OID %d.  SYSMIS values will be inserted."), type);
436           fmt.type = FMT_A;
437           width = length > 0 ? length : MAX_SHORT_STRING;
438           fmt.w = width ;
439           fmt.d = 0;
440           break;
441         }
442
443       var = create_var (r, &fmt, width, PQfname (qres, i), i);
444       if ( type == NUMERICOID && n_tuples > 0)
445         {
446           const uint8_t *vptr = (const uint8_t *) PQgetvalue (qres, 0, i);
447           struct fmt_spec fmt;
448           int16_t n_digits, weight, dscale;
449           uint16_t sign;
450
451           GET_VALUE (&vptr, n_digits);
452           GET_VALUE (&vptr, weight);
453           GET_VALUE (&vptr, sign);
454           GET_VALUE (&vptr, dscale);
455
456           fmt.d = dscale;
457           fmt.type = FMT_E;
458           fmt.w = fmt_max_output_width (fmt.type) ;
459           fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
460           var_set_both_formats (var, &fmt);
461         }
462
463       /* Timezones need an extra variable */
464       switch (type)
465         {
466         case TIMETZOID:
467           {
468             struct string name;
469             ds_init_cstr (&name, var_get_name (var));
470             ds_put_cstr (&name, "-zone");
471             fmt.type = FMT_F;
472             fmt.w = 8;
473             fmt.d = 2;
474
475             create_var (r, &fmt, 0, ds_cstr (&name), -1);
476
477             ds_destroy (&name);
478           }
479           break;
480
481         case INTERVALOID:
482           {
483             struct string name;
484             ds_init_cstr (&name, var_get_name (var));
485             ds_put_cstr (&name, "-months");
486             fmt.type = FMT_F;
487             fmt.w = 3;
488             fmt.d = 0;
489
490             create_var (r, &fmt, 0, ds_cstr (&name), -1);
491
492             ds_destroy (&name);
493           }
494         default:
495           break;
496         }
497     }
498
499   PQclear (qres);
500
501   qres = PQexec (r->conn, "MOVE BACKWARD 1 FROM pspp");
502   if ( PQresultStatus (qres) != PGRES_COMMAND_OK)
503     {
504       PQclear (qres);
505       goto error;
506     }
507   PQclear (qres);
508
509   r->cache_size = info->bsize != -1 ? info->bsize: 4096;
510
511   ds_init_empty (&r->fetch_cmd);
512   ds_put_format (&r->fetch_cmd,  "FETCH FORWARD %d FROM pspp", r->cache_size);
513
514   reload_cache (r);
515
516   return casereader_create_sequential
517     (NULL,
518      r->value_cnt,
519      n_cases,
520      &psql_casereader_class, r);
521
522  error:
523   dict_destroy (*dict);
524
525   psql_casereader_destroy (NULL, r);
526   return NULL;
527 }
528
529
530 static void
531 psql_casereader_destroy (struct casereader *reader UNUSED, void *r_)
532 {
533   struct psql_reader *r = r_;
534   if (r == NULL)
535     return ;
536
537   ds_destroy (&r->fetch_cmd);
538   free (r->vmap);
539   if (r->res) PQclear (r->res);
540   PQfinish (r->conn);
541
542   free (r);
543 }
544
545
546
547 static struct ccase *
548 psql_casereader_read (struct casereader *reader UNUSED, void *r_)
549 {
550   struct psql_reader *r = r_;
551
552   if ( NULL == r->res || r->tuple >= r->cache_size)
553     {
554       if ( ! reload_cache (r) )
555         return false;
556     }
557
558   return set_value (r);
559 }
560
561 static struct ccase *
562 set_value (struct psql_reader *r)
563 {
564   struct ccase *c;
565   int n_vars;
566   int i;
567
568   assert (r->res);
569
570   n_vars = PQnfields (r->res);
571
572   if ( r->tuple >= PQntuples (r->res))
573     return NULL;
574
575   c = case_create (r->value_cnt);
576   memset (case_data_rw_idx (c, 0)->s, ' ', MAX_SHORT_STRING * r->value_cnt);
577
578
579   for (i = 0 ; i < n_vars ; ++i )
580     {
581       Oid type = PQftype (r->res, i);
582       const struct variable *v = r->vmap[i];
583       union value *val = case_data_rw (c, v);
584
585       union value *val1 = NULL;
586
587       switch (type)
588         {
589         case INTERVALOID:
590         case TIMESTAMPTZOID:
591         case TIMETZOID:
592           if (i < r->vmapsize && var_get_dict_index(v) + 1 < dict_get_var_cnt (r->dict))
593             {
594               const struct variable *v1 = NULL;
595               v1 = dict_get_var (r->dict, var_get_dict_index (v) + 1);
596
597               val1 = case_data_rw (c, v1);
598             }
599           break;
600         default:
601           break;
602         }
603
604
605       if (PQgetisnull (r->res, r->tuple, i))
606         {
607           value_set_missing (val, var_get_width (v));
608
609           switch (type)
610             {
611             case INTERVALOID:
612             case TIMESTAMPTZOID:
613             case TIMETZOID:
614               val1->f = SYSMIS;
615               break;
616             default:
617               break;
618             }
619         }
620       else
621         {
622           const uint8_t *vptr = (const uint8_t *) PQgetvalue (r->res, r->tuple, i);
623           int length = PQgetlength (r->res, r->tuple, i);
624
625           int var_width = var_get_width (v);
626           switch (type)
627             {
628             case BOOLOID:
629               {
630                 int8_t x;
631                 GET_VALUE (&vptr, x);
632                 val->f = x;
633               }
634               break;
635
636             case OIDOID:
637             case INT2OID:
638               {
639                 int16_t x;
640                 GET_VALUE (&vptr, x);
641                 val->f = x;
642               }
643               break;
644
645             case INT4OID:
646               {
647                 int32_t x;
648                 GET_VALUE (&vptr, x);
649                 val->f = x;
650               }
651               break;
652
653             case INT8OID:
654               {
655                 int64_t x;
656                 GET_VALUE (&vptr, x);
657                 val->f = x;
658               }
659               break;
660
661             case FLOAT4OID:
662               {
663                 float n;
664                 GET_VALUE (&vptr, n);
665                 val->f = n;
666               }
667               break;
668
669             case FLOAT8OID:
670               {
671                 double n;
672                 GET_VALUE (&vptr, n);
673                 val->f = n;
674               }
675               break;
676
677             case CASHOID:
678               {
679                 /* Postgres 8.3 uses 64 bits.
680                    Earlier versions use 32 */
681                 switch (length)
682                   {
683                   case 8:
684                     {
685                       int64_t x;
686                       GET_VALUE (&vptr, x);
687                       val->f = x / 100.0;
688                     }
689                     break;
690                   case 4:
691                     {
692                       int32_t x;
693                       GET_VALUE (&vptr, x);
694                       val->f = x / 100.0;
695                     }
696                     break;
697                   default:
698                     val->f = SYSMIS;
699                     break;
700                   }
701               }
702               break;
703
704             case INTERVALOID:
705               {
706                 if ( r->integer_datetimes )
707                   {
708                     uint32_t months;
709                     uint32_t days;
710                     uint32_t us;
711                     uint32_t things;
712
713                     GET_VALUE (&vptr, things);
714                     GET_VALUE (&vptr, us);
715                     GET_VALUE (&vptr, days);
716                     GET_VALUE (&vptr, months);
717
718                     val->f = us / 1000000.0;
719                     val->f += days * 24 * 3600;
720
721                     val1->f = months;
722                   }
723                 else
724                   {
725                     uint32_t days, months;
726                     double seconds;
727
728                     GET_VALUE (&vptr, seconds);
729                     GET_VALUE (&vptr, days);
730                     GET_VALUE (&vptr, months);
731
732                     val->f = seconds;
733                     val->f += days * 24 * 3600;
734
735                     val1->f = months;
736                   }
737               }
738               break;
739
740             case DATEOID:
741               {
742                 int32_t x;
743
744                 GET_VALUE (&vptr, x);
745
746                 val->f = (x + r->postgres_epoch) * 24 * 3600 ;
747               }
748               break;
749
750             case TIMEOID:
751               {
752                 if ( r->integer_datetimes)
753                   {
754                     uint64_t x;
755                     GET_VALUE (&vptr, x);
756                     val->f = x / 1000000.0;
757                   }
758                 else
759                   {
760                     double x;
761                     GET_VALUE (&vptr, x);
762                     val->f = x;
763                   }
764               }
765               break;
766
767             case TIMETZOID:
768               {
769                 int32_t zone;
770                 if ( r->integer_datetimes)
771                   {
772                     uint64_t x;
773
774
775                     GET_VALUE (&vptr, x);
776                     val->f = x / 1000000.0;
777                   }
778                 else
779                   {
780                     double x;
781
782                     GET_VALUE (&vptr, x);
783                     val->f = x ;
784                   }
785
786                 GET_VALUE (&vptr, zone);
787                 val1->f = zone / 3600.0;
788               }
789               break;
790
791             case TIMESTAMPOID:
792             case TIMESTAMPTZOID:
793               {
794                 if ( r->integer_datetimes)
795                   {
796                     int64_t x;
797
798                     GET_VALUE (&vptr, x);
799
800                     x /= 1000000;
801
802                     val->f = (x + r->postgres_epoch * 24 * 3600 );
803                   }
804                 else
805                   {
806                     double x;
807
808                     GET_VALUE (&vptr, x);
809
810                     val->f = (x + r->postgres_epoch * 24 * 3600 );
811                   }
812               }
813               break;
814             case TEXTOID:
815             case VARCHAROID:
816             case BPCHAROID:
817             case BYTEAOID:
818               memcpy (val->s, (char *) vptr, MIN (length, var_width));
819               break;
820
821             case NUMERICOID:
822               {
823                 double f = 0.0;
824                 int i;
825                 int16_t n_digits, weight, dscale;
826                 uint16_t sign;
827
828                 GET_VALUE (&vptr, n_digits);
829                 GET_VALUE (&vptr, weight);
830                 GET_VALUE (&vptr, sign);
831                 GET_VALUE (&vptr, dscale);
832
833 #if 0
834                 {
835                   struct fmt_spec fmt;
836                   fmt.d = dscale;
837                   fmt.type = FMT_E;
838                   fmt.w = fmt_max_output_width (fmt.type) ;
839                   fmt.d =  MIN (dscale, fmt_max_output_decimals (fmt.type, fmt.w));
840                   var_set_both_formats (v, &fmt);
841                 }
842 #endif
843
844                 for (i = 0 ; i < n_digits;  ++i)
845                   {
846                     uint16_t x;
847                     GET_VALUE (&vptr, x);
848                     f += x * pow (10000, weight--);
849                   }
850
851                 if ( sign == 0x4000)
852                   f *= -1.0;
853
854                 if ( sign == 0xC000)
855                   val->f = SYSMIS;
856                 else
857                   val->f = f;
858               }
859               break;
860
861             default:
862               val->f = SYSMIS;
863               break;
864             }
865         }
866     }
867
868   r->tuple++;
869
870   return c;
871 }
872
873 #endif