diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c index 0e7e423866..a2587462a8 100644 --- a/src/rpc/virnetclient.c +++ b/src/rpc/virnetclient.c @@ -101,6 +101,7 @@ struct _virNetClient { virKeepAlivePtr keepalive; bool wantClose; + int closeReason; }; @@ -108,6 +109,8 @@ static void virNetClientIOEventLoopPassTheBuck(virNetClientPtr client, virNetClientCallPtr thiscall); static int virNetClientQueueNonBlocking(virNetClientPtr client, virNetMessagePtr msg); +static void virNetClientCloseInternal(virNetClientPtr client, + int reason); static void virNetClientLock(virNetClientPtr client) @@ -261,7 +264,7 @@ virNetClientKeepAliveStop(virNetClientPtr client) static void virNetClientKeepAliveDeadCB(void *opaque) { - virNetClientClose(opaque); + virNetClientCloseInternal(opaque, VIR_CONNECT_CLOSE_REASON_KEEPALIVE); } static int @@ -483,17 +486,27 @@ void virNetClientFree(virNetClientPtr client) } +static void +virNetClientMarkClose(virNetClientPtr client, + int reason) +{ + VIR_DEBUG("client=%p, reason=%d", client, reason); + virNetSocketRemoveIOCallback(client->sock); + client->wantClose = true; + client->closeReason = reason; +} + + static void virNetClientCloseLocked(virNetClientPtr client) { virKeepAlivePtr ka; - VIR_DEBUG("client=%p, sock=%p", client, client->sock); + VIR_DEBUG("client=%p, sock=%p, reason=%d", client, client->sock, client->closeReason); if (!client->sock) return; - virNetSocketRemoveIOCallback(client->sock); virNetSocketFree(client->sock); client->sock = NULL; virNetTLSSessionFree(client->tls); @@ -518,16 +531,21 @@ virNetClientCloseLocked(virNetClientPtr client) } } -void virNetClientClose(virNetClientPtr client) +static void virNetClientCloseInternal(virNetClientPtr client, + int reason) { VIR_DEBUG("client=%p", client); if (!client) return; + if (!client->sock || + client->wantClose) + return; + virNetClientLock(client); - client->wantClose = true; + virNetClientMarkClose(client, reason); /* If there is a thread polling for data on the socket, wake the thread up * otherwise try to pass the buck to a possibly waiting thread. If no @@ -548,6 +566,12 @@ void virNetClientClose(virNetClientPtr client) } +void virNetClientClose(virNetClientPtr client) +{ + virNetClientCloseInternal(client, VIR_CONNECT_CLOSE_REASON_CLIENT); +} + + #if HAVE_SASL void virNetClientSetSASLSession(virNetClientPtr client, virNetSASLSessionPtr sasl) @@ -1351,7 +1375,7 @@ static int virNetClientIOEventLoop(virNetClientPtr client, } if (virKeepAliveTrigger(client->keepalive, &msg)) { - client->wantClose = true; + virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_KEEPALIVE); } else if (msg && virNetClientQueueNonBlocking(client, msg) < 0) { VIR_WARN("Could not queue keepalive request"); virNetMessageFree(msg); @@ -1374,18 +1398,23 @@ static int virNetClientIOEventLoop(virNetClientPtr client, if (saferead(client->wakeupReadFD, &ignore, sizeof(ignore)) != sizeof(ignore)) { virReportSystemError(errno, "%s", _("read on wakeup fd failed")); + virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR); goto error; } } if (fds[0].revents & POLLOUT) { - if (virNetClientIOHandleOutput(client) < 0) + if (virNetClientIOHandleOutput(client) < 0) { + virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR); goto error; + } } if (fds[0].revents & POLLIN) { - if (virNetClientIOHandleInput(client) < 0) + if (virNetClientIOHandleInput(client) < 0) { + virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR); goto error; + } } /* Iterate through waiting calls and if any are @@ -1410,6 +1439,7 @@ static int virNetClientIOEventLoop(virNetClientPtr client, } if (fds[0].revents & (POLLHUP | POLLERR)) { + virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_EOF); virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("received hangup / error event on socket")); goto error; @@ -1441,6 +1471,9 @@ static void virNetClientIOUpdateCallback(virNetClientPtr client, { int events = 0; + if (client->wantClose) + return; + if (enableCallback) { events |= VIR_EVENT_HANDLE_READABLE; virNetClientCallMatchPredicate(client->waitDispatch, @@ -1623,6 +1656,8 @@ void virNetClientIncomingEvent(virNetSocketPtr sock, virNetClientLock(client); + VIR_DEBUG("client=%p wantclose=%d", client, client ? client->wantClose : false); + if (!client->sock) goto done; @@ -1635,18 +1670,21 @@ void virNetClientIncomingEvent(virNetSocketPtr sock, if (events & (VIR_EVENT_HANDLE_HANGUP | VIR_EVENT_HANDLE_ERROR)) { VIR_DEBUG("%s : VIR_EVENT_HANDLE_HANGUP or " "VIR_EVENT_HANDLE_ERROR encountered", __FUNCTION__); - virNetSocketRemoveIOCallback(sock); + virNetClientMarkClose(client, + (events & VIR_EVENT_HANDLE_HANGUP) ? + VIR_CONNECT_CLOSE_REASON_EOF : + VIR_CONNECT_CLOSE_REASON_ERROR); goto done; } if (events & VIR_EVENT_HANDLE_WRITABLE) { if (virNetClientIOHandleOutput(client) < 0) - virNetSocketRemoveIOCallback(sock); + virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR); } if (events & VIR_EVENT_HANDLE_READABLE) { if (virNetClientIOHandleInput(client) < 0) - virNetSocketRemoveIOCallback(sock); + virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR); } /* Remove completed calls or signal their threads. */ @@ -1656,6 +1694,8 @@ void virNetClientIncomingEvent(virNetSocketPtr sock, virNetClientIOUpdateCallback(client, true); done: + if (client->wantClose) + virNetClientCloseLocked(client); virNetClientUnlock(client); }