Fix sending/receiving of FDs when stream returns EAGAIN

The code calling sendfd/recvfd was mistakenly assuming those
calls would never block. They can in fact return EAGAIN and
this is causing us to drop the client connection when blocking
ocurrs while sending/receiving FDs.

Fixing this is a little hairy on the incoming side, since at
the point where we see the EAGAIN, we already thought we had
finished receiving all data for the packet. So we play a little
trick to reset bufferOffset again and go back into polling for
more data.

* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Update
  virNetSocketSendFD/RecvFD to return 0 on EAGAIN, or 1
  on success
* src/rpc/virnetclient.c: Move decoding of header & fds
  out of virNetClientCallDispatch and into virNetClientIOHandleInput.
  Handling blocking when sending/receiving FDs
* src/rpc/virnetmessage.h: Add a 'donefds' field to track
  how many FDs we've sent / received
* src/rpc/virnetserverclient.c: Handling blocking when
  sending/receiving FDs
This commit is contained in:
Daniel P. Berrange 2011-11-04 16:02:14 +00:00
parent 4d970fd293
commit b2c6231647
5 changed files with 125 additions and 53 deletions

View File

@ -694,10 +694,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
static int
virNetClientCallDispatch(virNetClientPtr client)
{
size_t i;
if (virNetMessageDecodeHeader(&client->msg) < 0)
return -1;
PROBE(RPC_CLIENT_MSG_RX,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
client, client->msg.bufferLength,
@ -706,15 +702,7 @@ virNetClientCallDispatch(virNetClientPtr client)
switch (client->msg.header.type) {
case VIR_NET_REPLY: /* Normal RPC replies */
return virNetClientCallDispatchReply(client);
case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */
if (virNetMessageDecodeNumFDs(&client->msg) < 0)
return -1;
for (i = 0 ; i < client->msg.nfds ; i++) {
if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0)
return -1;
}
return virNetClientCallDispatchReply(client);
case VIR_NET_MESSAGE: /* Async notifications */
@ -737,22 +725,29 @@ static ssize_t
virNetClientIOWriteMessage(virNetClientPtr client,
virNetClientCallPtr thecall)
{
ssize_t ret;
ssize_t ret = 0;
ret = virNetSocketWrite(client->sock,
thecall->msg->buffer + thecall->msg->bufferOffset,
thecall->msg->bufferLength - thecall->msg->bufferOffset);
if (ret <= 0)
return ret;
if (thecall->msg->bufferOffset < thecall->msg->bufferLength) {
ret = virNetSocketWrite(client->sock,
thecall->msg->buffer + thecall->msg->bufferOffset,
thecall->msg->bufferLength - thecall->msg->bufferOffset);
if (ret <= 0)
return ret;
thecall->msg->bufferOffset += ret;
thecall->msg->bufferOffset += ret;
}
if (thecall->msg->bufferOffset == thecall->msg->bufferLength) {
size_t i;
for (i = 0 ; i < thecall->msg->nfds ; i++) {
if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0)
for (i = thecall->msg->donefds ; i < thecall->msg->nfds ; i++) {
int rv;
if ((rv = virNetSocketSendFD(client->sock, thecall->msg->fds[i])) < 0)
return -1;
if (rv == 0) /* Blocking */
return 0;
thecall->msg->donefds++;
}
thecall->msg->donefds = 0;
thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
if (thecall->expectReply)
thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
@ -821,12 +816,16 @@ virNetClientIOHandleInput(virNetClientPtr client)
* EAGAIN
*/
for (;;) {
ssize_t ret = virNetClientIOReadMessage(client);
ssize_t ret;
if (ret < 0)
return -1;
if (ret == 0)
return 0; /* Blocking on read */
if (client->msg.nfds == 0) {
ret = virNetClientIOReadMessage(client);
if (ret < 0)
return -1;
if (ret == 0)
return 0; /* Blocking on read */
}
/* Check for completion of our goal */
if (client->msg.bufferOffset == client->msg.bufferLength) {
@ -842,6 +841,33 @@ virNetClientIOHandleInput(virNetClientPtr client)
* next iteration.
*/
} else {
if (virNetMessageDecodeHeader(&client->msg) < 0)
return -1;
if (client->msg.header.type == VIR_NET_REPLY_WITH_FDS) {
size_t i;
if (virNetMessageDecodeNumFDs(&client->msg) < 0)
return -1;
for (i = client->msg.donefds ; i < client->msg.nfds ; i++) {
int rv;
if ((rv = virNetSocketRecvFD(client->sock, &(client->msg.fds[i]))) < 0)
return -1;
if (rv == 0) /* Blocking */
break;
client->msg.donefds++;
}
if (client->msg.donefds < client->msg.nfds) {
/* Because DecodeHeader/NumFDs reset bufferOffset, we
* put it back to what it was, so everything works
* again next time we run this method
*/
client->msg.bufferOffset = client->msg.bufferLength;
return 0; /* Blocking on more fds */
}
}
ret = virNetClientCallDispatch(client);
client->msg.bufferOffset = client->msg.bufferLength = 0;
/*
@ -1257,6 +1283,7 @@ int virNetClientSend(virNetClientPtr client,
goto cleanup;
}
msg->donefds = 0;
if (msg->bufferLength)
call->mode = VIR_NET_CLIENT_MODE_WAIT_TX;
else

View File

@ -48,6 +48,7 @@ struct _virNetMessage {
size_t nfds;
int *fds;
size_t donefds;
virNetMessagePtr next;
};

View File

@ -771,9 +771,11 @@ static ssize_t virNetServerClientRead(virNetServerClientPtr client)
static void virNetServerClientDispatchRead(virNetServerClientPtr client)
{
readmore:
if (virNetServerClientRead(client) < 0) {
client->wantClose = true;
return; /* Error */
if (client->rx->nfds == 0) {
if (virNetServerClientRead(client) < 0) {
client->wantClose = true;
return; /* Error */
}
}
if (client->rx->bufferOffset < client->rx->bufferLength)
@ -794,7 +796,7 @@ readmore:
goto readmore;
} else {
/* Grab the completed message */
virNetMessagePtr msg = virNetMessageQueueServe(&client->rx);
virNetMessagePtr msg = client->rx;
virNetServerClientFilterPtr filter;
size_t i;
@ -805,20 +807,40 @@ readmore:
return;
}
/* Now figure out if we need to read more data to get some
* file descriptors */
if (msg->header.type == VIR_NET_CALL_WITH_FDS &&
virNetMessageDecodeNumFDs(msg) < 0) {
virNetMessageFree(msg);
client->wantClose = true;
return;
return; /* Error */
}
for (i = 0 ; i < msg->nfds ; i++) {
if ((msg->fds[i] = virNetSocketRecvFD(client->sock)) < 0) {
/* Try getting the file descriptors (may fail if blocking) */
for (i = msg->donefds ; i < msg->nfds ; i++) {
int rv;
if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
virNetMessageFree(msg);
client->wantClose = true;
return;
}
if (rv == 0) /* Blocking */
break;
msg->donefds++;
}
/* Need to poll() until FDs arrive */
if (msg->donefds < msg->nfds) {
/* Because DecodeHeader/NumFDs reset bufferOffset, we
* put it back to what it was, so everything works
* again next time we run this method
*/
client->rx->bufferOffset = client->rx->bufferLength;
return;
}
/* Definitely finished reading, so remove from queue */
virNetMessageQueueServe(&client->rx);
PROBE(RPC_SERVER_CLIENT_MSG_RX,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
client, msg->bufferLength,
@ -912,25 +934,30 @@ static void
virNetServerClientDispatchWrite(virNetServerClientPtr client)
{
while (client->tx) {
ssize_t ret;
ret = virNetServerClientWrite(client);
if (ret < 0) {
client->wantClose = true;
return;
if (client->tx->bufferOffset < client->tx->bufferLength) {
ssize_t ret;
ret = virNetServerClientWrite(client);
if (ret < 0) {
client->wantClose = true;
return;
}
if (ret == 0)
return; /* Would block on write EAGAIN */
}
if (ret == 0)
return; /* Would block on write EAGAIN */
if (client->tx->bufferOffset == client->tx->bufferLength) {
virNetMessagePtr msg;
size_t i;
for (i = 0 ; i < client->tx->nfds ; i++) {
if (virNetSocketSendFD(client->sock, client->tx->fds[i]) < 0) {
for (i = client->tx->donefds ; i < client->tx->nfds ; i++) {
int rv;
if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
client->wantClose = true;
return;
}
if (rv == 0) /* Blocking */
return;
client->tx->donefds++;
}
#if HAVE_SASL
@ -1041,6 +1068,7 @@ int virNetServerClientSendMessage(virNetServerClientPtr client,
msg->bufferLength, msg->bufferOffset);
virNetServerClientLock(client);
msg->donefds = 0;
if (client->sock && !client->wantClose) {
PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",

View File

@ -1142,6 +1142,9 @@ ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
}
/*
* Returns 1 if an FD was sent, 0 if it would block, -1 on error
*/
int virNetSocketSendFD(virNetSocketPtr sock, int fd)
{
int ret = -1;
@ -1154,12 +1157,15 @@ int virNetSocketSendFD(virNetSocketPtr sock, int fd)
PROBE(RPC_SOCKET_SEND_FD,
"sock=%p fd=%d", sock, fd);
if (sendfd(sock->fd, fd) < 0) {
virReportSystemError(errno,
_("Failed to send file descriptor %d"),
fd);
if (errno == EAGAIN)
ret = 0;
else
virReportSystemError(errno,
_("Failed to send file descriptor %d"),
fd);
goto cleanup;
}
ret = 0;
ret = 1;
cleanup:
virMutexUnlock(&sock->lock);
@ -1167,9 +1173,15 @@ cleanup:
}
int virNetSocketRecvFD(virNetSocketPtr sock)
/*
* Returns 1 if an FD was read, 0 if it would block, -1 on error
*/
int virNetSocketRecvFD(virNetSocketPtr sock, int *fd)
{
int ret = -1;
*fd = -1;
if (!virNetSocketHasPassFD(sock)) {
virNetError(VIR_ERR_INTERNAL_ERROR,
_("Receiving file descriptors is not supported on this socket"));
@ -1177,13 +1189,17 @@ int virNetSocketRecvFD(virNetSocketPtr sock)
}
virMutexLock(&sock->lock);
if ((ret = recvfd(sock->fd, O_CLOEXEC)) < 0) {
virReportSystemError(errno, "%s",
_("Failed to recv file descriptor"));
if ((*fd = recvfd(sock->fd, O_CLOEXEC)) < 0) {
if (errno == EAGAIN)
ret = 0;
else
virReportSystemError(errno, "%s",
_("Failed to recv file descriptor"));
goto cleanup;
}
PROBE(RPC_SOCKET_RECV_FD,
"sock=%p fd=%d", sock, ret);
"sock=%p fd=%d", sock, *fd);
ret = 1;
cleanup:
virMutexUnlock(&sock->lock);

View File

@ -97,7 +97,7 @@ ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
int virNetSocketSendFD(virNetSocketPtr sock, int fd);
int virNetSocketRecvFD(virNetSocketPtr sock);
int virNetSocketRecvFD(virNetSocketPtr sock, int *fd);
void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess);