rpc: avoid libvirtd crash on unexpected client close

Steps to reproduce this problem (vm1 is not running):
for i in `seq 50`; do virsh managedsave vm1& done; killall virsh

Pre-patch, virNetServerClientClose could end up setting client->sock
to NULL prior to other cleanup functions trying to use client->sock.
This fixes things by checking for NULL in more places, and by deferring
the cleanup until after all queued messages have been served.

* src/rpc/virnetserverclient.c (virNetServerClientRegisterEvent)
(virNetServerClientGetFD, virNetServerClientIsSecure)
(virNetServerClientLocalAddrString)
(virNetServerClientRemoteAddrString): Check for closed socket.
(virNetServerClientClose): Rearrange close sequence.
Analysis from Wen Congyang.
This commit is contained in:
Eric Blake 2011-08-01 13:41:38 -06:00
parent 22da8c941c
commit 4e006b844f

View File

@ -177,7 +177,8 @@ static int virNetServerClientRegisterEvent(virNetServerClientPtr client)
client->refs++; client->refs++;
VIR_DEBUG("Registering client event callback %d", mode); VIR_DEBUG("Registering client event callback %d", mode);
if (virNetSocketAddIOCallback(client->sock, if (!client->sock ||
virNetSocketAddIOCallback(client->sock,
mode, mode,
virNetServerClientDispatchEvent, virNetServerClientDispatchEvent,
client, client,
@ -386,8 +387,9 @@ int virNetServerClientGetTLSKeySize(virNetServerClientPtr client)
int virNetServerClientGetFD(virNetServerClientPtr client) int virNetServerClientGetFD(virNetServerClientPtr client)
{ {
int fd = 0; int fd = -1;
virNetServerClientLock(client); virNetServerClientLock(client);
if (client->sock)
fd = virNetSocketGetFD(client->sock); fd = virNetSocketGetFD(client->sock);
virNetServerClientUnlock(client); virNetServerClientUnlock(client);
return fd; return fd;
@ -396,8 +398,9 @@ int virNetServerClientGetFD(virNetServerClientPtr client)
int virNetServerClientGetLocalIdentity(virNetServerClientPtr client, int virNetServerClientGetLocalIdentity(virNetServerClientPtr client,
uid_t *uid, pid_t *pid) uid_t *uid, pid_t *pid)
{ {
int ret; int ret = -1;
virNetServerClientLock(client); virNetServerClientLock(client);
if (client->sock)
ret = virNetSocketGetLocalIdentity(client->sock, uid, pid); ret = virNetSocketGetLocalIdentity(client->sock, uid, pid);
virNetServerClientUnlock(client); virNetServerClientUnlock(client);
return ret; return ret;
@ -413,7 +416,7 @@ bool virNetServerClientIsSecure(virNetServerClientPtr client)
if (client->sasl) if (client->sasl)
secure = true; secure = true;
#endif #endif
if (virNetSocketIsLocal(client->sock)) if (client->sock && virNetSocketIsLocal(client->sock))
secure = true; secure = true;
virNetServerClientUnlock(client); virNetServerClientUnlock(client);
return secure; return secure;
@ -502,12 +505,16 @@ void virNetServerClientSetDispatcher(virNetServerClientPtr client,
const char *virNetServerClientLocalAddrString(virNetServerClientPtr client) const char *virNetServerClientLocalAddrString(virNetServerClientPtr client)
{ {
if (!client->sock)
return NULL;
return virNetSocketLocalAddrString(client->sock); return virNetSocketLocalAddrString(client->sock);
} }
const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client) const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client)
{ {
if (!client->sock)
return NULL;
return virNetSocketRemoteAddrString(client->sock); return virNetSocketRemoteAddrString(client->sock);
} }
@ -570,10 +577,7 @@ void virNetServerClientClose(virNetServerClientPtr client)
virNetTLSSessionFree(client->tls); virNetTLSSessionFree(client->tls);
client->tls = NULL; client->tls = NULL;
} }
if (client->sock) { client->wantClose = true;
virNetSocketFree(client->sock);
client->sock = NULL;
}
while (client->rx) { while (client->rx) {
virNetMessagePtr msg virNetMessagePtr msg
@ -586,6 +590,11 @@ void virNetServerClientClose(virNetServerClientPtr client)
virNetMessageFree(msg); virNetMessageFree(msg);
} }
if (client->sock) {
virNetSocketFree(client->sock);
client->sock = NULL;
}
virNetServerClientUnlock(client); virNetServerClientUnlock(client);
} }