vconn: Make errors in vconn names non-fatal errors.
[openvswitch] / lib / vconn-ssl.c
index 30d8caf63194b850845eb07fe11cc86d663a8f8b..a7060d74bb9a91e4f1bd7c80829437a9b4132a1a 100644 (file)
@@ -35,6 +35,7 @@
 #include "dhparams.h"
 #include <assert.h>
 #include <errno.h>
+#include <inttypes.h>
 #include <string.h>
 #include <netinet/tcp.h>
 #include <openssl/err.h>
@@ -45,6 +46,7 @@
 #include "socket-util.h"
 #include "util.h"
 #include "openflow.h"
+#include "packets.h"
 #include "poll-loop.h"
 #include "ofp-print.h"
 #include "socket-util.h"
@@ -76,6 +78,64 @@ struct ssl_vconn
     struct buffer *rxbuf;
     struct buffer *txbuf;
     struct poll_waiter *tx_waiter;
+
+    /* rx_want and tx_want record the result of the last call to SSL_read()
+     * and SSL_write(), respectively:
+     *
+     *    - If the call reported that data needed to be read from the file
+     *      descriptor, the corresponding member is set to SSL_READING.
+     *
+     *    - If the call reported that data needed to be written to the file
+     *      descriptor, the corresponding member is set to SSL_WRITING.
+     *
+     *    - Otherwise, the member is set to SSL_NOTHING, indicating that the
+     *      call completed successfully (or with an error) and that there is no
+     *      need to block.
+     *
+     * These are needed because there is no way to ask OpenSSL what a data read
+     * or write would require without giving it a buffer to receive into or
+     * data to send, respectively.  (Note that the SSL_want() status is
+     * overwritten by each SSL_read() or SSL_write() call, so we can't rely on
+     * its value.)
+     *
+     * A single call to SSL_read() or SSL_write() can perform both reading
+     * and writing and thus invalidate not one of these values but actually
+     * both.  Consider this situation, for example:
+     *
+     *    - SSL_write() blocks on a read, so tx_want gets SSL_READING.
+     *
+     *    - SSL_read() laters succeeds reading from 'fd' and clears out the
+     *      whole receive buffer, so rx_want gets SSL_READING.
+     *
+     *    - Client calls vconn_wait(WAIT_RECV) and vconn_wait(WAIT_SEND) and
+     *      blocks.
+     *
+     *    - Now we're stuck blocking until the peer sends us data, even though
+     *      SSL_write() could now succeed, which could easily be a deadlock
+     *      condition.
+     *
+     * On the other hand, we can't reset both tx_want and rx_want on every call
+     * to SSL_read() or SSL_write(), because that would produce livelock,
+     * e.g. in this situation:
+     *
+     *    - SSL_write() blocks, so tx_want gets SSL_READING or SSL_WRITING.
+     *
+     *    - SSL_read() blocks, so rx_want gets SSL_READING or SSL_WRITING,
+     *      but tx_want gets reset to SSL_NOTHING.
+     *
+     *    - Client calls vconn_wait(WAIT_RECV) and vconn_wait(WAIT_SEND) and
+     *      blocks.
+     *
+     *    - Client wakes up immediately since SSL_NOTHING in tx_want indicates
+     *      that no blocking is necessary.
+     *
+     * The solution we adopt here is to set tx_want to SSL_NOTHING after
+     * calling SSL_read() only if the SSL state of the connection changed,
+     * which indicates that an SSL-level renegotiation made some progress, and
+     * similarly for rx_want and SSL_write().  This prevents both the
+     * deadlock and livelock situations above.
+     */
+    int rx_want, tx_want;
 };
 
 /* SSL context created by ssl_init(). */
@@ -88,13 +148,33 @@ static int ssl_init(void);
 static int do_ssl_init(void);
 static bool ssl_wants_io(int ssl_error);
 static void ssl_close(struct vconn *);
-static int interpret_ssl_error(const char *function, int ret, int error);
-static void ssl_do_tx(int fd, short int revents, void *vconn_);
+static int interpret_ssl_error(const char *function, int ret, int error,
+                               int *want);
+static void ssl_tx_poll_callback(int fd, short int revents, void *vconn_);
 static DH *tmp_dh_callback(SSL *ssl, int is_export UNUSED, int keylength);
 
