Ensure client is marked for close in all error paths

Currently if the keepalive timer triggers, the 'markClose'
flag is set on the virNetClient. A controlled shutdown will
then be performed. If an I/O error occurs during read or
write of the connection an error is raised back to the
caller, but the connection isn't marked for close. This
patch ensures that all I/O error scenarios always result
in the connection being marked for close.

Signed-off-by: Daniel P. Berrange <berrange@redhat.com>
This commit is contained in:
Daniel P. Berrange 2012-07-19 11:21:54 +01:00
parent 6ed5a1b9bd
commit e5a1bee07a

View File

@ -101,6 +101,7 @@ struct _virNetClient {
virKeepAlivePtr keepalive;
bool wantClose;
int closeReason;
};
@ -108,6 +109,8 @@ static void virNetClientIOEventLoopPassTheBuck(virNetClientPtr client,
virNetClientCallPtr thiscall);
static int virNetClientQueueNonBlocking(virNetClientPtr client,
virNetMessagePtr msg);
static void virNetClientCloseInternal(virNetClientPtr client,
int reason);
static void virNetClientLock(virNetClientPtr client)
@ -261,7 +264,7 @@ virNetClientKeepAliveStop(virNetClientPtr client)
static void
virNetClientKeepAliveDeadCB(void *opaque)
{
virNetClientClose(opaque);
virNetClientCloseInternal(opaque, VIR_CONNECT_CLOSE_REASON_KEEPALIVE);
}
static int
@ -483,17 +486,27 @@ void virNetClientFree(virNetClientPtr client)
}
static void
virNetClientMarkClose(virNetClientPtr client,
int reason)
{
VIR_DEBUG("client=%p, reason=%d", client, reason);
virNetSocketRemoveIOCallback(client->sock);
client->wantClose = true;
client->closeReason = reason;
}
static void
virNetClientCloseLocked(virNetClientPtr client)
{
virKeepAlivePtr ka;
VIR_DEBUG("client=%p, sock=%p", client, client->sock);
VIR_DEBUG("client=%p, sock=%p, reason=%d", client, client->sock, client->closeReason);
if (!client->sock)
return;
virNetSocketRemoveIOCallback(client->sock);
virNetSocketFree(client->sock);
client->sock = NULL;
virNetTLSSessionFree(client->tls);
@ -518,16 +531,21 @@ virNetClientCloseLocked(virNetClientPtr client)
}
}
void virNetClientClose(virNetClientPtr client)
static void virNetClientCloseInternal(virNetClientPtr client,
int reason)
{
VIR_DEBUG("client=%p", client);
if (!client)
return;
if (!client->sock ||
client->wantClose)
return;
virNetClientLock(client);
client->wantClose = true;
virNetClientMarkClose(client, reason);
/* If there is a thread polling for data on the socket, wake the thread up
* otherwise try to pass the buck to a possibly waiting thread. If no
@ -548,6 +566,12 @@ void virNetClientClose(virNetClientPtr client)
}
void virNetClientClose(virNetClientPtr client)
{
virNetClientCloseInternal(client, VIR_CONNECT_CLOSE_REASON_CLIENT);
}
#if HAVE_SASL
void virNetClientSetSASLSession(virNetClientPtr client,
virNetSASLSessionPtr sasl)
@ -1351,7 +1375,7 @@ static int virNetClientIOEventLoop(virNetClientPtr client,
}
if (virKeepAliveTrigger(client->keepalive, &msg)) {
client->wantClose = true;
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_KEEPALIVE);
} else if (msg && virNetClientQueueNonBlocking(client, msg) < 0) {
VIR_WARN("Could not queue keepalive request");
virNetMessageFree(msg);
@ -1374,18 +1398,23 @@ static int virNetClientIOEventLoop(virNetClientPtr client,
if (saferead(client->wakeupReadFD, &ignore, sizeof(ignore)) != sizeof(ignore)) {
virReportSystemError(errno, "%s",
_("read on wakeup fd failed"));
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
goto error;
}
}
if (fds[0].revents & POLLOUT) {
if (virNetClientIOHandleOutput(client) < 0)
if (virNetClientIOHandleOutput(client) < 0) {
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
goto error;
}
}
if (fds[0].revents & POLLIN) {
if (virNetClientIOHandleInput(client) < 0)
if (virNetClientIOHandleInput(client) < 0) {
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
goto error;
}
}
/* Iterate through waiting calls and if any are
@ -1410,6 +1439,7 @@ static int virNetClientIOEventLoop(virNetClientPtr client,
}
if (fds[0].revents & (POLLHUP | POLLERR)) {
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_EOF);
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
_("received hangup / error event on socket"));
goto error;
@ -1441,6 +1471,9 @@ static void virNetClientIOUpdateCallback(virNetClientPtr client,
{
int events = 0;
if (client->wantClose)
return;
if (enableCallback) {
events |= VIR_EVENT_HANDLE_READABLE;
virNetClientCallMatchPredicate(client->waitDispatch,
@ -1623,6 +1656,8 @@ void virNetClientIncomingEvent(virNetSocketPtr sock,
virNetClientLock(client);
VIR_DEBUG("client=%p wantclose=%d", client, client ? client->wantClose : false);
if (!client->sock)
goto done;
@ -1635,18 +1670,21 @@ void virNetClientIncomingEvent(virNetSocketPtr sock,
if (events & (VIR_EVENT_HANDLE_HANGUP | VIR_EVENT_HANDLE_ERROR)) {
VIR_DEBUG("%s : VIR_EVENT_HANDLE_HANGUP or "
"VIR_EVENT_HANDLE_ERROR encountered", __FUNCTION__);
virNetSocketRemoveIOCallback(sock);
virNetClientMarkClose(client,
(events & VIR_EVENT_HANDLE_HANGUP) ?
VIR_CONNECT_CLOSE_REASON_EOF :
VIR_CONNECT_CLOSE_REASON_ERROR);
goto done;
}
if (events & VIR_EVENT_HANDLE_WRITABLE) {
if (virNetClientIOHandleOutput(client) < 0)
virNetSocketRemoveIOCallback(sock);
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
}
if (events & VIR_EVENT_HANDLE_READABLE) {
if (virNetClientIOHandleInput(client) < 0)
virNetSocketRemoveIOCallback(sock);
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
}
/* Remove completed calls or signal their threads. */
@ -1656,6 +1694,8 @@ void virNetClientIncomingEvent(virNetSocketPtr sock,
virNetClientIOUpdateCallback(client, true);
done:
if (client->wantClose)
virNetClientCloseLocked(client);
virNetClientUnlock(client);
}