/*
* virnetclientstream.c: generic network RPC client stream
*
* Copyright (C) 2006-2011 Red Hat, Inc.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library. If not, see
* .
*
* Author: Daniel P. Berrange
*/
#include
#include "virnetclientstream.h"
#include "virnetclient.h"
#include "viralloc.h"
#include "virerror.h"
#include "virlog.h"
#include "virthread.h"
#define VIR_FROM_THIS VIR_FROM_RPC
VIR_LOG_INIT("rpc.netclientstream");
struct _virNetClientStream {
virObjectLockable parent;
virNetClientProgramPtr prog;
int proc;
unsigned serial;
virError err;
/* XXX this buffer is unbounded if the client
* app has domain events registered, since packets
* may be read off wire, while app isn't ready to
* recv them. Figure out how to address this some
* time by stopping consuming any incoming data
* off the socket....
*/
struct iovec *incomingVec; /* I/O Vector to hold data */
size_t writeVec; /* Vectors produced */
size_t readVec; /* Vectors consumed */
bool incomingEOF;
virNetClientStreamEventCallback cb;
void *cbOpaque;
virFreeCallback cbFree;
int cbEvents;
int cbTimer;
int cbDispatch;
};
static virClassPtr virNetClientStreamClass;
static void virNetClientStreamDispose(void *obj);
static int virNetClientStreamOnceInit(void)
{
if (!(virNetClientStreamClass = virClassNew(virClassForObjectLockable(),
"virNetClientStream",
sizeof(virNetClientStream),
virNetClientStreamDispose)))
return -1;
return 0;
}
VIR_ONCE_GLOBAL_INIT(virNetClientStream)
static void
virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st)
{
if (!st->cb)
return;
VIR_DEBUG("Check timer readVec %zu writeVec %zu %d", st->readVec, st->writeVec, st->cbEvents);
if ((((st->readVec < st->writeVec) || st->incomingEOF) &&
(st->cbEvents & VIR_STREAM_EVENT_READABLE)) ||
(st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) {
VIR_DEBUG("Enabling event timer");
virEventUpdateTimeout(st->cbTimer, 0);
} else {
VIR_DEBUG("Disabling event timer");
virEventUpdateTimeout(st->cbTimer, -1);
}
}
static void
virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
{
virNetClientStreamPtr st = opaque;
int events = 0;
virObjectLock(st);
if (st->cb &&
(st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
((st->readVec < st->writeVec) || st->incomingEOF))
events |= VIR_STREAM_EVENT_READABLE;
if (st->cb &&
(st->cbEvents & VIR_STREAM_EVENT_WRITABLE))
events |= VIR_STREAM_EVENT_WRITABLE;
VIR_DEBUG("Got Timer dispatch %d %d readVec %zu writeVec %zu", events, st->cbEvents,
st->readVec, st->writeVec);
if (events) {
virNetClientStreamEventCallback cb = st->cb;
void *cbOpaque = st->cbOpaque;
virFreeCallback cbFree = st->cbFree;
st->cbDispatch = 1;
virObjectUnlock(st);
(cb)(st, events, cbOpaque);
virObjectLock(st);
st->cbDispatch = 0;
if (!st->cb && cbFree)
(cbFree)(cbOpaque);
}
virObjectUnlock(st);
}
virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog,
int proc,
unsigned serial)
{
virNetClientStreamPtr st;
if (virNetClientStreamInitialize() < 0)
return NULL;
if (!(st = virObjectLockableNew(virNetClientStreamClass)))
return NULL;
st->prog = prog;
st->proc = proc;
st->serial = serial;
virObjectRef(prog);
return st;
}
void virNetClientStreamDispose(void *obj)
{
virNetClientStreamPtr st = obj;
virResetError(&st->err);
VIR_FREE(st->incomingVec);
virObjectUnref(st->prog);
}
bool virNetClientStreamMatches(virNetClientStreamPtr st,
virNetMessagePtr msg)
{
bool match = false;
virObjectLock(st);
if (virNetClientProgramMatches(st->prog, msg) &&
st->proc == msg->header.proc &&
st->serial == msg->header.serial)
match = true;
virObjectUnlock(st);
return match;
}
bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
{
virObjectLock(st);
if (st->err.code == VIR_ERR_OK) {
virObjectUnlock(st);
return false;
}
virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__,
st->err.domain,
st->err.code,
st->err.level,
st->err.str1,
st->err.str2,
st->err.str3,
st->err.int1,
st->err.int2,
"%s", st->err.message ? st->err.message : _("Unknown error"));
virObjectUnlock(st);
return true;
}
int virNetClientStreamSetError(virNetClientStreamPtr st,
virNetMessagePtr msg)
{
virNetMessageError err;
int ret = -1;
virObjectLock(st);
if (st->err.code != VIR_ERR_OK)
VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message));
virResetError(&st->err);
memset(&err, 0, sizeof(err));
if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0)
goto cleanup;
if (err.domain == VIR_FROM_REMOTE &&
err.code == VIR_ERR_RPC &&
err.level == VIR_ERR_ERROR &&
err.message &&
STRPREFIX(*err.message, "unknown procedure")) {
st->err.code = VIR_ERR_NO_SUPPORT;
} else {
st->err.code = err.code;
}
if (err.message) {
st->err.message = *err.message;
*err.message = NULL;
}
st->err.domain = err.domain;
st->err.level = err.level;
if (err.str1) {
st->err.str1 = *err.str1;
*err.str1 = NULL;
}
if (err.str2) {
st->err.str2 = *err.str2;
*err.str2 = NULL;
}
if (err.str3) {
st->err.str3 = *err.str3;
*err.str3 = NULL;
}
st->err.int1 = err.int1;
st->err.int2 = err.int2;
st->incomingEOF = true;
virNetClientStreamEventTimerUpdate(st);
ret = 0;
cleanup:
xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err);
virObjectUnlock(st);
return ret;
}
int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
virNetMessagePtr msg)
{
int ret = -1;
struct iovec iov;
char *base;
size_t piece, pieces, length, offset = 0, size = 1024*1024;
virObjectLock(st);
length = msg->bufferLength - msg->bufferOffset;
if (length == 0) {
st->incomingEOF = true;
goto end;
}
pieces = VIR_DIV_UP(length, size);
for (piece = 0; piece < pieces; piece++) {
if (size > length - offset)
size = length - offset;
if (VIR_ALLOC_N(base, size)) {
VIR_DEBUG("Allocation failed");
goto cleanup;
}
memcpy(base, msg->buffer + msg->bufferOffset + offset, size);
iov.iov_base = base;
iov.iov_len = size;
offset += size;
if (VIR_APPEND_ELEMENT(st->incomingVec, st->writeVec, iov) < 0) {
VIR_DEBUG("Append failed");
VIR_FREE(base);
goto cleanup;
}
VIR_DEBUG("Wrote piece of vector. readVec %zu, writeVec %zu size %zu",
st->readVec, st->writeVec, size);
}
end:
virNetClientStreamEventTimerUpdate(st);
ret = 0;
cleanup:
VIR_DEBUG("Stream incoming data readVec %zu writeVec %zu EOF %d",
st->readVec, st->writeVec, st->incomingEOF);
virObjectUnlock(st);
return ret;
}
int virNetClientStreamSendPacket(virNetClientStreamPtr st,
virNetClientPtr client,
int status,
const char *data,
size_t nbytes)
{
virNetMessagePtr msg;
VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes);
if (!(msg = virNetMessageNew(false)))
return -1;
virObjectLock(st);
msg->header.prog = virNetClientProgramGetProgram(st->prog);
msg->header.vers = virNetClientProgramGetVersion(st->prog);
msg->header.status = status;
msg->header.type = VIR_NET_STREAM;
msg->header.serial = st->serial;
msg->header.proc = st->proc;
virObjectUnlock(st);
if (virNetMessageEncodeHeader(msg) < 0)
goto error;
/* Data packets are async fire&forget, but OK/ERROR packets
* need a synchronous confirmation
*/
if (status == VIR_NET_CONTINUE) {
if (virNetMessageEncodePayloadRaw(msg, data, nbytes) < 0)
goto error;
if (virNetClientSendNoReply(client, msg) < 0)
goto error;
} else {
if (virNetMessageEncodePayloadRaw(msg, NULL, 0) < 0)
goto error;
if (virNetClientSendWithReply(client, msg) < 0)
goto error;
}
virNetMessageFree(msg);
return nbytes;
error:
virNetMessageFree(msg);
return -1;
}
int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
virNetClientPtr client,
char *data,
size_t nbytes,
bool nonblock)
{
int ret = -1;
size_t partial, offset;
virObjectLock(st);
VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d",
st, client, data, nbytes, nonblock);
if ((st->readVec >= st->writeVec) && !st->incomingEOF) {
virNetMessagePtr msg;
int rv;
if (nonblock) {
VIR_DEBUG("Non-blocking mode and no data available");
ret = -2;
goto cleanup;
}
if (!(msg = virNetMessageNew(false)))
goto cleanup;
msg->header.prog = virNetClientProgramGetProgram(st->prog);
msg->header.vers = virNetClientProgramGetVersion(st->prog);
msg->header.type = VIR_NET_STREAM;
msg->header.serial = st->serial;
msg->header.proc = st->proc;
msg->header.status = VIR_NET_CONTINUE;
VIR_DEBUG("Dummy packet to wait for stream data");
virObjectUnlock(st);
rv = virNetClientSendWithReplyStream(client, msg, st);
virObjectLock(st);
virNetMessageFree(msg);
if (rv < 0)
goto cleanup;
}
offset = 0;
partial = nbytes;
while (st->incomingVec && (st->readVec < st->writeVec)) {
struct iovec *iov = st->incomingVec + st->readVec;
if (!iov || !iov->iov_base) {
virReportError(VIR_ERR_INTERNAL_ERROR,
"%s", _("NULL pointer encountered"));
goto cleanup;
}
if (partial < iov->iov_len) {
memcpy(data+offset, iov->iov_base, partial);
memmove(iov->iov_base, (char*)iov->iov_base+partial,
iov->iov_len-partial);
iov->iov_len -= partial;
offset += partial;
VIR_DEBUG("Consumed %zu, left %zu", partial, iov->iov_len);
break;
}
memcpy(data+offset, iov->iov_base, iov->iov_len);
VIR_DEBUG("Consumed %zu. Moving to next piece", iov->iov_len);
partial -= iov->iov_len;
offset += iov->iov_len;
VIR_FREE(iov->iov_base);
iov->iov_len = 0;
st->readVec++;
VIR_DEBUG("Read piece of vector. read %zu, readVec %zu, writeVec %zu",
offset, st->readVec, st->writeVec);
}
/* Shrink the I/O Vector buffer to free up memory. Do the
shrinking only when there is selected amount or more buffers to
free so it doesn't constantly memmove() and realloc() buffers.
*/
if (st->readVec >= 16) {
memmove(st->incomingVec, st->incomingVec + st->readVec,
sizeof(*st->incomingVec)*(st->writeVec - st->readVec));
VIR_SHRINK_N(st->incomingVec, st->writeVec, st->readVec);
VIR_DEBUG("shrink removed %zu, left %zu", st->readVec, st->writeVec);
st->readVec = 0;
}
ret = offset;
virNetClientStreamEventTimerUpdate(st);
cleanup:
virObjectUnlock(st);
return ret;
}
int virNetClientStreamEventAddCallback(virNetClientStreamPtr st,
int events,
virNetClientStreamEventCallback cb,
void *opaque,
virFreeCallback ff)
{
int ret = -1;
virObjectLock(st);
if (st->cb) {
virReportError(VIR_ERR_INTERNAL_ERROR,
"%s", _("multiple stream callbacks not supported"));
goto cleanup;
}
virObjectRef(st);
if ((st->cbTimer =
virEventAddTimeout(-1,
virNetClientStreamEventTimer,
st,
virObjectFreeCallback)) < 0) {
virObjectUnref(st);
goto cleanup;
}
st->cb = cb;
st->cbOpaque = opaque;
st->cbFree = ff;
st->cbEvents = events;
virNetClientStreamEventTimerUpdate(st);
ret = 0;
cleanup:
virObjectUnlock(st);
return ret;
}
int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st,
int events)
{
int ret = -1;
virObjectLock(st);
if (!st->cb) {
virReportError(VIR_ERR_INTERNAL_ERROR,
"%s", _("no stream callback registered"));
goto cleanup;
}
st->cbEvents = events;
virNetClientStreamEventTimerUpdate(st);
ret = 0;
cleanup:
virObjectUnlock(st);
return ret;
}
int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
{
int ret = -1;
virObjectLock(st);
if (!st->cb) {
virReportError(VIR_ERR_INTERNAL_ERROR,
"%s", _("no stream callback registered"));
goto cleanup;
}
if (!st->cbDispatch &&
st->cbFree)
(st->cbFree)(st->cbOpaque);
st->cb = NULL;
st->cbOpaque = NULL;
st->cbFree = NULL;
st->cbEvents = 0;
virEventRemoveTimeout(st->cbTimer);
ret = 0;
cleanup:
virObjectUnlock(st);
return ret;
}
bool virNetClientStreamEOF(virNetClientStreamPtr st)
{
return st->incomingEOF;
}