Prevent deadlock in OpenSSL.
authorBen Pfaff <blp@nicira.com>
Fri, 30 May 2008 23:09:00 +0000 (16:09 -0700)
committerBen Pfaff <blp@nicira.com>
Wed, 4 Jun 2008 16:37:08 +0000 (09:37 -0700)
lib/vconn-ssl.c

index 507e032debb6ef93ad9f5e7974e9b86b362aab73..99b82699f7e27548ac9132b8a4f0d39fb043935a 100644 (file)
@@ -76,6 +76,64 @@ struct ssl_vconn
     struct buffer *rxbuf;
     struct buffer *txbuf;
     struct poll_waiter *tx_waiter;
+
+    /* rx_want and tx_want record the result of the last call to SSL_read()
+     * and SSL_write(), respectively:
+     *
+     *    - If the call reported that data needed to be read from the file
+     *      descriptor, the corresponding member is set to SSL_READING.
+     *
+     *    - If the call reported that data needed to be written to the file
+     *      descriptor, the corresponding member is set to SSL_WRITING.
+     *
+     *    - Otherwise, the member is set to SSL_NOTHING, indicating that the
+     *      call completed successfully (or with an error) and that there is no
+     *      need to block.
+     *
+     * These are needed because there is no way to ask OpenSSL what a data read
+     * or write would require without giving it a buffer to receive into or
+     * data to send, respectively.  (Note that the SSL_want() status is
+     * overwritten by each SSL_read() or SSL_write() call, so we can't rely on
+     * its value.)
+     *
+     * A single call to SSL_read() or SSL_write() can perform both reading
+     * and writing and thus invalidate not one of these values but actually
+     * both.  Consider this situation, for example:
+     *
+     *    - SSL_write() blocks on a read, so tx_want gets SSL_READING.
+     *
+     *    - SSL_read() laters succeeds reading from 'fd' and clears out the
+     *      whole receive buffer, so rx_want gets SSL_READING.
+     *
+     *    - Client calls vconn_wait(WAIT_RECV) and vconn_wait(WAIT_SEND) and
+     *      blocks.
+     *
+     *    - Now we're stuck blocking until the peer sends us data, even though
+     *      SSL_write() could now succeed, which could easily be a deadlock
+     *      condition.
+     *
+     * On the other hand, we can't reset both tx_want and rx_want on every call
+     * to SSL_read() or SSL_write(), because that would produce livelock,
+     * e.g. in this situation:
+     *
+     *    - SSL_write() blocks, so tx_want gets SSL_READING or SSL_WRITING.
+     *
+     *    - SSL_read() blocks, so rx_want gets SSL_READING or SSL_WRITING,
+     *      but tx_want gets reset to SSL_NOTHING.
+     *
+     *    - Client calls vconn_wait(WAIT_RECV) and vconn_wait(WAIT_SEND) and
+     *      blocks.
+     *
+     *    - Client wakes up immediately since SSL_NOTHING in tx_want indicates
+     *      that no blocking is necessary.
+     *
+     * The solution we adopt here is to set tx_want to SSL_NOTHING after
+     * calling SSL_read() only if the SSL state of the connection changed,
+     * which indicates that an SSL-level renegotiation made some progress, and
+     * similarly for rx_want and SSL_write().  This prevents both the
+     * deadlock and livelock situations above.
+     */
+    int rx_want, tx_want;
 };
 
 /* SSL context created by ssl_init(). */
@@ -88,10 +146,29 @@ static int ssl_init(void);
 static int do_ssl_init(void);
 static bool ssl_wants_io(int ssl_error);
 static void ssl_close(struct vconn *);
-static int interpret_ssl_error(const char *function, int ret, int error);
-static void ssl_do_tx(int fd, short int revents, void *vconn_);
+static int interpret_ssl_error(const char *function, int ret, int error,
+                               int *want);
+static void ssl_tx_poll_callback(int fd, short int revents, void *vconn_);
 static DH *tmp_dh_callback(SSL *ssl, int is_export UNUSED, int keylength);
 
+short int
+want_to_poll_events(int want)
+{
+    switch (want) {
+    case SSL_NOTHING:
+        NOT_REACHED();
+
+    case SSL_READING:
+        return POLLIN;
+
+    case SSL_WRITING:
+        return POLLOUT;
+
+    default:
+        NOT_REACHED();
+    }
+}
+
 static int
 new_ssl_vconn(const char *name, int fd, enum session_type type,
               enum ssl_state state, struct vconn **vconnp)
@@ -150,6 +227,7 @@ new_ssl_vconn(const char *name, int fd, enum session_type type,
     sslv->rxbuf = NULL;
     sslv->txbuf = NULL;
     sslv->tx_waiter = NULL;
+    sslv->rx_want = sslv->tx_want = SSL_NOTHING;
     *vconnp = &sslv->vconn;
     return 0;
 
@@ -253,7 +331,7 @@ ssl_connect(struct vconn *vconn)
                 return EAGAIN;
             } else {
                 interpret_ssl_error((sslv->type == CLIENT ? "SSL_connect"
-                                     : "SSL_accept"), retval, error);
+                                     : "SSL_accept"), retval, error, NULL);
                 shutdown(sslv->fd, SHUT_RDWR);
                 return EPROTO;
             }
@@ -276,8 +354,11 @@ ssl_close(struct vconn *vconn)
 }
 
 static int
-interpret_ssl_error(const char *function, int ret, int error)
+interpret_ssl_error(const char *function, int ret, int error,
+                    int *want)
 {
+    *want = SSL_NOTHING;
+
     switch (error) {
     case SSL_ERROR_NONE:
         VLOG_ERR("%s: unexpected SSL_ERROR_NONE", function);
@@ -288,7 +369,11 @@ interpret_ssl_error(const char *function, int ret, int error)
         break;
 
     case SSL_ERROR_WANT_READ:
+        *want = SSL_READING;
+        return EAGAIN;
+
     case SSL_ERROR_WANT_WRITE:
+        *want = SSL_WRITING;
         return EAGAIN;
 
     case SSL_ERROR_WANT_CONNECT:
@@ -343,6 +428,7 @@ ssl_recv(struct vconn *vconn, struct buffer **bufferp)
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
     struct buffer *rx;
     size_t want_bytes;
+    int old_state;
     ssize_t ret;
 
     if (sslv->rxbuf == NULL) {
@@ -372,7 +458,17 @@ again:
     /* Behavior of zero-byte SSL_read is poorly defined. */
     assert(want_bytes > 0);
 
+    old_state = SSL_get_state(sslv->ssl);
     ret = SSL_read(sslv->ssl, buffer_tail(rx), want_bytes);
+    if (old_state != SSL_get_state(sslv->ssl)) {
+        sslv->tx_want = SSL_NOTHING;
+        if (sslv->tx_waiter) {
+            poll_cancel(sslv->tx_waiter);
+            ssl_tx_poll_callback(sslv->fd, POLLIN, vconn);
+        }
+    }
+    sslv->rx_want = SSL_NOTHING;
+
     if (ret > 0) {
         rx->size += ret;
         if (ret == want_bytes) {
@@ -396,7 +492,7 @@ again:
                 return EOF;
             }
         } else {
-            return interpret_ssl_error("SSL_read", ret, error);
+            return interpret_ssl_error("SSL_read", ret, error, &sslv->rx_want);
         }
     }
 }
@@ -410,81 +506,80 @@ ssl_clear_txbuf(struct ssl_vconn *sslv)
 }
 
 static void
-ssl_register_tx_waiter(struct vconn *vconn) 
+ssl_register_tx_waiter(struct vconn *vconn)
 {
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
-    short int events = SSL_want_read(sslv->ssl) ? POLLIN : POLLOUT;
-    sslv->tx_waiter = poll_fd_callback(sslv->fd, events, ssl_do_tx, vconn);
+    sslv->tx_waiter = poll_fd_callback(sslv->fd,
+                                       want_to_poll_events(sslv->tx_want),
+                                       ssl_tx_poll_callback, vconn);
+}
+
+static int
+ssl_do_tx(struct vconn *vconn)
+{
+    struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
+
+    for (;;) {
+        int old_state = SSL_get_state(sslv->ssl);
+        int ret = SSL_write(sslv->ssl, sslv->txbuf->data, sslv->txbuf->size);
+        if (old_state != SSL_get_state(sslv->ssl)) {
+            sslv->rx_want = SSL_NOTHING;
+        }
+        sslv->tx_want = SSL_NOTHING;
+        if (ret > 0) {
+            buffer_pull(sslv->txbuf, ret);
+            if (sslv->txbuf->size == 0) {
+                return 0;
+            }
+        } else {
+            int ssl_error = SSL_get_error(sslv->ssl, ret);
+            if (ssl_error == SSL_ERROR_ZERO_RETURN) {
+                VLOG_WARN("SSL_write: connection closed");
+                return EPIPE;
+            } else {
+                return interpret_ssl_error("SSL_write", ret, ssl_error,
+                                           &sslv->tx_want);
+            }
+        }
+    }
 }
 
 static void
