From 4e006b844f40e70a1ffdafe0115cd4277ad9916d Mon Sep 17 00:00:00 2001 From: Eric Blake Date: Mon, 1 Aug 2011 13:41:38 -0600 Subject: [PATCH] rpc: avoid libvirtd crash on unexpected client close Steps to reproduce this problem (vm1 is not running): for i in `seq 50`; do virsh managedsave vm1& done; killall virsh Pre-patch, virNetServerClientClose could end up setting client->sock to NULL prior to other cleanup functions trying to use client->sock. This fixes things by checking for NULL in more places, and by deferring the cleanup until after all queued messages have been served. * src/rpc/virnetserverclient.c (virNetServerClientRegisterEvent) (virNetServerClientGetFD, virNetServerClientIsSecure) (virNetServerClientLocalAddrString) (virNetServerClientRemoteAddrString): Check for closed socket. (virNetServerClientClose): Rearrange close sequence. Analysis from Wen Congyang. --- src/rpc/virnetserverclient.c | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index 3c0dba8d65..2f6c040762 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -177,7 +177,8 @@ static int virNetServerClientRegisterEvent(virNetServerClientPtr client) client->refs++; VIR_DEBUG("Registering client event callback %d", mode); - if (virNetSocketAddIOCallback(client->sock, + if (!client->sock || + virNetSocketAddIOCallback(client->sock, mode, virNetServerClientDispatchEvent, client, @@ -386,9 +387,10 @@ int virNetServerClientGetTLSKeySize(virNetServerClientPtr client) int virNetServerClientGetFD(virNetServerClientPtr client) { - int fd = 0; + int fd = -1; virNetServerClientLock(client); - fd = virNetSocketGetFD(client->sock); + if (client->sock) + fd = virNetSocketGetFD(client->sock); virNetServerClientUnlock(client); return fd; } @@ -396,9 +398,10 @@ int virNetServerClientGetFD(virNetServerClientPtr client) int virNetServerClientGetLocalIdentity(virNetServerClientPtr client, uid_t *uid, pid_t *pid) { - int ret; + int ret = -1; virNetServerClientLock(client); - ret = virNetSocketGetLocalIdentity(client->sock, uid, pid); + if (client->sock) + ret = virNetSocketGetLocalIdentity(client->sock, uid, pid); virNetServerClientUnlock(client); return ret; } @@ -413,7 +416,7 @@ bool virNetServerClientIsSecure(virNetServerClientPtr client) if (client->sasl) secure = true; #endif - if (virNetSocketIsLocal(client->sock)) + if (client->sock && virNetSocketIsLocal(client->sock)) secure = true; virNetServerClientUnlock(client); return secure; @@ -502,12 +505,16 @@ void virNetServerClientSetDispatcher(virNetServerClientPtr client, const char *virNetServerClientLocalAddrString(virNetServerClientPtr client) { + if (!client->sock) + return NULL; return virNetSocketLocalAddrString(client->sock); } const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client) { + if (!client->sock) + return NULL; return virNetSocketRemoteAddrString(client->sock); } @@ -570,10 +577,7 @@ void virNetServerClientClose(virNetServerClientPtr client) virNetTLSSessionFree(client->tls); client->tls = NULL; } - if (client->sock) { - virNetSocketFree(client->sock); - client->sock = NULL; - } + client->wantClose = true; while (client->rx) { virNetMessagePtr msg @@ -586,6 +590,11 @@ void virNetServerClientClose(virNetServerClientPtr client) virNetMessageFree(msg); } + if (client->sock) { + virNetSocketFree(client->sock); + client->sock = NULL; + } + virNetServerClientUnlock(client); }