From: Ben Pfaff Date: Thu, 14 Jun 2012 16:46:18 +0000 (-0700) Subject: socket-util: Add functions for sending fds over Unix domain sockets. X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?p=openvswitch;a=commitdiff_plain;h=fd94a42c43ff4a0e57a44bdc9ded1b7e1e63faaa socket-util: Add functions for sending fds over Unix domain sockets. These will be used in upcoming commits. This commit also adds corresponding definitions to the "sparse" header, so that sparse still works. Signed-off-by: Ben Pfaff --- diff --git a/include/sparse/sys/socket.h b/include/sparse/sys/socket.h index d7e17ea1..13f61e50 100644 --- a/include/sparse/sys/socket.h +++ b/include/sparse/sys/socket.h @@ -47,6 +47,37 @@ struct msghdr { int msg_flags; }; +struct cmsghdr { + size_t cmsg_len; + int cmsg_level; + int cmsg_type; + unsigned char cmsg_data[]; +}; + +#define __CMSG_ALIGNTO sizeof(size_t) +#define CMSG_ALIGN(LEN) \ + (((LEN) + __CMSG_ALIGNTO - 1) / __CMSG_ALIGNTO * __CMSG_ALIGNTO) +#define CMSG_DATA(CMSG) ((CMSG)->cmsg_data) +#define CMSG_LEN(LEN) (sizeof(struct cmsghdr) + (LEN)) +#define CMSG_SPACE(LEN) CMSG_ALIGN(CMSG_LEN(LEN)) +#define CMSG_FIRSTHDR(MSG) \ + ((MSG)->msg_controllen ? (struct cmsghdr *) (MSG)->msg_control : NULL) +#define CMSG_NXTHDR(MSG, CMSG) __cmsg_nxthdr(MSG, CMSG) + +static inline struct cmsghdr * +__cmsg_nxthdr(struct msghdr *msg, struct cmsghdr *cmsg) +{ + size_t ofs = (char *) cmsg - (char *) msg->msg_control; + size_t next_ofs = ofs + CMSG_ALIGN(cmsg->cmsg_len); + return (next_ofs < msg->msg_controllen + ? (void *) ((char *) msg->msg_control + next_ofs) + : NULL); +} + +enum { + SCM_RIGHTS = 1 +}; + enum { SOCK_DGRAM, SOCK_RAW, diff --git a/lib/socket-util.c b/lib/socket-util.c index 939cd90f..5fe9f98b 100644 --- a/lib/socket-util.c +++ b/lib/socket-util.c @@ -35,6 +35,7 @@ #include "dynamic-string.h" #include "fatal-signal.h" #include "packets.h" +#include "poll-loop.h" #include "util.h" #include "vlog.h" #if AF_PACKET && __linux__ @@ -889,6 +890,14 @@ xpipe(int fds[2]) } } +void +xsocketpair(int domain, int type, int protocol, int fds[2]) +{ + if (socketpair(domain, type, protocol, fds)) { + VLOG_FATAL("failed to create socketpair (%s)", strerror(errno)); + } +} + static int getsockopt_int(int fd, int level, int option, const char *optname, int *valuep) { @@ -1048,3 +1057,249 @@ describe_fd(int fd) } return ds_steal_cstr(&string); } + +/* Returns the total of the 'iov_len' members of the 'n_iovs' in 'iovs'. + * The caller must ensure that the total does not exceed SIZE_MAX. */ +size_t +iovec_len(const struct iovec iovs[], size_t n_iovs) +{ + size_t len = 0; + size_t i; + + for (i = 0; i < n_iovs; i++) { + len += iovs[i].iov_len; + } + return len; +} + +/* Returns true if all of the 'n_iovs' iovecs in 'iovs' have length zero. */ +bool +iovec_is_empty(const struct iovec iovs[], size_t n_iovs) +{ + size_t i; + + for (i = 0; i < n_iovs; i++) { + if (iovs[i].iov_len) { + return false; + } + } + return true; +} + +/* Sends the 'n_iovs' iovecs of data in 'iovs' and the 'n_fds' file descriptors + * in 'fds' on Unix domain socket 'sock'. Returns the number of bytes + * successfully sent or -1 if an error occurred. On error, sets errno + * appropriately. */ +int +send_iovec_and_fds(int sock, + const struct iovec *iovs, size_t n_iovs, + const int fds[], size_t n_fds) +{ + assert(sock >= 0); + if (n_fds > 0) { + union { + struct cmsghdr cm; + char control[CMSG_SPACE(SOUTIL_MAX_FDS * sizeof *fds)]; + } cmsg; + struct msghdr msg; + + assert(!iovec_is_empty(iovs, n_iovs)); + assert(n_fds <= SOUTIL_MAX_FDS); + + memset(&cmsg, 0, sizeof cmsg); + cmsg.cm.cmsg_len = CMSG_LEN(n_fds * sizeof *fds); + cmsg.cm.cmsg_level = SOL_SOCKET; + cmsg.cm.cmsg_type = SCM_RIGHTS; + memcpy(CMSG_DATA(&cmsg.cm), fds, n_fds * sizeof *fds); + + msg.msg_name = NULL; + msg.msg_namelen = 0; + msg.msg_iov = (struct iovec *) iovs; + msg.msg_iovlen = n_iovs; + msg.msg_control = &cmsg.cm; + msg.msg_controllen = CMSG_SPACE(n_fds * sizeof *fds); + msg.msg_flags = 0; + + return sendmsg(sock, &msg, 0); + } else { + return writev(sock, iovs, n_iovs); + } +} + +/* Sends the 'n_iovs' iovecs of data in 'iovs' and the 'n_fds' file descriptors + * in 'fds' on Unix domain socket 'sock'. If 'skip_bytes' is nonzero, then the + * first 'skip_bytes' of data in the iovecs are not sent, and none of the file + * descriptors are sent. The function continues to retry sending until an + * error (other than EINTR) occurs or all the data and fds are sent. + * + * Returns 0 if all the data and fds were successfully sent, otherwise a + * positive errno value. Regardless of success, stores the number of bytes + * sent (always at least 'skip_bytes') in '*bytes_sent'. (If at least one byte + * is sent, then all the fds have been sent.) + * + * 'skip_bytes' must be less than or equal to iovec_len(iovs, n_iovs). */ +int +send_iovec_and_fds_fully(int sock, + const struct iovec iovs[], size_t n_iovs, + const int fds[], size_t n_fds, + size_t skip_bytes, size_t *bytes_sent) +{ + *bytes_sent = 0; + while (n_iovs > 0) { + int retval; + + if (skip_bytes) { + retval = skip_bytes; + skip_bytes = 0; + } else if (!*bytes_sent) { + retval = send_iovec_and_fds(sock, iovs, n_iovs, fds, n_fds); + } else { + retval = writev(sock, iovs, n_iovs); + } + + if (retval > 0) { + *bytes_sent += retval; + while (retval > 0) { + const uint8_t *base = iovs->iov_base; + size_t len = iovs->iov_len; + + if (retval < len) { + size_t sent; + int error; + + error = write_fully(sock, base + retval, len - retval, + &sent); + *bytes_sent += sent; + retval += sent; + if (error) { + return error; + } + } + retval -= len; + iovs++; + n_iovs--; + } + } else if (retval == 0) { + if (iovec_is_empty(iovs, n_iovs)) { + break; + } + VLOG_WARN("send returned 0"); + return EPROTO; + } else if (errno != EINTR) { + return errno; + } + } + + return 0; +} + +/* Sends the 'n_iovs' iovecs of data in 'iovs' and the 'n_fds' file descriptors + * in 'fds' on Unix domain socket 'sock'. The function continues to retry + * sending until an error (other than EAGAIN or EINTR) occurs or all the data + * and fds are sent. Upon EAGAIN, the function blocks until the socket is + * ready for more data. + * + * Returns 0 if all the data and fds were successfully sent, otherwise a + * positive errno value. */ +int +send_iovec_and_fds_fully_block(int sock, + const struct iovec iovs[], size_t n_iovs, + const int fds[], size_t n_fds) +{ + size_t sent = 0; + + for (;;) { + int error; + + error = send_iovec_and_fds_fully(sock, iovs, n_iovs, + fds, n_fds, sent, &sent); + if (error != EAGAIN) { + return error; + } + poll_fd_wait(sock, POLLOUT); + poll_block(); + } +} + +/* Attempts to receive from Unix domain socket 'sock' up to 'size' bytes of + * data into 'data' and up to SOUTIL_MAX_FDS file descriptors into 'fds'. + * + * - Upon success, returns the number of bytes of data copied into 'data' + * and stores the number of received file descriptors into '*n_fdsp'. + * + * - On failure, returns a negative errno value and stores 0 in + * '*n_fdsp'. + * + * - On EOF, returns 0 and stores 0 in '*n_fdsp'. */ +int +recv_data_and_fds(int sock, + void *data, size_t size, + int fds[SOUTIL_MAX_FDS], size_t *n_fdsp) +{ + union { + struct cmsghdr cm; + char control[CMSG_SPACE(SOUTIL_MAX_FDS * sizeof *fds)]; + } cmsg; + struct msghdr msg; + int retval; + struct cmsghdr *p; + size_t i; + + *n_fdsp = 0; + + do { + struct iovec iov; + + iov.iov_base = data; + iov.iov_len = size; + + msg.msg_name = NULL; + msg.msg_namelen = 0; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = &cmsg.cm; + msg.msg_controllen = sizeof cmsg.control; + msg.msg_flags = 0; + + retval = recvmsg(sock, &msg, 0); + } while (retval < 0 && errno == EINTR); + if (retval <= 0) { + return retval < 0 ? -errno : 0; + } + + for (p = CMSG_FIRSTHDR(&msg); p; p = CMSG_NXTHDR(&msg, p)) { + if (p->cmsg_level != SOL_SOCKET || p->cmsg_type != SCM_RIGHTS) { + VLOG_ERR("unexpected control message %d:%d", + p->cmsg_level, p->cmsg_type); + goto error; + } else if (*n_fdsp) { + VLOG_ERR("multiple SCM_RIGHTS received"); + goto error; + } else { + size_t n_fds = (p->cmsg_len - CMSG_LEN(0)) / sizeof *fds; + const int *fds_data = (const int *) CMSG_DATA(p); + + assert(n_fds > 0); + if (n_fds > SOUTIL_MAX_FDS) { + VLOG_ERR("%zu fds received but only %d supported", + n_fds, SOUTIL_MAX_FDS); + for (i = 0; i < n_fds; i++) { + close(fds_data[i]); + } + goto error; + } + + *n_fdsp = n_fds; + memcpy(fds, fds_data, n_fds * sizeof *fds); + } + } + + return retval; + +error: + for (i = 0; i < *n_fdsp; i++) { + close(fds[i]); + } + *n_fdsp = 0; + return EPROTO; +} diff --git a/lib/socket-util.h b/lib/socket-util.h index e2e0d9a2..a0e7970a 100644 --- a/lib/socket-util.h +++ b/lib/socket-util.h @@ -71,4 +71,28 @@ char *describe_fd(int fd); * in is used. */ #define DSCP_DEFAULT (IPTOS_PREC_INTERNETCONTROL >> 2) +/* Maximum number of fds that we support sending or receiving at one time + * across a Unix domain socket. */ +#define SOUTIL_MAX_FDS 8 + +/* Iovecs. */ +size_t iovec_len(const struct iovec *iovs, size_t n_iovs); +bool iovec_is_empty(const struct iovec *iovs, size_t n_iovs); + +/* Functions particularly useful for Unix domain sockets. */ +void xsocketpair(int domain, int type, int protocol, int fds[2]); +int send_iovec_and_fds(int sock, + const struct iovec *iovs, size_t n_iovs, + const int fds[], size_t n_fds); +int send_iovec_and_fds_fully(int sock, + const struct iovec *iovs, size_t n_iovs, + const int fds[], size_t n_fds, + size_t skip_bytes, size_t *bytes_sent); +int send_iovec_and_fds_fully_block(int sock, + const struct iovec *iovs, size_t n_iovs, + const int fds[], size_t n_fds); +int recv_data_and_fds(int sock, + void *data, size_t size, + int fds[SOUTIL_MAX_FDS], size_t *n_fdsp); + #endif /* socket-util.h */