-ssl_do_tx(int fd UNUSED, short int revents UNUSED, void *vconn_)
+ssl_tx_poll_callback(int fd UNUSED, short int revents UNUSED, void *vconn_)
 {
     struct vconn *vconn = vconn_;
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
-    int ret = SSL_write(sslv->ssl, sslv->txbuf->data, sslv->txbuf->size);
-    if (ret > 0) {
-        buffer_pull(sslv->txbuf, ret);
-        if (sslv->txbuf->size == 0) {
-            ssl_clear_txbuf(sslv);
-            return;
-        }
+    int error = ssl_do_tx(vconn);
+    if (error != EAGAIN) {
+        ssl_clear_txbuf(sslv);
     } else {
-        int error = SSL_get_error(sslv->ssl, ret);
-        if (error == SSL_ERROR_ZERO_RETURN) {
-            /* Connection closed (EOF). */
-            VLOG_WARN("SSL_write: connection close");
-        } else if (interpret_ssl_error("SSL_write", ret, error) != EAGAIN) {
-            ssl_clear_txbuf(sslv);
-            return;
-        }
+        ssl_register_tx_waiter(vconn);
     }
-    ssl_register_tx_waiter(vconn);
 }
 
 static int
 ssl_send(struct vconn *vconn, struct buffer *buffer)
 {
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
-    ssize_t ret;
 
     if (sslv->txbuf) {
         return EAGAIN;
-    }
+    } else {
+        int error;
 
-    ret = SSL_write(sslv->ssl, buffer->data, buffer->size);
-    if (ret > 0) {
-        if (ret == buffer->size) {
-            buffer_delete(buffer);
-        } else {
-            sslv->txbuf = buffer;
-            buffer_pull(buffer, ret);
+        sslv->txbuf = buffer;
+        error = ssl_do_tx(vconn);
+        switch (error) {
+        case 0:
+            ssl_clear_txbuf(sslv);
+            return 0;
+        case EAGAIN:
             ssl_register_tx_waiter(vconn);
+            return 0;
+        default:
+            sslv->txbuf = NULL;
+            return error;
         }
-        return 0;
-    } else {
-        int error = SSL_get_error(sslv->ssl, ret);
-        if (error == SSL_ERROR_ZERO_RETURN) {
-            /* Connection closed (EOF). */
-            VLOG_WARN("SSL_write: connection close");
-            return EPIPE;
-        } else {
-            return interpret_ssl_error("SSL_write", ret, error);
-        }
-    }
-}
-
-static bool
-ssl_needs_wait(struct ssl_vconn *sslv) 
-{
-    if (SSL_want_read(sslv->ssl)) {
-        poll_fd_wait(sslv->fd, POLLIN);
-        return true;
-    } else if (SSL_want_write(sslv->ssl)) {
-        poll_fd_wait(sslv->fd, POLLOUT);
-        return true;
-    } else {
-        return false;
     }
 }
 
@@ -497,26 +592,39 @@ ssl_wait(struct vconn *vconn, enum vconn_wait_type wait)
     case WAIT_CONNECT:
         if (vconn_connect(vconn) != EAGAIN) {
             poll_immediate_wake();
-        } else if (sslv->state == STATE_TCP_CONNECTING) {
-            poll_fd_wait(sslv->fd, POLLOUT);
-        } else if (!ssl_needs_wait(sslv)) {
-            NOT_REACHED();
+        } else {
+            switch (sslv->state) {
+            case STATE_TCP_CONNECTING:
+                poll_fd_wait(sslv->fd, POLLOUT);
+                break;
+
+            case STATE_SSL_CONNECTING:
+                /* ssl_connect() called SSL_accept() or SSL_connect(), which
+                 * set up the status that we test here. */
+                poll_fd_wait(sslv->fd,
+                             want_to_poll_events(SSL_want(sslv->ssl)));
+                break;
+
+            default:
+                NOT_REACHED();
+            }
         }
         break;
 
     case WAIT_RECV:
-        if (!ssl_needs_wait(sslv)) {
-            if (SSL_pending(sslv->ssl)) {
-                poll_immediate_wake();
-            } else {
-                poll_fd_wait(sslv->fd, POLLIN);
-            }
+        if (sslv->rx_want != SSL_NOTHING) {
+            poll_fd_wait(sslv->fd, want_to_poll_events(sslv->rx_want));
+        } else {
+            poll_immediate_wake();
         }
         break;
 
     case WAIT_SEND:
-        if (!sslv->txbuf && !ssl_needs_wait(sslv)) {
-            poll_fd_wait(sslv->fd, POLLOUT);
+        if (!sslv->txbuf) {
+            /* We have room in our tx queue. */
+            poll_immediate_wake();
+        } else {
+            /* The call to ssl_tx_poll_callback() will wake us up. */
         }
         break;