+short int
+want_to_poll_events(int want)
+{
+    switch (want) {
+    case SSL_NOTHING:
+        NOT_REACHED();
+
+    case SSL_READING:
+        return POLLIN;
+
+    case SSL_WRITING:
+        return POLLOUT;
+
+    default:
+        NOT_REACHED();
+    }
+}
+
 static int
 new_ssl_vconn(const char *name, int fd, enum session_type type,
-              enum ssl_state state, struct vconn **vconnp)
+              enum ssl_state state, const struct sockaddr_in *sin,
+              struct vconn **vconnp)
 {
     struct ssl_vconn *sslv;
     SSL *ssl = NULL;
@@ -143,6 +223,7 @@ new_ssl_vconn(const char *name, int fd, enum session_type type,
     sslv = xmalloc(sizeof *sslv);
     sslv->vconn.class = &ssl_vconn_class;
     sslv->vconn.connect_status = EAGAIN;
+    sslv->vconn.ip = sin->sin_addr.s_addr;
     sslv->state = state;
     sslv->type = type;
     sslv->fd = fd;
@@ -150,6 +231,7 @@ new_ssl_vconn(const char *name, int fd, enum session_type type,
     sslv->rxbuf = NULL;
     sslv->txbuf = NULL;
     sslv->tx_waiter = NULL;
+    sslv->rx_want = sslv->tx_want = SSL_NOTHING;
     *vconnp = &sslv->vconn;
     return 0;
 
@@ -188,7 +270,8 @@ ssl_open(const char *name, char *suffix, struct vconn **vconnp)
     host_name = strtok_r(suffix, "::", &save_ptr);
     port_string = strtok_r(NULL, "::", &save_ptr);
     if (!host_name) {
-        fatal(0, "%s: bad peer name format", name);
+        error(0, "%s: bad peer name format", name);
+        return EAFNOSUPPORT;
     }
 
     memset(&sin, 0, sizeof sin);
@@ -216,7 +299,7 @@ ssl_open(const char *name, char *suffix, struct vconn **vconnp)
     if (retval < 0) {
         if (errno == EINPROGRESS) {
             return new_ssl_vconn(name, fd, CLIENT, STATE_TCP_CONNECTING,
-                                 vconnp);
+                                 &sin, vconnp);
         } else {
             int error = errno;
             VLOG_ERR("%s: connect: %s", name, strerror(error));
@@ -225,7 +308,7 @@ ssl_open(const char *name, char *suffix, struct vconn **vconnp)
         }
     } else {
         return new_ssl_vconn(name, fd, CLIENT, STATE_SSL_CONNECTING,
-                             vconnp);
+                             &sin, vconnp);
     }
 }
 
@@ -252,8 +335,9 @@ ssl_connect(struct vconn *vconn)
             if (retval < 0 && ssl_wants_io(error)) {
                 return EAGAIN;
             } else {
+                int unused;
                 interpret_ssl_error((sslv->type == CLIENT ? "SSL_connect"
-                                     : "SSL_accept"), retval, error);
+                                     : "SSL_accept"), retval, error, &unused);
                 shutdown(sslv->fd, SHUT_RDWR);
                 return EPROTO;
             }
@@ -276,8 +360,11 @@ ssl_close(struct vconn *vconn)
 }
 
 static int
-interpret_ssl_error(const char *function, int ret, int error)
+interpret_ssl_error(const char *function, int ret, int error,
+                    int *want)
 {
+    *want = SSL_NOTHING;
+
     switch (error) {
     case SSL_ERROR_NONE:
         VLOG_ERR("%s: unexpected SSL_ERROR_NONE", function);
@@ -288,7 +375,11 @@ interpret_ssl_error(const char *function, int ret, int error)
         break;
 
     case SSL_ERROR_WANT_READ:
+        *want = SSL_READING;
+        return EAGAIN;
+
     case SSL_ERROR_WANT_WRITE:
+        *want = SSL_WRITING;
         return EAGAIN;
 
     case SSL_ERROR_WANT_CONNECT:
@@ -343,6 +434,7 @@ ssl_recv(struct vconn *vconn, struct buffer **bufferp)
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
     struct buffer *rx;
     size_t want_bytes;
+    int old_state;
     ssize_t ret;
 
     if (sslv->rxbuf == NULL) {
@@ -361,13 +453,28 @@ again:
             return EPROTO;
         }
         want_bytes = length - rx->size;
+        if (!want_bytes) {
+            *bufferp = rx;
+            sslv->rxbuf = NULL;
+            return 0;
+        }
     }
-    buffer_reserve_tailroom(rx, want_bytes);
+    buffer_prealloc_tailroom(rx, want_bytes);
 
     /* Behavior of zero-byte SSL_read is poorly defined. */
     assert(want_bytes > 0);
 
+    old_state = SSL_get_state(sslv->ssl);
     ret = SSL_read(sslv->ssl, buffer_tail(rx), want_bytes);
+    if (old_state != SSL_get_state(sslv->ssl)) {
+        sslv->tx_want = SSL_NOTHING;
+        if (sslv->tx_waiter) {
+            poll_cancel(sslv->tx_waiter);
+            ssl_tx_poll_callback(sslv->fd, POLLIN, vconn);
+        }
+    }
+    sslv->rx_want = SSL_NOTHING;
+
     if (ret > 0) {
         rx->size += ret;
         if (ret == want_bytes) {
@@ -391,7 +498,7 @@ again:
                 return EOF;
             }
         } else {
-            return interpret_ssl_error("SSL_read", ret, error);
+            return interpret_ssl_error("SSL_read", ret, error, &sslv->rx_want);
         }
     }
 }
