diff --git a/udp.c b/udp.c index 2178082..097e263 100644 --- a/udp.c +++ b/udp.c @@ -644,9 +644,11 @@ static void udp_sock_handler_splice(struct ctx *c, union epoll_ref ref, void udp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events, struct timespec *now) { - int i, iov_in_msg, msg_i = 0; + int iov_in_msg, msg_i = 0, ret; + ssize_t n, msglen, missing; + struct mmsghdr *tap_mmh; struct msghdr *cur_mh; - ssize_t n, msglen; + unsigned int i; if (events == EPOLLERR) return; @@ -664,7 +666,7 @@ void udp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events, cur_mh = &udp6_l2_mh_tap[msg_i].msg_hdr; cur_mh->msg_iov = &udp6_l2_iov_tap[0]; msg_i = msglen = iov_in_msg = 0; - /* TODO: Explicit AVX2 vectorisation of this loop */ + for (i = 0; i < n; i++) { struct udp6_l2_buf_t *b = &udp6_l2_buf[i]; size_t ip_len, iov_len; @@ -725,7 +727,7 @@ void udp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events, udp6_l2_iov_tap[i].iov_len = iov_len; /* With bigger messages, qemu closes the connection. */ - if (iov_in_msg && msglen + iov_len > SHRT_MAX) { + if (iov_in_msg && msglen + iov_len > USHRT_MAX) { cur_mh->msg_iovlen = iov_in_msg; cur_mh = &udp6_l2_mh_tap[++msg_i].msg_hdr; @@ -737,14 +739,7 @@ void udp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events, iov_in_msg++; } - if (c->mode == MODE_PASTA) - return; - - cur_mh->msg_iovlen = iov_in_msg; - - sendmmsg(c->fd_tap, udp6_l2_mh_tap, msg_i + 1, - MSG_NOSIGNAL | MSG_DONTWAIT); - pcapmm(udp6_l2_mh_tap, msg_i + 1); + tap_mmh = udp6_l2_mh_tap; } else { n = recvmmsg(ref.s, udp4_l2_mh_sock, UDP_TAP_FRAMES, 0, NULL); if (n <= 0) @@ -753,7 +748,7 @@ void udp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events, cur_mh = &udp4_l2_mh_tap[msg_i].msg_hdr; cur_mh->msg_iov = &udp4_l2_iov_tap[0]; msg_i = msglen = iov_in_msg = 0; - /* TODO: Explicit AVX2 vectorisation of this loop */ + for (i = 0; i < n; i++) { struct udp4_l2_buf_t *b = &udp4_l2_buf[i]; size_t ip_len, iov_len; @@ -801,7 +796,7 @@ void udp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events, udp4_l2_iov_tap[i].iov_len = iov_len; /* With bigger messages, qemu closes the connection. */ - if (iov_in_msg && msglen + iov_len > SHRT_MAX) { + if (iov_in_msg && msglen + iov_len > USHRT_MAX) { cur_mh->msg_iovlen = iov_in_msg; cur_mh = &udp4_l2_mh_tap[++msg_i].msg_hdr; @@ -813,15 +808,60 @@ void udp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events, iov_in_msg++; } - if (c->mode == MODE_PASTA) - return; - - cur_mh->msg_iovlen = iov_in_msg; - - sendmmsg(c->fd_tap, udp4_l2_mh_tap, msg_i + 1, - MSG_NOSIGNAL | MSG_DONTWAIT); - pcapmm(udp4_l2_mh_tap, msg_i + 1); + tap_mmh = udp4_l2_mh_tap; } + + if (c->mode == MODE_PASTA) + return; + + cur_mh->msg_iovlen = iov_in_msg; + ret = sendmmsg(c->fd_tap, tap_mmh, msg_i + 1, + MSG_NOSIGNAL | MSG_DONTWAIT); + if (ret <= 0) + return; + + /* If we lose some messages to sendmmsg() here, fine, it's UDP. However, + * the last message needs to be delivered completely, otherwise qemu + * will fail to reassemble the next message and close the connection. Go + * through headers from the last sent message, counting bytes, and, if + * and as soon as we see more bytes than sendmmsg() sent, re-send the + * rest with a blocking call. + * + * In pictures, given this example: + * + * iov #0 iov #2 iov #3 iov #4 + * tap_mmh[ret - 1].msg_hdr: .... ...... ..... ...... + * tap_mmh[ret - 1].msg_len: 7 .... ... + * + * when 'msglen' reaches: 10 ^ + * and 'missing' below is: 3 --- + * + * re-send everything from here: ^-- ----- ------ + */ + cur_mh = &tap_mmh[ret - 1].msg_hdr; + for (i = 0, msglen = 0; i < cur_mh->msg_iovlen; i++) { + if (missing <= 0) { + msglen += cur_mh->msg_iov[i].iov_len; + missing = msglen - tap_mmh[ret - 1].msg_len; + } + + if (missing) { + uint8_t **iov_base; + int first_offset; + + iov_base = (uint8_t **)&cur_mh->msg_iov[i].iov_base; + first_offset = cur_mh->msg_iov[i].iov_len - missing; + *iov_base += first_offset; + cur_mh->msg_iov[i].iov_len = missing; + + cur_mh->msg_iov = &cur_mh->msg_iov[i]; + + sendmsg(c->fd_tap, cur_mh, MSG_NOSIGNAL); + break; + } + } + + pcapmm(tap_mmh, ret); } /**