Fix locking wrt virNetClientStreamPtr object

The client stream object can be used independently of the
virNetClientPtr object, so must have full locking of its
own and not rely on any caller.

* src/remote/remote_driver.c: Remove locking around stream
  callback
* src/rpc/virnetclientstream.c: Add locking to all APIs
  and callbacks
This commit is contained in:
Daniel P. Berrange 2011-06-28 17:51:49 +01:00
parent 7a779ef6a2
commit 8a4e28743e
2 changed files with 89 additions and 26 deletions

View File

@ -3254,11 +3254,8 @@ static void remoteStreamEventCallback(virNetClientStreamPtr stream ATTRIBUTE_UNU
void *opaque)
{
struct remoteStreamCallbackData *cbdata = opaque;
struct private_data *priv = cbdata->st->conn->privateData;
remoteDriverUnlock(priv);
(cbdata->cb)(cbdata->st, events, cbdata->opaque);
remoteDriverLock(priv);
}

View File

@ -28,6 +28,7 @@
#include "virterror_internal.h"
#include "logging.h"
#include "event.h"
#include "threads.h"
#define VIR_FROM_THIS VIR_FROM_RPC
#define virNetError(code, ...) \
@ -35,6 +36,8 @@
__FUNCTION__, __LINE__, __VA_ARGS__)
struct _virNetClientStream {
virMutex lock;
virNetClientProgramPtr prog;
int proc;
unsigned serial;
@ -53,7 +56,6 @@ struct _virNetClientStream {
size_t incomingOffset;
size_t incomingLength;
virNetClientStreamEventCallback cb;
void *cbOpaque;
virFreeCallback cbFree;
@ -89,7 +91,8 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
virNetClientStreamPtr st = opaque;
int events = 0;
/* XXX we need a mutex on 'st' to protect this callback */
virMutexLock(&st->lock);
if (st->cb &&
(st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
@ -106,12 +109,15 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
virFreeCallback cbFree = st->cbFree;
st->cbDispatch = 1;
virMutexUnlock(&st->lock);
(cb)(st, events, cbOpaque);
virMutexLock(&st->lock);
st->cbDispatch = 0;
if (!st->cb && cbFree)
(cbFree)(cbOpaque);
}
virMutexUnlock(&st->lock);
}
@ -134,30 +140,45 @@ virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog,
return NULL;
}
virNetClientProgramRef(prog);
st->refs = 1;
st->prog = prog;
st->proc = proc;
st->serial = serial;
if (virMutexInit(&st->lock) < 0) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
_("cannot initialize mutex"));
VIR_FREE(st);
return NULL;
}
virNetClientProgramRef(prog);
return st;
}
void virNetClientStreamRef(virNetClientStreamPtr st)
{
virMutexLock(&st->lock);
st->refs++;
virMutexUnlock(&st->lock);
}
void virNetClientStreamFree(virNetClientStreamPtr st)
{
virMutexLock(&st->lock);
st->refs--;
if (st->refs > 0)
if (st->refs > 0) {
virMutexUnlock(&st->lock);
return;
}
virMutexUnlock(&st->lock);
virResetError(&st->err);
VIR_FREE(st->incoming);
virMutexDestroy(&st->lock);
virNetClientProgramFree(st->prog);
VIR_FREE(st);
}
@ -165,18 +186,24 @@ void virNetClientStreamFree(virNetClientStreamPtr st)
bool virNetClientStreamMatches(virNetClientStreamPtr st,
virNetMessagePtr msg)
{
bool match = false;
virMutexLock(&st->lock);
if (virNetClientProgramMatches(st->prog, msg) &&
st->proc == msg->header.proc &&
st->serial == msg->header.serial)
return 1;
return 0;
match = true;
virMutexUnlock(&st->lock);
return match;
}
bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
{
if (st->err.code == VIR_ERR_OK)
virMutexLock(&st->lock);
if (st->err.code == VIR_ERR_OK) {
virMutexUnlock(&st->lock);
return false;
}
virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__,
st->err.domain,
@ -188,7 +215,7 @@ bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
st->err.int1,
st->err.int2,
"%s", st->err.message ? st->err.message : _("Unknown error"));
virMutexUnlock(&st->lock);
return true;
}
@ -199,6 +226,8 @@ int virNetClientStreamSetError(virNetClientStreamPtr st,
virNetMessageError err;
int ret = -1;
virMutexLock(&st->lock);
if (st->err.code != VIR_ERR_OK)
VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message));
@ -242,6 +271,7 @@ int virNetClientStreamSetError(virNetClientStreamPtr st,
cleanup:
xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err);
virMutexUnlock(&st->lock);
return ret;
}
@ -249,15 +279,18 @@ cleanup:
int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
virNetMessagePtr msg)
{
size_t avail = st->incomingLength - st->incomingOffset;
size_t need = msg->bufferLength - msg->bufferOffset;
int ret = -1;
size_t need;
virMutexLock(&st->lock);
need = msg->bufferLength - msg->bufferOffset;
size_t avail = st->incomingLength - st->incomingOffset;
if (need > avail) {
size_t extra = need - avail;
if (VIR_REALLOC_N(st->incoming,
st->incomingLength + extra) < 0) {
VIR_DEBUG("Out of memory handling stream data");
return -1;
goto cleanup;
}
st->incomingLength += extra;
}
@ -269,7 +302,12 @@ int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
VIR_DEBUG("Stream incoming data offset %zu length %zu",
st->incomingOffset, st->incomingLength);
return 0;
ret = 0;
cleanup:
virMutexUnlock(&st->lock);
return ret;
}
@ -286,6 +324,8 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st,
if (!(msg = virNetMessageNew()))
return -1;
virMutexLock(&st->lock);
msg->header.prog = virNetClientProgramGetProgram(st->prog);
msg->header.vers = virNetClientProgramGetVersion(st->prog);
msg->header.status = status;
@ -293,6 +333,8 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st,
msg->header.serial = st->serial;
msg->header.proc = st->proc;
virMutexUnlock(&st->lock);
if (virNetMessageEncodeHeader(msg) < 0)
goto error;
@ -329,6 +371,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
int rv = -1;
VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d",
st, client, data, nbytes, nonblock);
virMutexLock(&st->lock);
if (!st->incomingOffset) {
virNetMessagePtr msg;
int ret;
@ -351,8 +394,9 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
msg->header.proc = st->proc;
VIR_DEBUG("Dummy packet to wait for stream data");
virMutexUnlock(&st->lock);
ret = virNetClientSend(client, msg, true);
virMutexLock(&st->lock);
virNetMessageFree(msg);
if (ret < 0)
@ -380,6 +424,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
virNetClientStreamEventTimerUpdate(st);
cleanup:
virMutexUnlock(&st->lock);
return rv;
}
@ -390,20 +435,23 @@ int virNetClientStreamEventAddCallback(virNetClientStreamPtr st,
void *opaque,
virFreeCallback ff)
{
int ret = -1;
virMutexLock(&st->lock);
if (st->cb) {
virNetError(VIR_ERR_INTERNAL_ERROR,
"%s", _("multiple stream callbacks not supported"));
return 1;
goto cleanup;
}
virNetClientStreamRef(st);
st->refs++;
if ((st->cbTimer =
virEventAddTimeout(-1,
virNetClientStreamEventTimer,
st,
virNetClientStreamEventTimerFree)) < 0) {
virNetClientStreamFree(st);
return -1;
st->refs--;
goto cleanup;
}
st->cb = cb;
@ -413,31 +461,45 @@ int virNetClientStreamEventAddCallback(virNetClientStreamPtr st,
virNetClientStreamEventTimerUpdate(st);
return 0;
ret = 0;
cleanup:
virMutexUnlock(&st->lock);
return ret;
}
int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st,
int events)
{
int ret = -1;
virMutexLock(&st->lock);
if (!st->cb) {
virNetError(VIR_ERR_INTERNAL_ERROR,
"%s", _("no stream callback registered"));
return -1;
goto cleanup;
}
st->cbEvents = events;
virNetClientStreamEventTimerUpdate(st);
return 0;
ret = 0;
cleanup:
virMutexUnlock(&st->lock);
return ret;
}
int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
{
int ret = -1;
virMutexUnlock(&st->lock);
if (!st->cb) {
virNetError(VIR_ERR_INTERNAL_ERROR,
"%s", _("no stream callback registered"));
return -1;
goto cleanup;
}
if (!st->cbDispatch &&
@ -449,5 +511,9 @@ int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
st->cbEvents = 0;
virEventRemoveTimeout(st->cbTimer);
return 0;
ret = 0;
cleanup:
virMutexUnlock(&st->lock);
return ret;
}