Add mutex locking and reference counting to virNetSocket

Remove the need for a virNetSocket object to be protected by
locks from the object using it, by introducing its own native
locking and reference counting

* src/rpc/virnetsocket.c: Add locking & reference counting
This commit is contained in:
Daniel P. Berrange 2011-07-19 14:00:24 +01:00
parent 06c0d1841c
commit 6198f3a1d7

View File

@ -40,6 +40,7 @@
#include "logging.h" #include "logging.h"
#include "files.h" #include "files.h"
#include "event.h" #include "event.h"
#include "threads.h"
#define VIR_FROM_THIS VIR_FROM_RPC #define VIR_FROM_THIS VIR_FROM_RPC
@ -49,6 +50,9 @@
struct _virNetSocket { struct _virNetSocket {
virMutex lock;
int refs;
int fd; int fd;
int watch; int watch;
pid_t pid; pid_t pid;
@ -122,6 +126,14 @@ static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr,
return NULL; return NULL;
} }
if (virMutexInit(&sock->lock) < 0) {
virReportSystemError(errno, "%s",
_("Unable to initialize mutex"));
VIR_FREE(sock);
return NULL;
}
sock->refs = 1;
if (localAddr) if (localAddr)
sock->localAddr = *localAddr; sock->localAddr = *localAddr;
if (remoteAddr) if (remoteAddr)
@ -627,6 +639,13 @@ void virNetSocketFree(virNetSocketPtr sock)
if (!sock) if (!sock)
return; return;
virMutexLock(&sock->lock);
sock->refs--;
if (sock->refs > 0) {
virMutexUnlock(&sock->lock);
return;
}
VIR_DEBUG("sock=%p fd=%d", sock, sock->fd); VIR_DEBUG("sock=%p fd=%d", sock, sock->fd);
if (sock->watch > 0) { if (sock->watch > 0) {
virEventRemoveHandle(sock->watch); virEventRemoveHandle(sock->watch);
@ -657,27 +676,41 @@ void virNetSocketFree(virNetSocketPtr sock)
VIR_FREE(sock->localAddrStr); VIR_FREE(sock->localAddrStr);
VIR_FREE(sock->remoteAddrStr); VIR_FREE(sock->remoteAddrStr);
virMutexUnlock(&sock->lock);
virMutexDestroy(&sock->lock);
VIR_FREE(sock); VIR_FREE(sock);
} }
int virNetSocketGetFD(virNetSocketPtr sock) int virNetSocketGetFD(virNetSocketPtr sock)
{ {
return sock->fd; int fd;
virMutexLock(&sock->lock);
fd = sock->fd;
virMutexUnlock(&sock->lock);
return fd;
} }
bool virNetSocketIsLocal(virNetSocketPtr sock) bool virNetSocketIsLocal(virNetSocketPtr sock)
{ {
bool isLocal = false;
virMutexLock(&sock->lock);
if (sock->localAddr.data.sa.sa_family == AF_UNIX) if (sock->localAddr.data.sa.sa_family == AF_UNIX)
return true; isLocal = true;
return false; virMutexUnlock(&sock->lock);
return isLocal;
} }
int virNetSocketGetPort(virNetSocketPtr sock) int virNetSocketGetPort(virNetSocketPtr sock)
{ {
return virSocketGetPort(&sock->localAddr); int port;
virMutexLock(&sock->lock);
port = virSocketGetPort(&sock->localAddr);
virMutexUnlock(&sock->lock);
return port;
} }
@ -688,15 +721,19 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock,
{ {
struct ucred cr; struct ucred cr;
unsigned int cr_len = sizeof (cr); unsigned int cr_len = sizeof (cr);
virMutexLock(&sock->lock);
if (getsockopt(sock->fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) < 0) { if (getsockopt(sock->fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) < 0) {
virReportSystemError(errno, "%s", virReportSystemError(errno, "%s",
_("Failed to get client socket identity")); _("Failed to get client socket identity"));
virMutexUnlock(&sock->lock);
return -1; return -1;
} }
*pid = cr.pid; *pid = cr.pid;
*uid = cr.uid; *uid = cr.uid;
virMutexUnlock(&sock->lock);
return 0; return 0;
} }
#else #else
@ -715,7 +752,11 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock ATTRIBUTE_UNUSED,
int virNetSocketSetBlocking(virNetSocketPtr sock, int virNetSocketSetBlocking(virNetSocketPtr sock,
bool blocking) bool blocking)
{ {
return virSetBlocking(sock->fd, blocking); int ret;
virMutexLock(&sock->lock);
ret = virSetBlocking(sock->fd, blocking);
virMutexUnlock(&sock->lock);
return ret;
} }
@ -751,6 +792,7 @@ static ssize_t virNetSocketTLSSessionRead(char *buf,
void virNetSocketSetTLSSession(virNetSocketPtr sock, void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess) virNetTLSSessionPtr sess)
{ {
virMutexLock(&sock->lock);
virNetTLSSessionFree(sock->tlsSession); virNetTLSSessionFree(sock->tlsSession);
sock->tlsSession = sess; sock->tlsSession = sess;
virNetTLSSessionSetIOCallbacks(sess, virNetTLSSessionSetIOCallbacks(sess,
@ -758,6 +800,7 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetSocketTLSSessionRead, virNetSocketTLSSessionRead,
sock); sock);
virNetTLSSessionRef(sess); virNetTLSSessionRef(sess);
virMutexUnlock(&sock->lock);
} }
@ -765,20 +808,25 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
void virNetSocketSetSASLSession(virNetSocketPtr sock, void virNetSocketSetSASLSession(virNetSocketPtr sock,
virNetSASLSessionPtr sess) virNetSASLSessionPtr sess)
{ {
virMutexLock(&sock->lock);
virNetSASLSessionFree(sock->saslSession); virNetSASLSessionFree(sock->saslSession);
sock->saslSession = sess; sock->saslSession = sess;
virNetSASLSessionRef(sess); virNetSASLSessionRef(sess);
virMutexUnlock(&sock->lock);
} }
#endif #endif
bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED) bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED)
{ {
bool hasCached = false;
virMutexLock(&sock->lock);
#if HAVE_SASL #if HAVE_SASL
if (sock->saslDecoded) if (sock->saslDecoded)
return true; hasCached = true;
#endif #endif
return false; virMutexUnlock(&sock->lock);
return hasCached;
} }
@ -965,39 +1013,54 @@ static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size
ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
{ {
ssize_t ret;
virMutexLock(&sock->lock);
#if HAVE_SASL #if HAVE_SASL
if (sock->saslSession) if (sock->saslSession)
return virNetSocketReadSASL(sock, buf, len); ret = virNetSocketReadSASL(sock, buf, len);
else else
#endif #endif
return virNetSocketReadWire(sock, buf, len); ret = virNetSocketReadWire(sock, buf, len);
virMutexUnlock(&sock->lock);
return ret;
} }
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
{ {
ssize_t ret;
virMutexLock(&sock->lock);
#if HAVE_SASL #if HAVE_SASL
if (sock->saslSession) if (sock->saslSession)
return virNetSocketWriteSASL(sock, buf, len); ret = virNetSocketWriteSASL(sock, buf, len);
else else
#endif #endif
return virNetSocketWriteWire(sock, buf, len); ret = virNetSocketWriteWire(sock, buf, len);
virMutexUnlock(&sock->lock);
return ret;
} }
int virNetSocketListen(virNetSocketPtr sock) int virNetSocketListen(virNetSocketPtr sock)
{ {
virMutexLock(&sock->lock);
if (listen(sock->fd, 30) < 0) { if (listen(sock->fd, 30) < 0) {
virReportSystemError(errno, "%s", _("Unable to listen on socket")); virReportSystemError(errno, "%s", _("Unable to listen on socket"));
virMutexUnlock(&sock->lock);
return -1; return -1;
} }
virMutexUnlock(&sock->lock);
return 0; return 0;
} }
int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock) int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock)
{ {
int fd; int fd = -1;
virSocketAddr localAddr; virSocketAddr localAddr;
virSocketAddr remoteAddr; virSocketAddr remoteAddr;
int ret = -1;
virMutexLock(&sock->lock);
*clientsock = NULL; *clientsock = NULL;
@ -1007,30 +1070,35 @@ int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock)
remoteAddr.len = sizeof(remoteAddr.data.stor); remoteAddr.len = sizeof(remoteAddr.data.stor);
if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) { if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) {
if (errno == ECONNABORTED || if (errno == ECONNABORTED ||
errno == EAGAIN) errno == EAGAIN) {
return 0; ret = 0;
goto cleanup;
}
virReportSystemError(errno, "%s", virReportSystemError(errno, "%s",
_("Unable to accept client")); _("Unable to accept client"));
return -1; goto cleanup;
} }
localAddr.len = sizeof(localAddr.data); localAddr.len = sizeof(localAddr.data);
if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
virReportSystemError(errno, "%s", _("Unable to get local socket name")); virReportSystemError(errno, "%s", _("Unable to get local socket name"));
VIR_FORCE_CLOSE(fd); goto cleanup;
return -1;
} }
if (!(*clientsock = virNetSocketNew(&localAddr, if (!(*clientsock = virNetSocketNew(&localAddr,
&remoteAddr, &remoteAddr,
true, true,
fd, -1, 0))) { fd, -1, 0)))
VIR_FORCE_CLOSE(fd); goto cleanup;
return -1;
}
return 0; fd = -1;
ret = 0;
cleanup:
VIR_FORCE_CLOSE(fd);
virMutexUnlock(&sock->lock);
return ret;
} }
@ -1040,18 +1108,30 @@ static void virNetSocketEventHandle(int watch ATTRIBUTE_UNUSED,
void *opaque) void *opaque)
{ {
virNetSocketPtr sock = opaque; virNetSocketPtr sock = opaque;
virNetSocketIOFunc func;
void *eopaque;
sock->func(sock, events, sock->opaque); virMutexLock(&sock->lock);
func = sock->func;
eopaque = sock->opaque;
virMutexUnlock(&sock->lock);
if (func)
func(sock, events, eopaque);
} }
int virNetSocketAddIOCallback(virNetSocketPtr sock, int virNetSocketAddIOCallback(virNetSocketPtr sock,
int events, int events,
virNetSocketIOFunc func, virNetSocketIOFunc func,
void *opaque) void *opaque)
{ {
int ret = -1;
virMutexLock(&sock->lock);
if (sock->watch > 0) { if (sock->watch > 0) {
VIR_DEBUG("Watch already registered on socket %p", sock); VIR_DEBUG("Watch already registered on socket %p", sock);
return -1; goto cleanup;
} }
if ((sock->watch = virEventAddHandle(sock->fd, if ((sock->watch = virEventAddHandle(sock->fd,
@ -1060,32 +1140,44 @@ int virNetSocketAddIOCallback(virNetSocketPtr sock,
sock, sock,
NULL)) < 0) { NULL)) < 0) {
VIR_DEBUG("Failed to register watch on socket %p", sock); VIR_DEBUG("Failed to register watch on socket %p", sock);
return -1; goto cleanup;
} }
sock->func = func; sock->func = func;
sock->opaque = opaque; sock->opaque = opaque;
return 0; ret = 0;
cleanup:
virMutexUnlock(&sock->lock);
return ret;
} }
void virNetSocketUpdateIOCallback(virNetSocketPtr sock, void virNetSocketUpdateIOCallback(virNetSocketPtr sock,
int events) int events)
{ {
virMutexLock(&sock->lock);
if (sock->watch <= 0) { if (sock->watch <= 0) {
VIR_DEBUG("Watch not registered on socket %p", sock); VIR_DEBUG("Watch not registered on socket %p", sock);
virMutexUnlock(&sock->lock);
return; return;
} }
virEventUpdateHandle(sock->watch, events); virEventUpdateHandle(sock->watch, events);
virMutexUnlock(&sock->lock);
} }
void virNetSocketRemoveIOCallback(virNetSocketPtr sock) void virNetSocketRemoveIOCallback(virNetSocketPtr sock)
{ {
virMutexLock(&sock->lock);
if (sock->watch <= 0) { if (sock->watch <= 0) {
VIR_DEBUG("Watch not registered on socket %p", sock); VIR_DEBUG("Watch not registered on socket %p", sock);
virMutexUnlock(&sock->lock);
return; return;
} }
virEventRemoveHandle(sock->watch); virEventRemoveHandle(sock->watch);
sock->watch = 0;
virMutexUnlock(&sock->lock);
} }