Ensure client is marked for close in all error paths

Currently if the keepalive timer triggers, the 'markClose'
flag is set on the virNetClient. A controlled shutdown will
then be performed. If an I/O error occurs during read or
write of the connection an error is raised back to the
caller, but the connection isn't marked for close. This
patch ensures that all I/O error scenarios always result
in the connection being marked for close.

Signed-off-by: Daniel P. Berrange <berrange@redhat.com>
This commit is contained in:
Daniel P. Berrange 2012-07-19 11:21:54 +01:00
parent 6ed5a1b9bd
commit e5a1bee07a

View File

@ -101,6 +101,7 @@ struct _virNetClient {
virKeepAlivePtr keepalive; virKeepAlivePtr keepalive;
bool wantClose; bool wantClose;
int closeReason;
}; };
@ -108,6 +109,8 @@ static void virNetClientIOEventLoopPassTheBuck(virNetClientPtr client,
virNetClientCallPtr thiscall); virNetClientCallPtr thiscall);
static int virNetClientQueueNonBlocking(virNetClientPtr client, static int virNetClientQueueNonBlocking(virNetClientPtr client,
virNetMessagePtr msg); virNetMessagePtr msg);
static void virNetClientCloseInternal(virNetClientPtr client,
int reason);
static void virNetClientLock(virNetClientPtr client) static void virNetClientLock(virNetClientPtr client)
@ -261,7 +264,7 @@ virNetClientKeepAliveStop(virNetClientPtr client)
static void static void
virNetClientKeepAliveDeadCB(void *opaque) virNetClientKeepAliveDeadCB(void *opaque)
{ {
virNetClientClose(opaque); virNetClientCloseInternal(opaque, VIR_CONNECT_CLOSE_REASON_KEEPALIVE);
} }
static int static int
@ -483,17 +486,27 @@ void virNetClientFree(virNetClientPtr client)
} }
static void
virNetClientMarkClose(virNetClientPtr client,
int reason)
{
VIR_DEBUG("client=%p, reason=%d", client, reason);
virNetSocketRemoveIOCallback(client->sock);
client->wantClose = true;
client->closeReason = reason;
}
static void static void
virNetClientCloseLocked(virNetClientPtr client) virNetClientCloseLocked(virNetClientPtr client)
{ {
virKeepAlivePtr ka; virKeepAlivePtr ka;
VIR_DEBUG("client=%p, sock=%p", client, client->sock); VIR_DEBUG("client=%p, sock=%p, reason=%d", client, client->sock, client->closeReason);
if (!client->sock) if (!client->sock)
return; return;
virNetSocketRemoveIOCallback(client->sock);
virNetSocketFree(client->sock); virNetSocketFree(client->sock);
client->sock = NULL; client->sock = NULL;
virNetTLSSessionFree(client->tls); virNetTLSSessionFree(client->tls);
@ -518,16 +531,21 @@ virNetClientCloseLocked(virNetClientPtr client)
} }
} }
void virNetClientClose(virNetClientPtr client) static void virNetClientCloseInternal(virNetClientPtr client,
int reason)
{ {
VIR_DEBUG("client=%p", client); VIR_DEBUG("client=%p", client);
if (!client) if (!client)
return; return;
if (!client->sock ||
client->wantClose)
return;
virNetClientLock(client); virNetClientLock(client);
client->wantClose = true; virNetClientMarkClose(client, reason);
/* If there is a thread polling for data on the socket, wake the thread up /* If there is a thread polling for data on the socket, wake the thread up
* otherwise try to pass the buck to a possibly waiting thread. If no * otherwise try to pass the buck to a possibly waiting thread. If no
@ -548,6 +566,12 @@ void virNetClientClose(virNetClientPtr client)
} }
void virNetClientClose(virNetClientPtr client)
{
virNetClientCloseInternal(client, VIR_CONNECT_CLOSE_REASON_CLIENT);
}
#if HAVE_SASL #if HAVE_SASL
void virNetClientSetSASLSession(virNetClientPtr client, void virNetClientSetSASLSession(virNetClientPtr client,
virNetSASLSessionPtr sasl) virNetSASLSessionPtr sasl)
@ -1351,7 +1375,7 @@ static int virNetClientIOEventLoop(virNetClientPtr client,
} }
if (virKeepAliveTrigger(client->keepalive, &msg)) { if (virKeepAliveTrigger(client->keepalive, &msg)) {
client->wantClose = true; virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_KEEPALIVE);
} else if (msg && virNetClientQueueNonBlocking(client, msg) < 0) { } else if (msg && virNetClientQueueNonBlocking(client, msg) < 0) {
VIR_WARN("Could not queue keepalive request"); VIR_WARN("Could not queue keepalive request");
virNetMessageFree(msg); virNetMessageFree(msg);
@ -1374,18 +1398,23 @@ static int virNetClientIOEventLoop(virNetClientPtr client,
if (saferead(client->wakeupReadFD, &ignore, sizeof(ignore)) != sizeof(ignore)) { if (saferead(client->wakeupReadFD, &ignore, sizeof(ignore)) != sizeof(ignore)) {
virReportSystemError(errno, "%s", virReportSystemError(errno, "%s",
_("read on wakeup fd failed")); _("read on wakeup fd failed"));
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
goto error; goto error;
} }
} }
if (fds[0].revents & POLLOUT) { if (fds[0].revents & POLLOUT) {
if (virNetClientIOHandleOutput(client) < 0) if (virNetClientIOHandleOutput(client) < 0) {
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
goto error; goto error;
}
} }
if (fds[0].revents & POLLIN) { if (fds[0].revents & POLLIN) {
if (virNetClientIOHandleInput(client) < 0) if (virNetClientIOHandleInput(client) < 0) {
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
goto error; goto error;
}
} }
/* Iterate through waiting calls and if any are /* Iterate through waiting calls and if any are
@ -1410,6 +1439,7 @@ static int virNetClientIOEventLoop(virNetClientPtr client,
} }
if (fds[0].revents & (POLLHUP | POLLERR)) { if (fds[0].revents & (POLLHUP | POLLERR)) {
virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_EOF);
virReportError(VIR_ERR_INTERNAL_ERROR, "%s", virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
_("received hangup / error event on socket")); _("received hangup / error event on socket"));
goto error; goto error;
@ -1441,6 +1471,9 @@ static void virNetClientIOUpdateCallback(virNetClientPtr client,
{ {
int events = 0; int events = 0;
if (client->wantClose)
return;
if (enableCallback) { if (enableCallback) {
events |= VIR_EVENT_HANDLE_READABLE; events |= VIR_EVENT_HANDLE_READABLE;
virNetClientCallMatchPredicate(client->waitDispatch, virNetClientCallMatchPredicate(client->waitDispatch,
@ -1623,6 +1656,8 @@ void virNetClientIncomingEvent(virNetSocketPtr sock,
virNetClientLock(client); virNetClientLock(client);
VIR_DEBUG("client=%p wantclose=%d", client, client ? client->wantClose : false);
if (!client->sock) if (!client->sock)
goto done; goto done;
@ -1635,18 +1670,21 @@ void virNetClientIncomingEvent(virNetSocketPtr sock,
if (events & (VIR_EVENT_HANDLE_HANGUP | VIR_EVENT_HANDLE_ERROR)) { if (events & (VIR_EVENT_HANDLE_HANGUP | VIR_EVENT_HANDLE_ERROR)) {
VIR_DEBUG("%s : VIR_EVENT_HANDLE_HANGUP or " VIR_DEBUG("%s : VIR_EVENT_HANDLE_HANGUP or "
"VIR_EVENT_HANDLE_ERROR encountered", __FUNCTION__); "VIR_EVENT_HANDLE_ERROR encountered", __FUNCTION__);
virNetSocketRemoveIOCallback(sock); virNetClientMarkClose(client,
(events & VIR_EVENT_HANDLE_HANGUP) ?
VIR_CONNECT_CLOSE_REASON_EOF :
VIR_CONNECT_CLOSE_REASON_ERROR);
goto done; goto done;
} }
if (events & VIR_EVENT_HANDLE_WRITABLE) { if (events & VIR_EVENT_HANDLE_WRITABLE) {
if (virNetClientIOHandleOutput(client) < 0) if (virNetClientIOHandleOutput(client) < 0)
virNetSocketRemoveIOCallback(sock); virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
} }
if (events & VIR_EVENT_HANDLE_READABLE) { if (events & VIR_EVENT_HANDLE_READABLE) {
if (virNetClientIOHandleInput(client) < 0) if (virNetClientIOHandleInput(client) < 0)
virNetSocketRemoveIOCallback(sock); virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
} }
/* Remove completed calls or signal their threads. */ /* Remove completed calls or signal their threads. */
@ -1656,6 +1694,8 @@ void virNetClientIncomingEvent(virNetSocketPtr sock,
virNetClientIOUpdateCallback(client, true); virNetClientIOUpdateCallback(client, true);
done: done:
if (client->wantClose)
virNetClientCloseLocked(client);
virNetClientUnlock(client); virNetClientUnlock(client);
} }