mirror of
https://gitlab.com/libvirt/libvirt.git
synced 2025-03-20 07:59:00 +00:00
Fix sending/receiving of FDs when stream returns EAGAIN
The code calling sendfd/recvfd was mistakenly assuming those calls would never block. They can in fact return EAGAIN and this is causing us to drop the client connection when blocking ocurrs while sending/receiving FDs. Fixing this is a little hairy on the incoming side, since at the point where we see the EAGAIN, we already thought we had finished receiving all data for the packet. So we play a little trick to reset bufferOffset again and go back into polling for more data. * src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Update virNetSocketSendFD/RecvFD to return 0 on EAGAIN, or 1 on success * src/rpc/virnetclient.c: Move decoding of header & fds out of virNetClientCallDispatch and into virNetClientIOHandleInput. Handling blocking when sending/receiving FDs * src/rpc/virnetmessage.h: Add a 'donefds' field to track how many FDs we've sent / received * src/rpc/virnetserverclient.c: Handling blocking when sending/receiving FDs
This commit is contained in:
parent
4d970fd293
commit
b2c6231647
@ -694,10 +694,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
|
||||
static int
|
||||
virNetClientCallDispatch(virNetClientPtr client)
|
||||
{
|
||||
size_t i;
|
||||
if (virNetMessageDecodeHeader(&client->msg) < 0)
|
||||
return -1;
|
||||
|
||||
PROBE(RPC_CLIENT_MSG_RX,
|
||||
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
|
||||
client, client->msg.bufferLength,
|
||||
@ -706,15 +702,7 @@ virNetClientCallDispatch(virNetClientPtr client)
|
||||
|
||||
switch (client->msg.header.type) {
|
||||
case VIR_NET_REPLY: /* Normal RPC replies */
|
||||
return virNetClientCallDispatchReply(client);
|
||||
|
||||
case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */
|
||||
if (virNetMessageDecodeNumFDs(&client->msg) < 0)
|
||||
return -1;
|
||||
for (i = 0 ; i < client->msg.nfds ; i++) {
|
||||
if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0)
|
||||
return -1;
|
||||
}
|
||||
return virNetClientCallDispatchReply(client);
|
||||
|
||||
case VIR_NET_MESSAGE: /* Async notifications */
|
||||
@ -737,22 +725,29 @@ static ssize_t
|
||||
virNetClientIOWriteMessage(virNetClientPtr client,
|
||||
virNetClientCallPtr thecall)
|
||||
{
|
||||
ssize_t ret;
|
||||
ssize_t ret = 0;
|
||||
|
||||
ret = virNetSocketWrite(client->sock,
|
||||
thecall->msg->buffer + thecall->msg->bufferOffset,
|
||||
thecall->msg->bufferLength - thecall->msg->bufferOffset);
|
||||
if (ret <= 0)
|
||||
return ret;
|
||||
if (thecall->msg->bufferOffset < thecall->msg->bufferLength) {
|
||||
ret = virNetSocketWrite(client->sock,
|
||||
thecall->msg->buffer + thecall->msg->bufferOffset,
|
||||
thecall->msg->bufferLength - thecall->msg->bufferOffset);
|
||||
if (ret <= 0)
|
||||
return ret;
|
||||
|
||||
thecall->msg->bufferOffset += ret;
|
||||
thecall->msg->bufferOffset += ret;
|
||||
}
|
||||
|
||||
if (thecall->msg->bufferOffset == thecall->msg->bufferLength) {
|
||||
size_t i;
|
||||
for (i = 0 ; i < thecall->msg->nfds ; i++) {
|
||||
if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0)
|
||||
for (i = thecall->msg->donefds ; i < thecall->msg->nfds ; i++) {
|
||||
int rv;
|
||||
if ((rv = virNetSocketSendFD(client->sock, thecall->msg->fds[i])) < 0)
|
||||
return -1;
|
||||
if (rv == 0) /* Blocking */
|
||||
return 0;
|
||||
thecall->msg->donefds++;
|
||||
}
|
||||
thecall->msg->donefds = 0;
|
||||
thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
|
||||
if (thecall->expectReply)
|
||||
thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
|
||||
@ -821,12 +816,16 @@ virNetClientIOHandleInput(virNetClientPtr client)
|
||||
* EAGAIN
|
||||
*/
|
||||
for (;;) {
|
||||
ssize_t ret = virNetClientIOReadMessage(client);
|
||||
ssize_t ret;
|
||||
|
||||
if (ret < 0)
|
||||
return -1;
|
||||
if (ret == 0)
|
||||
return 0; /* Blocking on read */
|
||||
if (client->msg.nfds == 0) {
|
||||
ret = virNetClientIOReadMessage(client);
|
||||
|
||||
if (ret < 0)
|
||||
return -1;
|
||||
if (ret == 0)
|
||||
return 0; /* Blocking on read */
|
||||
}
|
||||
|
||||
/* Check for completion of our goal */
|
||||
if (client->msg.bufferOffset == client->msg.bufferLength) {
|
||||
@ -842,6 +841,33 @@ virNetClientIOHandleInput(virNetClientPtr client)
|
||||
* next iteration.
|
||||
*/
|
||||
} else {
|
||||
if (virNetMessageDecodeHeader(&client->msg) < 0)
|
||||
return -1;
|
||||
|
||||
if (client->msg.header.type == VIR_NET_REPLY_WITH_FDS) {
|
||||
size_t i;
|
||||
if (virNetMessageDecodeNumFDs(&client->msg) < 0)
|
||||
return -1;
|
||||
|
||||
for (i = client->msg.donefds ; i < client->msg.nfds ; i++) {
|
||||
int rv;
|
||||
if ((rv = virNetSocketRecvFD(client->sock, &(client->msg.fds[i]))) < 0)
|
||||
return -1;
|
||||
if (rv == 0) /* Blocking */
|
||||
break;
|
||||
client->msg.donefds++;
|
||||
}
|
||||
|
||||
if (client->msg.donefds < client->msg.nfds) {
|
||||
/* Because DecodeHeader/NumFDs reset bufferOffset, we
|
||||
* put it back to what it was, so everything works
|
||||
* again next time we run this method
|
||||
*/
|
||||
client->msg.bufferOffset = client->msg.bufferLength;
|
||||
return 0; /* Blocking on more fds */
|
||||
}
|
||||
}
|
||||
|
||||
ret = virNetClientCallDispatch(client);
|
||||
client->msg.bufferOffset = client->msg.bufferLength = 0;
|
||||
/*
|
||||
@ -1257,6 +1283,7 @@ int virNetClientSend(virNetClientPtr client,
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
msg->donefds = 0;
|
||||
if (msg->bufferLength)
|
||||
call->mode = VIR_NET_CLIENT_MODE_WAIT_TX;
|
||||
else
|
||||
|
@ -48,6 +48,7 @@ struct _virNetMessage {
|
||||
|
||||
size_t nfds;
|
||||
int *fds;
|
||||
size_t donefds;
|
||||
|
||||
virNetMessagePtr next;
|
||||
};
|
||||
|
@ -771,9 +771,11 @@ static ssize_t virNetServerClientRead(virNetServerClientPtr client)
|
||||
static void virNetServerClientDispatchRead(virNetServerClientPtr client)
|
||||
{
|
||||
readmore:
|
||||
if (virNetServerClientRead(client) < 0) {
|
||||
client->wantClose = true;
|
||||
return; /* Error */
|
||||
if (client->rx->nfds == 0) {
|
||||
if (virNetServerClientRead(client) < 0) {
|
||||
client->wantClose = true;
|
||||
return; /* Error */
|
||||
}
|
||||
}
|
||||
|
||||
if (client->rx->bufferOffset < client->rx->bufferLength)
|
||||
@ -794,7 +796,7 @@ readmore:
|
||||
goto readmore;
|
||||
} else {
|
||||
/* Grab the completed message */
|
||||
virNetMessagePtr msg = virNetMessageQueueServe(&client->rx);
|
||||
virNetMessagePtr msg = client->rx;
|
||||
virNetServerClientFilterPtr filter;
|
||||
size_t i;
|
||||
|
||||
@ -805,20 +807,40 @@ readmore:
|
||||
return;
|
||||
}
|
||||
|
||||
/* Now figure out if we need to read more data to get some
|
||||
* file descriptors */
|
||||
if (msg->header.type == VIR_NET_CALL_WITH_FDS &&
|
||||
virNetMessageDecodeNumFDs(msg) < 0) {
|
||||
virNetMessageFree(msg);
|
||||
client->wantClose = true;
|
||||
return;
|
||||
return; /* Error */
|
||||
}
|
||||
for (i = 0 ; i < msg->nfds ; i++) {
|
||||
if ((msg->fds[i] = virNetSocketRecvFD(client->sock)) < 0) {
|
||||
|
||||
/* Try getting the file descriptors (may fail if blocking) */
|
||||
for (i = msg->donefds ; i < msg->nfds ; i++) {
|
||||
int rv;
|
||||
if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
|
||||
virNetMessageFree(msg);
|
||||
client->wantClose = true;
|
||||
return;
|
||||
}
|
||||
if (rv == 0) /* Blocking */
|
||||
break;
|
||||
msg->donefds++;
|
||||
}
|
||||
|
||||
/* Need to poll() until FDs arrive */
|
||||
if (msg->donefds < msg->nfds) {
|
||||
/* Because DecodeHeader/NumFDs reset bufferOffset, we
|
||||
* put it back to what it was, so everything works
|
||||
* again next time we run this method
|
||||
*/
|
||||
client->rx->bufferOffset = client->rx->bufferLength;
|
||||
return;
|
||||
}
|
||||
|
||||
/* Definitely finished reading, so remove from queue */
|
||||
virNetMessageQueueServe(&client->rx);
|
||||
PROBE(RPC_SERVER_CLIENT_MSG_RX,
|
||||
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
|
||||
client, msg->bufferLength,
|
||||
@ -912,25 +934,30 @@ static void
|
||||
virNetServerClientDispatchWrite(virNetServerClientPtr client)
|
||||
{
|
||||
while (client->tx) {
|
||||
ssize_t ret;
|
||||
|
||||
ret = virNetServerClientWrite(client);
|
||||
if (ret < 0) {
|
||||
client->wantClose = true;
|
||||
return;
|
||||
if (client->tx->bufferOffset < client->tx->bufferLength) {
|
||||
ssize_t ret;
|
||||
ret = virNetServerClientWrite(client);
|
||||
if (ret < 0) {
|
||||
client->wantClose = true;
|
||||
return;
|
||||
}
|
||||
if (ret == 0)
|
||||
return; /* Would block on write EAGAIN */
|
||||
}
|
||||
if (ret == 0)
|
||||
return; /* Would block on write EAGAIN */
|
||||
|
||||
if (client->tx->bufferOffset == client->tx->bufferLength) {
|
||||
virNetMessagePtr msg;
|
||||
size_t i;
|
||||
|
||||
for (i = 0 ; i < client->tx->nfds ; i++) {
|
||||
if (virNetSocketSendFD(client->sock, client->tx->fds[i]) < 0) {
|
||||
for (i = client->tx->donefds ; i < client->tx->nfds ; i++) {
|
||||
int rv;
|
||||
if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
|
||||
client->wantClose = true;
|
||||
return;
|
||||
}
|
||||
if (rv == 0) /* Blocking */
|
||||
return;
|
||||
client->tx->donefds++;
|
||||
}
|
||||
|
||||
#if HAVE_SASL
|
||||
@ -1041,6 +1068,7 @@ int virNetServerClientSendMessage(virNetServerClientPtr client,
|
||||
msg->bufferLength, msg->bufferOffset);
|
||||
virNetServerClientLock(client);
|
||||
|
||||
msg->donefds = 0;
|
||||
if (client->sock && !client->wantClose) {
|
||||
PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
|
||||
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
|
||||
|
@ -1142,6 +1142,9 @@ ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Returns 1 if an FD was sent, 0 if it would block, -1 on error
|
||||
*/
|
||||
int virNetSocketSendFD(virNetSocketPtr sock, int fd)
|
||||
{
|
||||
int ret = -1;
|
||||
@ -1154,12 +1157,15 @@ int virNetSocketSendFD(virNetSocketPtr sock, int fd)
|
||||
PROBE(RPC_SOCKET_SEND_FD,
|
||||
"sock=%p fd=%d", sock, fd);
|
||||
if (sendfd(sock->fd, fd) < 0) {
|
||||
virReportSystemError(errno,
|
||||
_("Failed to send file descriptor %d"),
|
||||
fd);
|
||||
if (errno == EAGAIN)
|
||||
ret = 0;
|
||||
else
|
||||
virReportSystemError(errno,
|
||||
_("Failed to send file descriptor %d"),
|
||||
fd);
|
||||
goto cleanup;
|
||||
}
|
||||
ret = 0;
|
||||
ret = 1;
|
||||
|
||||
cleanup:
|
||||
virMutexUnlock(&sock->lock);
|
||||
@ -1167,9 +1173,15 @@ cleanup:
|
||||
}
|
||||
|
||||
|
||||
int virNetSocketRecvFD(virNetSocketPtr sock)
|
||||
/*
|
||||
* Returns 1 if an FD was read, 0 if it would block, -1 on error
|
||||
*/
|
||||
int virNetSocketRecvFD(virNetSocketPtr sock, int *fd)
|
||||
{
|
||||
int ret = -1;
|
||||
|
||||
*fd = -1;
|
||||
|
||||
if (!virNetSocketHasPassFD(sock)) {
|
||||
virNetError(VIR_ERR_INTERNAL_ERROR,
|
||||
_("Receiving file descriptors is not supported on this socket"));
|
||||
@ -1177,13 +1189,17 @@ int virNetSocketRecvFD(virNetSocketPtr sock)
|
||||
}
|
||||
virMutexLock(&sock->lock);
|
||||
|
||||
if ((ret = recvfd(sock->fd, O_CLOEXEC)) < 0) {
|
||||
virReportSystemError(errno, "%s",
|
||||
_("Failed to recv file descriptor"));
|
||||
if ((*fd = recvfd(sock->fd, O_CLOEXEC)) < 0) {
|
||||
if (errno == EAGAIN)
|
||||
ret = 0;
|
||||
else
|
||||
virReportSystemError(errno, "%s",
|
||||
_("Failed to recv file descriptor"));
|
||||
goto cleanup;
|
||||
}
|
||||
PROBE(RPC_SOCKET_RECV_FD,
|
||||
"sock=%p fd=%d", sock, ret);
|
||||
"sock=%p fd=%d", sock, *fd);
|
||||
ret = 1;
|
||||
|
||||
cleanup:
|
||||
virMutexUnlock(&sock->lock);
|
||||
|
@ -97,7 +97,7 @@ ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
|
||||
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
|
||||
|
||||
int virNetSocketSendFD(virNetSocketPtr sock, int fd);
|
||||
int virNetSocketRecvFD(virNetSocketPtr sock);
|
||||
int virNetSocketRecvFD(virNetSocketPtr sock, int *fd);
|
||||
|
||||
void virNetSocketSetTLSSession(virNetSocketPtr sock,
|
||||
virNetTLSSessionPtr sess);
|
||||
|
Loading…
x
Reference in New Issue
Block a user