@@ -405,84 +512,83 @@ ssl_clear_txbuf(struct ssl_vconn *sslv)
 }
 
 static void
-ssl_register_tx_waiter(struct vconn *vconn) 
+ssl_register_tx_waiter(struct vconn *vconn)
 {
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
-    short int events = SSL_want_read(sslv->ssl) ? POLLIN : POLLOUT;
-    sslv->tx_waiter = poll_fd_callback(sslv->fd, events, ssl_do_tx, vconn);
+    sslv->tx_waiter = poll_fd_callback(sslv->fd,
+                                       want_to_poll_events(sslv->tx_want),
+                                       ssl_tx_poll_callback, vconn);
+}
+
+static int
+ssl_do_tx(struct vconn *vconn)
+{
+    struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
+
+    for (;;) {
+        int old_state = SSL_get_state(sslv->ssl);
+        int ret = SSL_write(sslv->ssl, sslv->txbuf->data, sslv->txbuf->size);
+        if (old_state != SSL_get_state(sslv->ssl)) {
+            sslv->rx_want = SSL_NOTHING;
+        }
+        sslv->tx_want = SSL_NOTHING;
+        if (ret > 0) {
+            buffer_pull(sslv->txbuf, ret);
+            if (sslv->txbuf->size == 0) {
+                return 0;
+            }
+        } else {
+            int ssl_error = SSL_get_error(sslv->ssl, ret);
+            if (ssl_error == SSL_ERROR_ZERO_RETURN) {
+                VLOG_WARN("SSL_write: connection closed");
+                return EPIPE;
+            } else {
+                return interpret_ssl_error("SSL_write", ret, ssl_error,
+                                           &sslv->tx_want);
+            }
+        }
+    }
 }
 
 static void
-ssl_do_tx(int fd UNUSED, short int revents UNUSED, void *vconn_)
+ssl_tx_poll_callback(int fd UNUSED, short int revents UNUSED, void *vconn_)
 {
     struct vconn *vconn = vconn_;
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
-    int ret = SSL_write(sslv->ssl, sslv->txbuf->data, sslv->txbuf->size);
-    if (ret > 0) {
-        buffer_pull(sslv->txbuf, ret);
-        if (sslv->txbuf->size == 0) {
-            ssl_clear_txbuf(sslv);
-            return;
-        }
+    int error = ssl_do_tx(vconn);
+    if (error != EAGAIN) {
+        ssl_clear_txbuf(sslv);
     } else {
-        int error = SSL_get_error(sslv->ssl, ret);
-        if (error == SSL_ERROR_ZERO_RETURN) {
-            /* Connection closed (EOF). */
-            VLOG_WARN("SSL_write: connection close");
-        } else if (interpret_ssl_error("SSL_write", ret, error) != EAGAIN) {
-            ssl_clear_txbuf(sslv);
-            return;
-        }
+        ssl_register_tx_waiter(vconn);
     }
-    ssl_register_tx_waiter(vconn);
 }
 
 static int
 ssl_send(struct vconn *vconn, struct buffer *buffer)
 {
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
-    ssize_t ret;
 
     if (sslv->txbuf) {
         return EAGAIN;
-    }
+    } else {
+        int error;
 
-    ret = SSL_write(sslv->ssl, buffer->data, buffer->size);
-    if (ret > 0) {
-        if (ret == buffer->size) {
-            buffer_delete(buffer);
-        } else {
-            sslv->txbuf = buffer;
-            buffer_pull(buffer, ret);
+        sslv->txbuf = buffer;
+        error = ssl_do_tx(vconn);
+        switch (error) {
+        case 0:
+            ssl_clear_txbuf(sslv);
+            return 0;
+        case EAGAIN:
             ssl_register_tx_waiter(vconn);
-        }
-        return 0;
-    } else {
-        int error = SSL_get_error(sslv->ssl, ret);
-        if (error == SSL_ERROR_ZERO_RETURN) {
-            /* Connection closed (EOF). */
-            VLOG_WARN("SSL_write: connection close");
-            return EPIPE;
-        } else {
-            return interpret_ssl_error("SSL_write", ret, error);
+            return 0;
+        default:
+            sslv->txbuf = NULL;
+            return error;
         }
     }
 }
 
-static bool
-ssl_needs_wait(struct ssl_vconn *sslv) 
-{
-    if (SSL_want_read(sslv->ssl)) {
-        poll_fd_wait(sslv->fd, POLLIN);
-        return true;
-    } else if (SSL_want_write(sslv->ssl)) {
-        poll_fd_wait(sslv->fd, POLLOUT);
-        return true;
-    } else {
-        return false;
-    }
-}
-
 static void
 ssl_wait(struct vconn *vconn, enum vconn_wait_type wait)
 {
@@ -492,26 +598,39 @@ ssl_wait(struct vconn *vconn, enum vconn_wait_type wait)
     case WAIT_CONNECT:
         if (vconn_connect(vconn) != EAGAIN) {
             poll_immediate_wake();
-        } else if (sslv->state == STATE_TCP_CONNECTING) {
-            poll_fd_wait(sslv->fd, POLLOUT);
-        } else if (!ssl_needs_wait(sslv)) {
-            NOT_REACHED();
+        } else {
+            switch (sslv->state) {
+            case STATE_TCP_CONNECTING:
+                poll_fd_wait(sslv->fd, POLLOUT);
+                break;
+
+            case STATE_SSL_CONNECTING:
+                /* ssl_connect() called SSL_accept() or SSL_connect(), which
+                 * set up the status that we test here. */
+                poll_fd_wait(sslv->fd,
+                             want_to_poll_events(SSL_want(sslv->ssl)));
+                break;
+
+            default:
+                NOT_REACHED();
+            }
         }
         break;
 
     case WAIT_RECV:
-        if (!ssl_needs_wait(sslv)) {
-            if (SSL_pending(sslv->ssl)) {
-                poll_immediate_wake();
-            } else {
-                poll_fd_wait(sslv->fd, POLLIN);
-            }
+        if (sslv->rx_want != SSL_NOTHING) {
+            poll_fd_wait(sslv->fd, want_to_poll_events(sslv->rx_want));
+        } else {
+            poll_immediate_wake();
         }
         break;
 
     case WAIT_SEND:
-        if (!sslv->txbuf && !ssl_needs_wait(sslv)) {
-            poll_fd_wait(sslv->fd, POLLOUT);
+        if (!sslv->txbuf) {
+            /* We have room in our tx queue. */
+            poll_immediate_wake();
+        } else {
+            /* The call to ssl_tx_poll_callback() will wake us up. */
         }
         break;
 
@@ -619,10 +738,13 @@ static int
 pssl_accept(struct vconn *vconn, struct vconn **new_vconnp)
 {
     struct pssl_vconn *pssl = pssl_vconn_cast(vconn);
+    struct sockaddr_in sin;
+    socklen_t sin_len = sizeof sin;
+    char name[128];
     int new_fd;
     int error;
 
-    new_fd = accept(pssl->fd, NULL, NULL);
+    new_fd = accept(pssl->fd, &sin, &sin_len);
     if (new_fd < 0) {
         int error = errno;
         if (error != EAGAIN) {
@@ -637,8 +759,12 @@ pssl_accept(struct vconn *vconn, struct vconn **new_vconnp)
         return error;
     }
 
-    return new_ssl_vconn("ssl" /* FIXME */, new_fd,
-                         SERVER, STATE_SSL_CONNECTING, new_vconnp);
+    sprintf(name, "ssl:"IP_FMT, IP_ARGS(&sin.sin_addr));
+    if (sin.sin_port != htons(OFP_SSL_PORT)) {
+        sprintf(strchr(name, '\0'), ":%"PRIu16, ntohs(sin.sin_port));
+    }
+    return new_ssl_vconn(name, new_fd, SERVER, STATE_SSL_CONNECTING, &sin,
+                         new_vconnp);
 }
 
 static void
@@ -742,6 +868,13 @@ tmp_dh_callback(SSL *ssl, int is_export UNUSED, int keylength)
     return NULL;
 }
 
+/* Returns true if SSL is at least partially configured. */
+bool
+vconn_ssl_is_configured(void) 
+{
+    return has_private_key || has_certificate || has_ca_cert;
+}
+
 void
 vconn_ssl_set_private_key_file(const char *file_name)
 {
@@ -790,7 +923,7 @@ vconn_ssl_set_ca_cert_file(const char *file_name)
 
     /* Set up CAs for OpenSSL to trust in verifying the peer's certificate. */
     if (SSL_CTX_load_verify_locations(ctx, file_name, NULL) != 1) {
-        VLOG_ERR("SSL_load_verify_locations: %s",
+        VLOG_ERR("SSL_CTX_load_verify_locations: %s",
                  ERR_error_string(ERR_get_error(), NULL));
         return;
     }