From 434de30da545aea1379324b9098061201dd1529b Mon Sep 17 00:00:00 2001 From: "Daniel P. Berrange" Date: Wed, 1 Dec 2010 16:35:50 +0000 Subject: [PATCH] Introduce generic RPC client objects To facilitate creation of new clients using XDR RPC services, pull alot of the remote driver code into a set of reusable objects. - virNetClient: Encapsulates a socket connection to a remote RPC server. Handles all the network I/O for reading/writing RPC messages. Delegates RPC encoding and decoding to the registered programs - virNetClientProgram: Handles processing and dispatch of RPC messages for a single RPC (program,version). A program can register to receive async events from a client - virNetClientStream: Handles generic I/O stream integration to RPC layer Each new client program now merely needs to define the list of RPC procedures & events it wants and their handlers. It does not need to deal with any of the network I/O functionality at all. --- cfg.mk | 3 + po/POTFILES.in | 3 + src/Makefile.am | 14 +- src/rpc/virnetclient.c | 1166 +++++++++++++++++++++++++++++++++ src/rpc/virnetclient.h | 84 +++ src/rpc/virnetclientprogram.c | 339 ++++++++++ src/rpc/virnetclientprogram.h | 85 +++ src/rpc/virnetclientstream.c | 442 +++++++++++++ src/rpc/virnetclientstream.h | 76 +++ 9 files changed, 2211 insertions(+), 1 deletion(-) create mode 100644 src/rpc/virnetclient.c create mode 100644 src/rpc/virnetclient.h create mode 100644 src/rpc/virnetclientprogram.c create mode 100644 src/rpc/virnetclientprogram.h create mode 100644 src/rpc/virnetclientstream.c create mode 100644 src/rpc/virnetclientstream.h diff --git a/cfg.mk b/cfg.mk index 4b4442ea29..02931f3efc 100644 --- a/cfg.mk +++ b/cfg.mk @@ -126,6 +126,9 @@ useless_free_options = \ --name=virJSONValueFree \ --name=virLastErrFreeData \ --name=virNetMessageFree \ + --name=virNetClientFree \ + --name=virNetClientProgramFree \ + --name=virNetClientStreamFree \ --name=virNetServerFree \ --name=virNetServerClientFree \ --name=virNetServerMDNSFree \ diff --git a/po/POTFILES.in b/po/POTFILES.in index b0db765fa8..6b07386120 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -69,6 +69,9 @@ src/qemu/qemu_monitor_text.c src/qemu/qemu_process.c src/remote/remote_client_bodies.h src/remote/remote_driver.c +src/rpc/virnetclient.c +src/rpc/virnetclientprogram.c +src/rpc/virnetclientstream.c src/rpc/virnetmessage.c src/rpc/virnetsaslcontext.c src/rpc/virnetsocket.c diff --git a/src/Makefile.am b/src/Makefile.am index db7bc808cc..dabfd647c9 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1188,7 +1188,7 @@ else EXTRA_DIST += $(LOCK_DRIVER_SANLOCK_SOURCES) endif -noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la +noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la libvirt-net-rpc-client.la libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ @@ -1238,6 +1238,18 @@ libvirt_net_rpc_server_la_LDFLAGS = \ libvirt_net_rpc_server_la_LIBADD = \ $(CYGWIN_EXTRA_LIBADD) +libvirt_net_rpc_client_la_SOURCES = \ + rpc/virnetclientprogram.h rpc/virnetclientprogram.c \ + rpc/virnetclientstream.h rpc/virnetclientstream.c \ + rpc/virnetclient.h rpc/virnetclient.c +libvirt_net_rpc_client_la_CFLAGS = \ + $(AM_CFLAGS) +libvirt_net_rpc_client_la_LDFLAGS = \ + $(AM_LDFLAGS) \ + $(CYGWIN_EXTRA_LDFLAGS) \ + $(MINGW_EXTRA_LDFLAGS) +libvirt_net_rpc_client_la_LIBADD = \ + $(CYGWIN_EXTRA_LIBADD) libexec_PROGRAMS = diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c new file mode 100644 index 0000000000..ded1e12615 --- /dev/null +++ b/src/rpc/virnetclient.c @@ -0,0 +1,1166 @@ +/* + * virnetclient.c: generic network RPC client + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange + */ + +#include + +#include +#include +#include +#include + +#include "virnetclient.h" +#include "virnetsocket.h" +#include "memory.h" +#include "threads.h" +#include "files.h" +#include "logging.h" +#include "util.h" +#include "virterror_internal.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +typedef struct _virNetClientCall virNetClientCall; +typedef virNetClientCall *virNetClientCallPtr; + +enum { + VIR_NET_CLIENT_MODE_WAIT_TX, + VIR_NET_CLIENT_MODE_WAIT_RX, + VIR_NET_CLIENT_MODE_COMPLETE, +}; + +struct _virNetClientCall { + int mode; + + virNetMessagePtr msg; + bool expectReply; + + virCond cond; + + virNetClientCallPtr next; +}; + + +struct _virNetClient { + int refs; + + virMutex lock; + + virNetSocketPtr sock; + + virNetTLSSessionPtr tls; + char *hostname; + + virNetClientProgramPtr *programs; + size_t nprograms; + + /* For incoming message packets */ + virNetMessage msg; + +#if HAVE_SASL + virNetSASLSessionPtr sasl; +#endif + + /* Self-pipe to wakeup threads waiting in poll() */ + int wakeupSendFD; + int wakeupReadFD; + + /* List of threads currently waiting for dispatch */ + virNetClientCallPtr waitDispatch; + + size_t nstreams; + virNetClientStreamPtr *streams; +}; + + +static void virNetClientLock(virNetClientPtr client) +{ + virMutexLock(&client->lock); +} + + +static void virNetClientUnlock(virNetClientPtr client) +{ + virMutexUnlock(&client->lock); +} + + +static void virNetClientIncomingEvent(virNetSocketPtr sock, + int events, + void *opaque); + +static virNetClientPtr virNetClientNew(virNetSocketPtr sock, + const char *hostname) +{ + virNetClientPtr client; + int wakeupFD[2] = { -1, -1 }; + + if (pipe2(wakeupFD, O_CLOEXEC) < 0) { + virReportSystemError(errno, "%s", + _("unable to make pipe")); + goto error; + } + + if (VIR_ALLOC(client) < 0) + goto no_memory; + + client->refs = 1; + + if (virMutexInit(&client->lock) < 0) + goto error; + + client->sock = sock; + client->wakeupReadFD = wakeupFD[0]; + client->wakeupSendFD = wakeupFD[1]; + wakeupFD[0] = wakeupFD[1] = -1; + + if (hostname && + !(client->hostname = strdup(hostname))) + goto no_memory; + + /* Set up a callback to listen on the socket data */ + if (virNetSocketAddIOCallback(client->sock, + VIR_EVENT_HANDLE_READABLE, + virNetClientIncomingEvent, + client) < 0) + VIR_DEBUG("Failed to add event watch, disabling events"); + + return client; + +no_memory: + virReportOOMError(); +error: + VIR_FORCE_CLOSE(wakeupFD[0]); + VIR_FORCE_CLOSE(wakeupFD[1]); + virNetClientFree(client); + return NULL; +} + + +virNetClientPtr virNetClientNewUNIX(const char *path, + bool spawnDaemon, + const char *binary) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectUNIX(path, spawnDaemon, binary, &sock) < 0) + return NULL; + + return virNetClientNew(sock, NULL); +} + + +virNetClientPtr virNetClientNewTCP(const char *nodename, + const char *service) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectTCP(nodename, service, &sock) < 0) + return NULL; + + return virNetClientNew(sock, nodename); +} + +virNetClientPtr virNetClientNewSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectSSH(nodename, service, binary, username, noTTY, netcat, path, &sock) < 0) + return NULL; + + return virNetClientNew(sock, NULL); +} + +virNetClientPtr virNetClientNewExternal(const char **cmdargv) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectExternal(cmdargv, &sock) < 0) + return NULL; + + return virNetClientNew(sock, NULL); +} + + +void virNetClientRef(virNetClientPtr client) +{ + virNetClientLock(client); + client->refs++; + virNetClientUnlock(client); +} + + +void virNetClientFree(virNetClientPtr client) +{ + int i; + + if (!client) + return; + + virNetClientLock(client); + client->refs--; + if (client->refs > 0) { + virNetClientUnlock(client); + return; + } + + for (i = 0 ; i < client->nprograms ; i++) + virNetClientProgramFree(client->programs[i]); + VIR_FREE(client->programs); + + VIR_FORCE_CLOSE(client->wakeupSendFD); + VIR_FORCE_CLOSE(client->wakeupReadFD); + + VIR_FREE(client->hostname); + + virNetSocketRemoveIOCallback(client->sock); + virNetSocketFree(client->sock); + virNetTLSSessionFree(client->tls); +#if HAVE_SASL + virNetSASLSessionFree(client->sasl); +#endif + virNetClientUnlock(client); + virMutexDestroy(&client->lock); + + VIR_FREE(client); +} + + +#if HAVE_SASL +void virNetClientSetSASLSession(virNetClientPtr client, + virNetSASLSessionPtr sasl) +{ + virNetClientLock(client); + client->sasl = sasl; + virNetSASLSessionRef(sasl); + virNetSocketSetSASLSession(client->sock, client->sasl); + virNetClientUnlock(client); +} +#endif + + +int virNetClientSetTLSSession(virNetClientPtr client, + virNetTLSContextPtr tls) +{ + int ret; + char buf[1]; + int len; + struct pollfd fds[1]; +#ifdef HAVE_PTHREAD_SIGMASK + sigset_t oldmask, blockedsigs; + + sigemptyset (&blockedsigs); + sigaddset (&blockedsigs, SIGWINCH); + sigaddset (&blockedsigs, SIGCHLD); + sigaddset (&blockedsigs, SIGPIPE); +#endif + + virNetClientLock(client); + + if (!(client->tls = virNetTLSSessionNew(tls, + client->hostname))) + goto error; + + virNetSocketSetTLSSession(client->sock, client->tls); + + for (;;) { + ret = virNetTLSSessionHandshake(client->tls); + + if (ret < 0) + goto error; + if (ret == 0) + break; + + fds[0].fd = virNetSocketGetFD(client->sock); + fds[0].revents = 0; + if (virNetTLSSessionGetHandshakeStatus(client->tls) == + VIR_NET_TLS_HANDSHAKE_RECVING) + fds[0].events = POLLIN; + else + fds[0].events = POLLOUT; + + /* Block SIGWINCH from interrupting poll in curses programs, + * then restore the original signal mask again immediately + * after the call (RHBZ#567931). Same for SIGCHLD and SIGPIPE + * at the suggestion of Paolo Bonzini and Daniel Berrange. + */ +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); +#endif + + repoll: + ret = poll(fds, ARRAY_CARDINALITY(fds), -1); + if (ret < 0 && errno == EAGAIN) + goto repoll; + +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_BLOCK, &oldmask, NULL)); +#endif + } + + ret = virNetTLSContextCheckCertificate(tls, client->tls); + + if (ret < 0) + goto error; + + /* At this point, the server is verifying _our_ certificate, IP address, + * etc. If we make the grade, it will send us a '\1' byte. + */ + + fds[0].fd = virNetSocketGetFD(client->sock); + fds[0].revents = 0; + fds[0].events = POLLIN; + +#ifdef HAVE_PTHREAD_SIGMASK + /* Block SIGWINCH from interrupting poll in curses programs */ + ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); +#endif + + repoll2: + ret = poll(fds, ARRAY_CARDINALITY(fds), -1); + if (ret < 0 && errno == EAGAIN) + goto repoll2; + +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_BLOCK, &oldmask, NULL)); +#endif + + len = virNetTLSSessionRead(client->tls, buf, 1); + if (len < 0) { + virReportSystemError(errno, "%s", + _("Unable to read TLS confirmation")); + goto error; + } + if (len != 1 || buf[0] != '\1') { + virNetError(VIR_ERR_RPC, "%s", + _("server verification (of our certificate or IP " + "address) failed")); + goto error; + } + + virNetClientUnlock(client); + return 0; + +error: + virNetTLSSessionFree(client->tls); + client->tls = NULL; + virNetClientUnlock(client); + return -1; +} + +bool virNetClientIsEncrypted(virNetClientPtr client) +{ + bool ret = false; + virNetClientLock(client); + if (client->tls) + ret = true; +#if HAVE_SASL + if (client->sasl) + ret = true; +#endif + virNetClientUnlock(client); + return ret; +} + + +int virNetClientAddProgram(virNetClientPtr client, + virNetClientProgramPtr prog) +{ + virNetClientLock(client); + + if (VIR_EXPAND_N(client->programs, client->nprograms, 1) < 0) + goto no_memory; + + client->programs[client->nprograms-1] = prog; + virNetClientProgramRef(prog); + + virNetClientUnlock(client); + return 0; + +no_memory: + virReportOOMError(); + virNetClientUnlock(client); + return -1; +} + + +int virNetClientAddStream(virNetClientPtr client, + virNetClientStreamPtr st) +{ + virNetClientLock(client); + + if (VIR_EXPAND_N(client->streams, client->nstreams, 1) < 0) + goto no_memory; + + client->streams[client->nstreams-1] = st; + virNetClientStreamRef(st); + + virNetClientUnlock(client); + return 0; + +no_memory: + virReportOOMError(); + virNetClientUnlock(client); + return -1; +} + + +void virNetClientRemoveStream(virNetClientPtr client, + virNetClientStreamPtr st) +{ + virNetClientLock(client); + size_t i; + for (i = 0 ; i < client->nstreams ; i++) { + if (client->streams[i] == st) + break; + } + if (i == client->nstreams) + goto cleanup; + + if (client->nstreams > 1) { + memmove(client->streams + i, + client->streams + i + 1, + sizeof(*client->streams) * + (client->nstreams - (i + 1))); + VIR_SHRINK_N(client->streams, client->nstreams, 1); + } else { + VIR_FREE(client->streams); + client->nstreams = 0; + } + virNetClientStreamFree(st); + +cleanup: + virNetClientUnlock(client); +} + + +const char *virNetClientLocalAddrString(virNetClientPtr client) +{ + return virNetSocketLocalAddrString(client->sock); +} + +const char *virNetClientRemoteAddrString(virNetClientPtr client) +{ + return virNetSocketRemoteAddrString(client->sock); +} + +int virNetClientGetTLSKeySize(virNetClientPtr client) +{ + int ret = 0; + virNetClientLock(client); + if (client->tls) + ret = virNetTLSSessionGetKeySize(client->tls); + virNetClientUnlock(client); + return ret; +} + +static int +virNetClientCallDispatchReply(virNetClientPtr client) +{ + virNetClientCallPtr thecall; + + /* Ok, definitely got an RPC reply now find + out who's been waiting for it */ + thecall = client->waitDispatch; + while (thecall && + !(thecall->msg->header.prog == client->msg.header.prog && + thecall->msg->header.vers == client->msg.header.vers && + thecall->msg->header.serial == client->msg.header.serial)) + thecall = thecall->next; + + if (!thecall) { + virNetError(VIR_ERR_RPC, + _("no call waiting for reply with prog %d vers %d serial %d"), + client->msg.header.prog, client->msg.header.vers, client->msg.header.serial); + return -1; + } + + memcpy(thecall->msg->buffer, client->msg.buffer, sizeof(client->msg.buffer)); + memcpy(&thecall->msg->header, &client->msg.header, sizeof(client->msg.header)); + thecall->msg->bufferLength = client->msg.bufferLength; + thecall->msg->bufferOffset = client->msg.bufferOffset; + + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + + return 0; +} + +static int virNetClientCallDispatchMessage(virNetClientPtr client) +{ + size_t i; + virNetClientProgramPtr prog = NULL; + + for (i = 0 ; i < client->nprograms ; i++) { + if (virNetClientProgramMatches(client->programs[i], + &client->msg)) { + prog = client->programs[i]; + break; + } + } + if (!prog) { + VIR_DEBUG("No program found for event with prog=%d vers=%d", + client->msg.header.prog, client->msg.header.vers); + return -1; + } + + virNetClientProgramDispatch(prog, client, &client->msg); + + return 0; +} + +static int virNetClientCallDispatchStream(virNetClientPtr client) +{ + size_t i; + virNetClientStreamPtr st = NULL; + virNetClientCallPtr thecall; + + /* First identify what stream this packet is directed at */ + for (i = 0 ; i < client->nstreams ; i++) { + if (virNetClientStreamMatches(client->streams[i], + &client->msg)) { + st = client->streams[i]; + break; + } + } + if (!st) { + VIR_DEBUG("No stream found for packet with prog=%d vers=%d serial=%u proc=%u", + client->msg.header.prog, client->msg.header.vers, + client->msg.header.serial, client->msg.header.proc); + return -1; + } + + /* Finish/Abort are synchronous, so also see if there's an + * (optional) call waiting for this stream packet */ + thecall = client->waitDispatch; + while (thecall && + !(thecall->msg->header.prog == client->msg.header.prog && + thecall->msg->header.vers == client->msg.header.vers && + thecall->msg->header.serial == client->msg.header.serial)) + thecall = thecall->next; + + VIR_DEBUG("Found call %p", thecall); + + /* Status is either + * - REMOTE_OK - no payload for streams + * - REMOTE_ERROR - followed by a remote_error struct + * - REMOTE_CONTINUE - followed by a raw data packet + */ + switch (client->msg.header.status) { + case VIR_NET_CONTINUE: { + if (virNetClientStreamQueuePacket(st, &client->msg) < 0) + return -1; + + if (thecall && thecall->expectReply) { + VIR_DEBUG("Got sync data packet completion"); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } else { + // XXX + //remoteStreamEventTimerUpdate(privst); + } + return 0; + } + + case VIR_NET_OK: + if (thecall && thecall->expectReply) { + VIR_DEBUG("Got a synchronous confirm"); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } else { + VIR_DEBUG("Got unexpected async stream finish confirmation"); + return -1; + } + return 0; + + case VIR_NET_ERROR: + /* No call, so queue the error against the stream */ + if (virNetClientStreamSetError(st, &client->msg) < 0) + return -1; + + if (thecall && thecall->expectReply) { + VIR_DEBUG("Got a synchronous error"); + /* Raise error now, so that this call will see it immediately */ + virNetClientStreamRaiseError(st); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } + return 0; + + default: + VIR_WARN("Stream with unexpected serial=%d, proc=%d, status=%d", + client->msg.header.serial, client->msg.header.proc, + client->msg.header.status); + return -1; + } + + return 0; +} + + +static int +virNetClientCallDispatch(virNetClientPtr client) +{ + if (virNetMessageDecodeHeader(&client->msg) < 0) + return -1; + + VIR_DEBUG("Incoming message prog %d vers %d proc %d type %d status %d serial %d", + client->msg.header.prog, client->msg.header.vers, + client->msg.header.proc, client->msg.header.type, + client->msg.header.status, client->msg.header.serial); + + switch (client->msg.header.type) { + case VIR_NET_REPLY: /* Normal RPC replies */ + return virNetClientCallDispatchReply(client); + + case VIR_NET_MESSAGE: /* Async notifications */ + return virNetClientCallDispatchMessage(client); + + case VIR_NET_STREAM: /* Stream protocol */ + return virNetClientCallDispatchStream(client); + + default: + virNetError(VIR_ERR_RPC, + _("got unexpected RPC call prog %d vers %d proc %d type %d"), + client->msg.header.prog, client->msg.header.vers, + client->msg.header.proc, client->msg.header.type); + return -1; + } +} + + +static ssize_t +virNetClientIOWriteMessage(virNetClientPtr client, + virNetClientCallPtr thecall) +{ + ssize_t ret; + + ret = virNetSocketWrite(client->sock, + thecall->msg->buffer + thecall->msg->bufferOffset, + thecall->msg->bufferLength - thecall->msg->bufferOffset); + if (ret <= 0) + return ret; + + thecall->msg->bufferOffset += ret; + + if (thecall->msg->bufferOffset == thecall->msg->bufferLength) { + thecall->msg->bufferOffset = thecall->msg->bufferLength = 0; + if (thecall->expectReply) + thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX; + else + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } + + return ret; +} + + +static ssize_t +virNetClientIOHandleOutput(virNetClientPtr client) +{ + virNetClientCallPtr thecall = client->waitDispatch; + + while (thecall && + thecall->mode != VIR_NET_CLIENT_MODE_WAIT_TX) + thecall = thecall->next; + + if (!thecall) + return -1; /* Shouldn't happen, but you never know... */ + + while (thecall) { + ssize_t ret = virNetClientIOWriteMessage(client, thecall); + if (ret < 0) + return ret; + + if (thecall->mode == VIR_NET_CLIENT_MODE_WAIT_TX) + return 0; /* Blocking write, to back to event loop */ + + thecall = thecall->next; + } + + return 0; /* No more calls to send, all done */ +} + +static ssize_t +virNetClientIOReadMessage(virNetClientPtr client) +{ + size_t wantData; + ssize_t ret; + + /* Start by reading length word */ + if (client->msg.bufferLength == 0) + client->msg.bufferLength = 4; + + wantData = client->msg.bufferLength - client->msg.bufferOffset; + + ret = virNetSocketRead(client->sock, + client->msg.buffer + client->msg.bufferOffset, + wantData); + if (ret <= 0) + return ret; + + client->msg.bufferOffset += ret; + + return ret; +} + + +static ssize_t +virNetClientIOHandleInput(virNetClientPtr client) +{ + /* Read as much data as is available, until we get + * EAGAIN + */ + for (;;) { + ssize_t ret = virNetClientIOReadMessage(client); + + if (ret < 0) + return -1; + if (ret == 0) + return 0; /* Blocking on read */ + + /* Check for completion of our goal */ + if (client->msg.bufferOffset == client->msg.bufferLength) { + if (client->msg.bufferOffset == 4) { + ret = virNetMessageDecodeLength(&client->msg); + if (ret < 0) + return -1; + + /* + * We'll carry on around the loop to immediately + * process the message body, because it has probably + * already arrived. Worst case, we'll get EAGAIN on + * next iteration. + */ + } else { + ret = virNetClientCallDispatch(client); + client->msg.bufferOffset = client->msg.bufferLength = 0; + /* + * We've completed one call, but we don't want to + * spin around the loop forever if there are many + * incoming async events, or replies for other + * thread's RPC calls. We want to get out & let + * any other thread take over as soon as we've + * got our reply. When SASL is active though, we + * may have read more data off the wire than we + * initially wanted & cached it in memory. In this + * case, poll() would not detect that there is more + * ready todo. + * + * So if SASL is active *and* some SASL data is + * already cached, then we'll process that now, + * before returning. + */ + if (ret == 0 && + virNetSocketHasCachedData(client->sock)) + continue; + return ret; + } + } + } +} + + +/* + * Process all calls pending dispatch/receive until we + * get a reply to our own call. Then quit and pass the buck + * to someone else. + */ +static int virNetClientIOEventLoop(virNetClientPtr client, + virNetClientCallPtr thiscall) +{ + struct pollfd fds[2]; + int ret; + + fds[0].fd = virNetSocketGetFD(client->sock); + fds[1].fd = client->wakeupReadFD; + + for (;;) { + virNetClientCallPtr tmp = client->waitDispatch; + virNetClientCallPtr prev; + char ignore; +#ifdef HAVE_PTHREAD_SIGMASK + sigset_t oldmask, blockedsigs; +#endif + int timeout = -1; + + /* If we have existing SASL decoded data we + * don't want to sleep in the poll(), just + * check if any other FDs are also ready + */ + if (virNetSocketHasCachedData(client->sock)) + timeout = 0; + + fds[0].events = fds[0].revents = 0; + fds[1].events = fds[1].revents = 0; + + fds[1].events = POLLIN; + while (tmp) { + if (tmp->mode == VIR_NET_CLIENT_MODE_WAIT_RX) + fds[0].events |= POLLIN; + if (tmp->mode == VIR_NET_CLIENT_MODE_WAIT_TX) + fds[0].events |= POLLOUT; + + tmp = tmp->next; + } + + /* We have to be prepared to receive stream data + * regardless of whether any of the calls waiting + * for dispatch are for streams. + */ + if (client->nstreams) + fds[0].events |= POLLIN; + + /* Release lock while poll'ing so other threads + * can stuff themselves on the queue */ + virNetClientUnlock(client); + + /* Block SIGWINCH from interrupting poll in curses programs, + * then restore the original signal mask again immediately + * after the call (RHBZ#567931). Same for SIGCHLD and SIGPIPE + * at the suggestion of Paolo Bonzini and Daniel Berrange. + */ +#ifdef HAVE_PTHREAD_SIGMASK + sigemptyset (&blockedsigs); + sigaddset (&blockedsigs, SIGWINCH); + sigaddset (&blockedsigs, SIGCHLD); + sigaddset (&blockedsigs, SIGPIPE); + ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); +#endif + + repoll: + ret = poll(fds, ARRAY_CARDINALITY(fds), timeout); + if (ret < 0 && errno == EAGAIN) + goto repoll; + +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_SETMASK, &oldmask, NULL)); +#endif + + virNetClientLock(client); + + /* If we have existing SASL decoded data, pretend + * the socket became readable so we consume it + */ + if (virNetSocketHasCachedData(client->sock)) + fds[0].revents |= POLLIN; + + if (fds[1].revents) { + VIR_DEBUG("Woken up from poll by other thread"); + if (saferead(client->wakeupReadFD, &ignore, sizeof(ignore)) != sizeof(ignore)) { + virReportSystemError(errno, "%s", + _("read on wakeup fd failed")); + goto error; + } + } + + if (ret < 0) { + if (errno == EWOULDBLOCK) + continue; + virReportSystemError(errno, + "%s", _("poll on socket failed")); + goto error; + } + + if (fds[0].revents & POLLOUT) { + if (virNetClientIOHandleOutput(client) < 0) + goto error; + } + + if (fds[0].revents & POLLIN) { + if (virNetClientIOHandleInput(client) < 0) + goto error; + } + + /* Iterate through waiting threads and if + * any are complete then tell 'em to wakeup + */ + tmp = client->waitDispatch; + prev = NULL; + while (tmp) { + if (tmp != thiscall && + tmp->mode == VIR_NET_CLIENT_MODE_COMPLETE) { + /* Take them out of the list */ + if (prev) + prev->next = tmp->next; + else + client->waitDispatch = tmp->next; + + /* And wake them up.... + * ...they won't actually wakeup until + * we release our mutex a short while + * later... + */ + VIR_DEBUG("Waking up sleep %p %p", tmp, client->waitDispatch); + virCondSignal(&tmp->cond); + } else { + prev = tmp; + } + tmp = tmp->next; + } + + /* Now see if *we* are done */ + if (thiscall->mode == VIR_NET_CLIENT_MODE_COMPLETE) { + /* We're at head of the list already, so + * remove us + */ + client->waitDispatch = thiscall->next; + VIR_DEBUG("Giving up the buck %p %p", thiscall, client->waitDispatch); + /* See if someone else is still waiting + * and if so, then pass the buck ! */ + if (client->waitDispatch) { + VIR_DEBUG("Passing the buck to %p", client->waitDispatch); + virCondSignal(&client->waitDispatch->cond); + } + return 0; + } + + + if (fds[0].revents & (POLLHUP | POLLERR)) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("received hangup / error event on socket")); + goto error; + } + } + + +error: + client->waitDispatch = thiscall->next; + VIR_DEBUG("Giving up the buck due to I/O error %p %p", thiscall, client->waitDispatch); + /* See if someone else is still waiting + * and if so, then pass the buck ! */ + if (client->waitDispatch) { + VIR_DEBUG("Passing the buck to %p", client->waitDispatch); + virCondSignal(&client->waitDispatch->cond); + } + return -1; +} + + +/* + * This function sends a message to remote server and awaits a reply + * + * NB. This does not free the args structure (not desirable, since you + * often want this allocated on the stack or else it contains strings + * which come from the user). It does however free any intermediate + * results, eg. the error structure if there is one. + * + * NB(2). Make sure to memset (&ret, 0, sizeof ret) before calling, + * else Bad Things will happen in the XDR code. + * + * NB(3) You must have the client lock before calling this + * + * NB(4) This is very complicated. Multiple threads are allowed to + * use the client for RPC at the same time. Obviously only one of + * them can. So if someone's using the socket, other threads are put + * to sleep on condition variables. The existing thread may completely + * send & receive their RPC call/reply while they're asleep. Or it + * may only get around to dealing with sending the call. Or it may + * get around to neither. So upon waking up from slumber, the other + * thread may or may not have more work todo. + * + * We call this dance 'passing the buck' + * + * http://en.wikipedia.org/wiki/Passing_the_buck + * + * "Buck passing or passing the buck is the action of transferring + * responsibility or blame unto another person. It is also used as + * a strategy in power politics when the actions of one country/ + * nation are blamed on another, providing an opportunity for war." + * + * NB(5) Don't Panic! + */ +static int virNetClientIO(virNetClientPtr client, + virNetClientCallPtr thiscall) +{ + int rv = -1; + + VIR_DEBUG("program=%u version=%u serial=%u proc=%d type=%d length=%zu dispatch=%p", + thiscall->msg->header.prog, + thiscall->msg->header.vers, + thiscall->msg->header.serial, + thiscall->msg->header.proc, + thiscall->msg->header.type, + thiscall->msg->bufferLength, + client->waitDispatch); + + /* Check to see if another thread is dispatching */ + if (client->waitDispatch) { + /* Stick ourselves on the end of the wait queue */ + virNetClientCallPtr tmp = client->waitDispatch; + char ignore = 1; + while (tmp && tmp->next) + tmp = tmp->next; + if (tmp) + tmp->next = thiscall; + else + client->waitDispatch = thiscall; + + /* Force other thread to wakeup from poll */ + if (safewrite(client->wakeupSendFD, &ignore, sizeof(ignore)) != sizeof(ignore)) { + if (tmp) + tmp->next = NULL; + else + client->waitDispatch = NULL; + virReportSystemError(errno, "%s", + _("failed to wake up polling thread")); + return -1; + } + + VIR_DEBUG("Going to sleep %p %p", client->waitDispatch, thiscall); + /* Go to sleep while other thread is working... */ + if (virCondWait(&thiscall->cond, &client->lock) < 0) { + if (client->waitDispatch == thiscall) { + client->waitDispatch = thiscall->next; + } else { + tmp = client->waitDispatch; + while (tmp && tmp->next && + tmp->next != thiscall) { + tmp = tmp->next; + } + if (tmp && tmp->next == thiscall) + tmp->next = thiscall->next; + } + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("failed to wait on condition")); + return -1; + } + + VIR_DEBUG("Wokeup from sleep %p %p", client->waitDispatch, thiscall); + /* Two reasons we can be woken up + * 1. Other thread has got our reply ready for us + * 2. Other thread is all done, and it is our turn to + * be the dispatcher to finish waiting for + * our reply + */ + if (thiscall->mode == VIR_NET_CLIENT_MODE_COMPLETE) { + rv = 0; + /* + * We avoided catching the buck and our reply is ready ! + * We've already had 'thiscall' removed from the list + * so just need to (maybe) handle errors & free it + */ + goto cleanup; + } + + /* Grr, someone passed the buck onto us ... */ + + } else { + /* We're first to catch the buck */ + client->waitDispatch = thiscall; + } + + VIR_DEBUG("We have the buck %p %p", client->waitDispatch, thiscall); + /* + * The buck stops here! + * + * At this point we're about to own the dispatch + * process... + */ + + /* + * Avoid needless wake-ups of the event loop in the + * case where this call is being made from a different + * thread than the event loop. These wake-ups would + * cause the event loop thread to be blocked on the + * mutex for the duration of the call + */ + virNetSocketUpdateIOCallback(client->sock, 0); + + rv = virNetClientIOEventLoop(client, thiscall); + + virNetSocketUpdateIOCallback(client->sock, VIR_EVENT_HANDLE_READABLE); + +cleanup: + VIR_DEBUG("All done with our call %p %p %d", client->waitDispatch, thiscall, rv); + return rv; +} + + +void virNetClientIncomingEvent(virNetSocketPtr sock, + int events, + void *opaque) +{ + virNetClientPtr client = opaque; + + virNetClientLock(client); + + /* This should be impossible, but it doesn't hurt to check */ + if (client->waitDispatch) + goto done; + + VIR_DEBUG("Event fired %p %d", sock, events); + + if (events & (VIR_EVENT_HANDLE_HANGUP | VIR_EVENT_HANDLE_ERROR)) { + VIR_DEBUG("%s : VIR_EVENT_HANDLE_HANGUP or " + "VIR_EVENT_HANDLE_ERROR encountered", __FUNCTION__); + virNetSocketRemoveIOCallback(sock); + goto done; + } + + if (virNetClientIOHandleInput(client) < 0) + VIR_DEBUG("Something went wrong during async message processing"); + +done: + virNetClientUnlock(client); +} + + +int virNetClientSend(virNetClientPtr client, + virNetMessagePtr msg, + bool expectReply) +{ + virNetClientCallPtr call; + int ret = -1; + + if (VIR_ALLOC(call) < 0) { + virReportOOMError(); + return -1; + } + + virNetClientLock(client); + + if (virCondInit(&call->cond) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("cannot initialize condition variable")); + goto cleanup; + } + + if (msg->bufferLength) + call->mode = VIR_NET_CLIENT_MODE_WAIT_TX; + else + call->mode = VIR_NET_CLIENT_MODE_WAIT_RX; + call->msg = msg; + call->expectReply = expectReply; + + ret = virNetClientIO(client, call); + +cleanup: + ignore_value(virCondDestroy(&call->cond)); + VIR_FREE(call); + virNetClientUnlock(client); + return ret; +} diff --git a/src/rpc/virnetclient.h b/src/rpc/virnetclient.h new file mode 100644 index 0000000000..de0782c240 --- /dev/null +++ b/src/rpc/virnetclient.h @@ -0,0 +1,84 @@ +/* + * virnetclient.h: generic network RPC client + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange + */ + +#ifndef __VIR_NET_CLIENT_H__ +# define __VIR_NET_CLIENT_H__ + +# include "virnettlscontext.h" +# include "virnetmessage.h" +# ifdef HAVE_SASL +# include "virnetsaslcontext.h" +# endif +# include "virnetclientprogram.h" +# include "virnetclientstream.h" + + +virNetClientPtr virNetClientNewUNIX(const char *path, + bool spawnDaemon, + const char *daemon); + +virNetClientPtr virNetClientNewTCP(const char *nodename, + const char *service); + +virNetClientPtr virNetClientNewSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path); + +virNetClientPtr virNetClientNewExternal(const char **cmdargv); + +void virNetClientRef(virNetClientPtr client); + +int virNetClientAddProgram(virNetClientPtr client, + virNetClientProgramPtr prog); + +int virNetClientAddStream(virNetClientPtr client, + virNetClientStreamPtr st); + +void virNetClientRemoveStream(virNetClientPtr client, + virNetClientStreamPtr st); + +int virNetClientSend(virNetClientPtr client, + virNetMessagePtr msg, + bool expectReply); + +# ifdef HAVE_SASL +void virNetClientSetSASLSession(virNetClientPtr client, + virNetSASLSessionPtr sasl); +# endif + +int virNetClientSetTLSSession(virNetClientPtr client, + virNetTLSContextPtr tls); + +bool virNetClientIsEncrypted(virNetClientPtr client); + +const char *virNetClientLocalAddrString(virNetClientPtr client); +const char *virNetClientRemoteAddrString(virNetClientPtr client); + +int virNetClientGetTLSKeySize(virNetClientPtr client); + +void virNetClientFree(virNetClientPtr client); + +#endif /* __VIR_NET_CLIENT_H__ */ diff --git a/src/rpc/virnetclientprogram.c b/src/rpc/virnetclientprogram.c new file mode 100644 index 0000000000..8414ad86e0 --- /dev/null +++ b/src/rpc/virnetclientprogram.c @@ -0,0 +1,339 @@ +/* + * virnetclientprogram.c: generic network RPC client program + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange + */ + +#include + +#include "virnetclientprogram.h" +#include "virnetclient.h" +#include "virnetprotocol.h" + +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetClientProgram { + int refs; + + unsigned program; + unsigned version; + virNetClientProgramEventPtr events; + size_t nevents; + void *eventOpaque; +}; + +virNetClientProgramPtr virNetClientProgramNew(unsigned program, + unsigned version, + virNetClientProgramEventPtr events, + size_t nevents, + void *eventOpaque) +{ + virNetClientProgramPtr prog; + + if (VIR_ALLOC(prog) < 0) { + virReportOOMError(); + return NULL; + } + + prog->refs = 1; + prog->program = program; + prog->version = version; + prog->events = events; + prog->nevents = nevents; + prog->eventOpaque = eventOpaque; + + return prog; +} + + +void virNetClientProgramRef(virNetClientProgramPtr prog) +{ + prog->refs++; +} + + +void virNetClientProgramFree(virNetClientProgramPtr prog) +{ + if (!prog) + return; + + prog->refs--; + if (prog->refs > 0) + return; + + VIR_FREE(prog); +} + + +unsigned virNetClientProgramGetProgram(virNetClientProgramPtr prog) +{ + return prog->program; +} + + +unsigned virNetClientProgramGetVersion(virNetClientProgramPtr prog) +{ + return prog->version; +} + + +int virNetClientProgramMatches(virNetClientProgramPtr prog, + virNetMessagePtr msg) +{ + if (prog->program == msg->header.prog && + prog->version == msg->header.vers) + return 1; + return 0; +} + + +static int +virNetClientProgramDispatchError(virNetClientProgramPtr prog ATTRIBUTE_UNUSED, + virNetMessagePtr msg) +{ + virNetMessageError err; + int ret = -1; + + memset(&err, 0, sizeof(err)); + + if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) + goto cleanup; + + /* Interop for virErrorNumber glitch in 0.8.0, if server is + * 0.7.1 through 0.7.7; see comments in virterror.h. */ + switch (err.code) { + case VIR_WAR_NO_NWFILTER: + /* no way to tell old VIR_WAR_NO_SECRET apart from + * VIR_WAR_NO_NWFILTER, but both are very similar + * warnings, so ignore the difference */ + break; + case VIR_ERR_INVALID_NWFILTER: + case VIR_ERR_NO_NWFILTER: + case VIR_ERR_BUILD_FIREWALL: + /* server was trying to pass VIR_ERR_INVALID_SECRET, + * VIR_ERR_NO_SECRET, or VIR_ERR_CONFIG_UNSUPPORTED */ + if (err.domain != VIR_FROM_NWFILTER) + err.code += 4; + break; + case VIR_WAR_NO_SECRET: + if (err.domain == VIR_FROM_QEMU) + err.code = VIR_ERR_OPERATION_TIMEOUT; + break; + case VIR_ERR_INVALID_SECRET: + if (err.domain == VIR_FROM_XEN) + err.code = VIR_ERR_MIGRATE_PERSIST_FAILED; + break; + default: + /* Nothing to alter. */ + break; + } + + if (err.domain == VIR_FROM_REMOTE && + err.code == VIR_ERR_RPC && + err.level == VIR_ERR_ERROR && + err.message && + STRPREFIX(*err.message, "unknown procedure")) { + virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__, + err.domain, + VIR_ERR_NO_SUPPORT, + err.level, + err.str1 ? *err.str1 : NULL, + err.str2 ? *err.str2 : NULL, + err.str3 ? *err.str3 : NULL, + err.int1, + err.int2, + "%s", *err.message); + } else { + virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__, + err.domain, + err.code, + err.level, + err.str1 ? *err.str1 : NULL, + err.str2 ? *err.str2 : NULL, + err.str3 ? *err.str3 : NULL, + err.int1, + err.int2, + "%s", err.message ? *err.message : _("Unknown error")); + } + + ret = 0; + +cleanup: + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); + return ret; +} + + +static virNetClientProgramEventPtr virNetClientProgramGetEvent(virNetClientProgramPtr prog, + int procedure) +{ + int i; + + for (i = 0 ; i < prog->nevents ; i++) { + if (prog->events[i].proc == procedure) + return &prog->events[i]; + } + + return NULL; +} + + +int virNetClientProgramDispatch(virNetClientProgramPtr prog, + virNetClientPtr client, + virNetMessagePtr msg) +{ + virNetClientProgramEventPtr event; + char *evdata; + + VIR_DEBUG("prog=%d ver=%d type=%d status=%d serial=%d proc=%d", + msg->header.prog, msg->header.vers, msg->header.type, + msg->header.status, msg->header.serial, msg->header.proc); + + /* Check version, etc. */ + if (msg->header.prog != prog->program) { + VIR_ERROR(_("program mismatch in event (actual %x, expected %x)"), + msg->header.prog, prog->program); + return -1; + } + + if (msg->header.vers != prog->version) { + VIR_ERROR(_("version mismatch in event (actual %x, expected %x)"), + msg->header.vers, prog->version); + return -1; + } + + if (msg->header.status != VIR_NET_OK) { + VIR_ERROR(_("status mismatch in event (actual %x, expected %x)"), + msg->header.status, VIR_NET_OK); + return -1; + } + + if (msg->header.type != VIR_NET_MESSAGE) { + VIR_ERROR(_("type mismatch in event (actual %x, expected %x)"), + msg->header.type, VIR_NET_MESSAGE); + return -1; + } + + event = virNetClientProgramGetEvent(prog, msg->header.proc); + + if (!event) { + VIR_ERROR(_("No event expected with procedure %x"), + msg->header.proc); + return -1; + } + + if (VIR_ALLOC_N(evdata, event->msg_len) < 0) { + virReportOOMError(); + return -1; + } + + if (virNetMessageDecodePayload(msg, event->msg_filter, evdata) < 0) + goto cleanup; + + event->func(prog, client, evdata, prog->eventOpaque); + + xdr_free(event->msg_filter, evdata); + +cleanup: + VIR_FREE(evdata); + return 0; +} + + +int virNetClientProgramCall(virNetClientProgramPtr prog, + virNetClientPtr client, + unsigned serial, + int proc, + xdrproc_t args_filter, void *args, + xdrproc_t ret_filter, void *ret) +{ + virNetMessagePtr msg; + + if (!(msg = virNetMessageNew())) + return -1; + + msg->header.prog = prog->program; + msg->header.vers = prog->version; + msg->header.status = VIR_NET_OK; + msg->header.type = VIR_NET_CALL; + msg->header.serial = serial; + msg->header.proc = proc; + + if (virNetMessageEncodeHeader(msg) < 0) + goto error; + + if (virNetMessageEncodePayload(msg, args_filter, args) < 0) + goto error; + + if (virNetClientSend(client, msg, true) < 0) + goto error; + + /* None of these 3 should ever happen here, because + * virNetClientSend should have validated the reply, + * but it doesn't hurt to check again. + */ + if (msg->header.type != VIR_NET_REPLY) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message type %d"), msg->header.type); + goto error; + } + if (msg->header.proc != proc) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message proc %d != %d"), + msg->header.proc, proc); + goto error; + } + if (msg->header.serial != serial) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message serial %d != %d"), + msg->header.serial, serial); + goto error; + } + + switch (msg->header.status) { + case VIR_NET_OK: + if (virNetMessageDecodePayload(msg, ret_filter, ret) < 0) + goto error; + break; + + case VIR_NET_ERROR: + virNetClientProgramDispatchError(prog, msg); + goto error; + + default: + virNetError(VIR_ERR_RPC, + _("Unexpected message status %d"), msg->header.status); + goto error; + } + + VIR_FREE(msg); + + return 0; + +error: + VIR_FREE(msg); + return -1; +} diff --git a/src/rpc/virnetclientprogram.h b/src/rpc/virnetclientprogram.h new file mode 100644 index 0000000000..82ae2c66fb --- /dev/null +++ b/src/rpc/virnetclientprogram.h @@ -0,0 +1,85 @@ +/* + * virnetclientprogram.h: generic network RPC client program + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange + */ + +#ifndef __VIR_NET_CLIENT_PROGRAM_H__ +# define __VIR_NET_CLIENT_PROGRAM_H__ + +# include +# include + +# include "virnetmessage.h" + +typedef struct _virNetClient virNetClient; +typedef virNetClient *virNetClientPtr; + +typedef struct _virNetClientProgram virNetClientProgram; +typedef virNetClientProgram *virNetClientProgramPtr; + +typedef struct _virNetClientProgramEvent virNetClientProgramEvent; +typedef virNetClientProgramEvent *virNetClientProgramEventPtr; + +typedef struct _virNetClientProgramErrorHandler virNetClientProgramErrorHander; +typedef virNetClientProgramErrorHander *virNetClientProgramErrorHanderPtr; + + +typedef void (*virNetClientProgramDispatchFunc)(virNetClientProgramPtr prog, + virNetClientPtr client, + void *msg, + void *opaque); + +struct _virNetClientProgramEvent { + int proc; + virNetClientProgramDispatchFunc func; + size_t msg_len; + xdrproc_t msg_filter; +}; + +virNetClientProgramPtr virNetClientProgramNew(unsigned program, + unsigned version, + virNetClientProgramEventPtr events, + size_t nevents, + void *eventOpaque); + +unsigned virNetClientProgramGetProgram(virNetClientProgramPtr prog); +unsigned virNetClientProgramGetVersion(virNetClientProgramPtr prog); + +void virNetClientProgramRef(virNetClientProgramPtr prog); + +void virNetClientProgramFree(virNetClientProgramPtr prog); + +int virNetClientProgramMatches(virNetClientProgramPtr prog, + virNetMessagePtr msg); + +int virNetClientProgramDispatch(virNetClientProgramPtr prog, + virNetClientPtr client, + virNetMessagePtr msg); + +int virNetClientProgramCall(virNetClientProgramPtr prog, + virNetClientPtr client, + unsigned serial, + int proc, + xdrproc_t args_filter, void *args, + xdrproc_t ret_filter, void *ret); + + + +#endif /* __VIR_NET_CLIENT_PROGRAM_H__ */ diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c new file mode 100644 index 0000000000..44c9acfe49 --- /dev/null +++ b/src/rpc/virnetclientstream.c @@ -0,0 +1,442 @@ +/* + * virnetclientstream.c: generic network RPC client stream + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange + */ + +#include + +#include "virnetclientstream.h" +#include "virnetclient.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" +#include "event.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetClientStream { + virNetClientProgramPtr prog; + int proc; + unsigned serial; + int refs; + + virError err; + + /* XXX this buffer is unbounded if the client + * app has domain events registered, since packets + * may be read off wire, while app isn't ready to + * recv them. Figure out how to address this some + * time by stopping consuming any incoming data + * off the socket.... + */ + char *incoming; + size_t incomingOffset; + size_t incomingLength; + + + virNetClientStreamEventCallback cb; + void *cbOpaque; + virFreeCallback cbFree; + int cbEvents; + int cbTimer; + int cbDispatch; +}; + + +static void +virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st) +{ + if (!st->cb) + return; + + VIR_DEBUG("Check timer offset=%zu %d", st->incomingOffset, st->cbEvents); + + if ((st->incomingOffset && + (st->cbEvents & VIR_STREAM_EVENT_READABLE)) || + (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) { + VIR_DEBUG("Enabling event timer"); + virEventUpdateTimeout(st->cbTimer, 0); + } else { + VIR_DEBUG("Disabling event timer"); + virEventUpdateTimeout(st->cbTimer, -1); + } +} + + +static void +virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque) +{ + virNetClientStreamPtr st = opaque; + int events = 0; + + /* XXX we need a mutex on 'st' to protect this callback */ + + if (st->cb && + (st->cbEvents & VIR_STREAM_EVENT_READABLE) && + st->incomingOffset) + events |= VIR_STREAM_EVENT_READABLE; + if (st->cb && + (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) + events |= VIR_STREAM_EVENT_WRITABLE; + + VIR_DEBUG("Got Timer dispatch %d %d offset=%zu", events, st->cbEvents, st->incomingOffset); + if (events) { + virNetClientStreamEventCallback cb = st->cb; + void *cbOpaque = st->cbOpaque; + virFreeCallback cbFree = st->cbFree; + + st->cbDispatch = 1; + (cb)(st, events, cbOpaque); + st->cbDispatch = 0; + + if (!st->cb && cbFree) + (cbFree)(cbOpaque); + } +} + + +static void +virNetClientStreamEventTimerFree(void *opaque) +{ + virNetClientStreamPtr st = opaque; + virNetClientStreamFree(st); +} + + +virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog, + int proc, + unsigned serial) +{ + virNetClientStreamPtr st; + + if (VIR_ALLOC(st) < 0) { + virReportOOMError(); + return NULL; + } + + virNetClientProgramRef(prog); + + st->refs = 1; + st->prog = prog; + st->proc = proc; + st->serial = serial; + + return st; +} + + +void virNetClientStreamRef(virNetClientStreamPtr st) +{ + st->refs++; +} + +void virNetClientStreamFree(virNetClientStreamPtr st) +{ + st->refs--; + if (st->refs > 0) + return; + + virResetError(&st->err); + VIR_FREE(st->incoming); + virNetClientProgramFree(st->prog); + VIR_FREE(st); +} + +bool virNetClientStreamMatches(virNetClientStreamPtr st, + virNetMessagePtr msg) +{ + if (virNetClientProgramMatches(st->prog, msg) && + st->proc == msg->header.proc && + st->serial == msg->header.serial) + return 1; + return 0; +} + + +bool virNetClientStreamRaiseError(virNetClientStreamPtr st) +{ + if (st->err.code == VIR_ERR_OK) + return false; + + virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__, + st->err.domain, + st->err.code, + st->err.level, + st->err.str1, + st->err.str2, + st->err.str3, + st->err.int1, + st->err.int2, + "%s", st->err.message ? st->err.message : _("Unknown error")); + + return true; +} + + +int virNetClientStreamSetError(virNetClientStreamPtr st, + virNetMessagePtr msg) +{ + virNetMessageError err; + int ret = -1; + + if (st->err.code != VIR_ERR_OK) + VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message)); + + virResetError(&st->err); + memset(&err, 0, sizeof(err)); + + if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) + goto cleanup; + + if (err.domain == VIR_FROM_REMOTE && + err.code == VIR_ERR_RPC && + err.level == VIR_ERR_ERROR && + err.message && + STRPREFIX(*err.message, "unknown procedure")) { + st->err.code = VIR_ERR_NO_SUPPORT; + } else { + st->err.code = err.code; + } + st->err.message = *err.message; + *err.message = NULL; + st->err.domain = err.domain; + st->err.level = err.level; + st->err.str1 = *err.str1; + st->err.str2 = *err.str2; + st->err.str3 = *err.str3; + st->err.int1 = err.int1; + st->err.int2 = err.int2; + + ret = 0; + +cleanup: + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); + return ret; +} + + +int virNetClientStreamQueuePacket(virNetClientStreamPtr st, + virNetMessagePtr msg) +{ + size_t avail = st->incomingLength - st->incomingOffset; + size_t need = msg->bufferLength - msg->bufferOffset; + + if (need > avail) { + size_t extra = need - avail; + if (VIR_REALLOC_N(st->incoming, + st->incomingLength + extra) < 0) { + VIR_DEBUG("Out of memory handling stream data"); + return -1; + } + st->incomingLength += extra; + } + + memcpy(st->incoming + st->incomingOffset, + msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset); + st->incomingOffset += (msg->bufferLength - msg->bufferOffset); + + VIR_DEBUG("Stream incoming data offset %zu length %zu", + st->incomingOffset, st->incomingLength); + return 0; +} + + +int virNetClientStreamSendPacket(virNetClientStreamPtr st, + virNetClientPtr client, + int status, + const char *data, + size_t nbytes) +{ + virNetMessagePtr msg; + bool wantReply; + VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes); + + if (!(msg = virNetMessageNew())) + return -1; + + msg->header.prog = virNetClientProgramGetProgram(st->prog); + msg->header.vers = virNetClientProgramGetVersion(st->prog); + msg->header.status = status; + msg->header.type = VIR_NET_STREAM; + msg->header.serial = st->serial; + msg->header.proc = st->proc; + + if (virNetMessageEncodeHeader(msg) < 0) + goto error; + + /* Data packets are async fire&forget, but OK/ERROR packets + * need a synchronous confirmation + */ + if (status == VIR_NET_CONTINUE) { + if (virNetMessageEncodePayloadRaw(msg, data, nbytes) < 0) + goto error; + wantReply = false; + } else { + if (virNetMessageEncodePayloadRaw(msg, NULL, 0) < 0) + goto error; + wantReply = true; + } + + if (virNetClientSend(client, msg, wantReply) < 0) + goto error; + + + return nbytes; + +error: + VIR_FREE(msg); + return -1; +} + +int virNetClientStreamRecvPacket(virNetClientStreamPtr st, + virNetClientPtr client, + char *data, + size_t nbytes, + bool nonblock) +{ + int rv = -1; + VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d", + st, client, data, nbytes, nonblock); + if (!st->incomingOffset) { + virNetMessagePtr msg; + int ret; + + if (nonblock) { + VIR_DEBUG("Non-blocking mode and no data available"); + rv = -2; + goto cleanup; + } + + if (!(msg = virNetMessageNew())) { + virReportOOMError(); + goto cleanup; + } + + msg->header.prog = virNetClientProgramGetProgram(st->prog); + msg->header.vers = virNetClientProgramGetVersion(st->prog); + msg->header.type = VIR_NET_STREAM; + msg->header.serial = st->serial; + msg->header.proc = st->proc; + + VIR_DEBUG("Dummy packet to wait for stream data"); + ret = virNetClientSend(client, msg, true); + + virNetMessageFree(msg); + + if (ret < 0) + goto cleanup; + } + + VIR_DEBUG("After IO %zu", st->incomingOffset); + if (st->incomingOffset) { + int want = st->incomingOffset; + if (want > nbytes) + want = nbytes; + memcpy(data, st->incoming, want); + if (want < st->incomingOffset) { + memmove(st->incoming, st->incoming + want, st->incomingOffset - want); + st->incomingOffset -= want; + } else { + VIR_FREE(st->incoming); + st->incomingOffset = st->incomingLength = 0; + } + rv = want; + } else { + rv = 0; + } + + virNetClientStreamEventTimerUpdate(st); + +cleanup: + return rv; +} + + +int virNetClientStreamEventAddCallback(virNetClientStreamPtr st, + int events, + virNetClientStreamEventCallback cb, + void *opaque, + virFreeCallback ff) +{ + if (st->cb) { + virNetError(VIR_ERR_INTERNAL_ERROR, + "%s", _("multiple stream callbacks not supported")); + return 1; + } + + virNetClientStreamRef(st); + if ((st->cbTimer = + virEventAddTimeout(-1, + virNetClientStreamEventTimer, + st, + virNetClientStreamEventTimerFree)) < 0) { + virNetClientStreamFree(st); + return -1; + } + + st->cb = cb; + st->cbOpaque = opaque; + st->cbFree = ff; + st->cbEvents = events; + + virNetClientStreamEventTimerUpdate(st); + + return 0; +} + +int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st, + int events) +{ + if (!st->cb) { + virNetError(VIR_ERR_INTERNAL_ERROR, + "%s", _("no stream callback registered")); + return -1; + } + + st->cbEvents = events; + + virNetClientStreamEventTimerUpdate(st); + + return 0; +} + +int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st) +{ + if (!st->cb) { + virNetError(VIR_ERR_INTERNAL_ERROR, + "%s", _("no stream callback registered")); + return -1; + } + + if (!st->cbDispatch && + st->cbFree) + (st->cbFree)(st->cbOpaque); + st->cb = NULL; + st->cbOpaque = NULL; + st->cbFree = NULL; + st->cbEvents = 0; + virEventRemoveTimeout(st->cbTimer); + + return 0; +} diff --git a/src/rpc/virnetclientstream.h b/src/rpc/virnetclientstream.h new file mode 100644 index 0000000000..6c8d538099 --- /dev/null +++ b/src/rpc/virnetclientstream.h @@ -0,0 +1,76 @@ +/* + * virnetclientstream.h: generic network RPC client stream + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange + */ + +#ifndef __VIR_NET_CLIENT_STREAM_H__ +# define __VIR_NET_CLIENT_STREAM_H__ + +# include "virnetclientprogram.h" + +typedef struct _virNetClientStream virNetClientStream; +typedef virNetClientStream *virNetClientStreamPtr; + +typedef void (*virNetClientStreamEventCallback)(virNetClientStreamPtr stream, + int events, void *opaque); + +virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog, + int proc, + unsigned serial); + +void virNetClientStreamRef(virNetClientStreamPtr st); + +void virNetClientStreamFree(virNetClientStreamPtr st); + +bool virNetClientStreamRaiseError(virNetClientStreamPtr st); + +int virNetClientStreamSetError(virNetClientStreamPtr st, + virNetMessagePtr msg); + +bool virNetClientStreamMatches(virNetClientStreamPtr st, + virNetMessagePtr msg); + +int virNetClientStreamQueuePacket(virNetClientStreamPtr st, + virNetMessagePtr msg); + +int virNetClientStreamSendPacket(virNetClientStreamPtr st, + virNetClientPtr client, + int status, + const char *data, + size_t nbytes); + +int virNetClientStreamRecvPacket(virNetClientStreamPtr st, + virNetClientPtr client, + char *data, + size_t nbytes, + bool nonblock); + +int virNetClientStreamEventAddCallback(virNetClientStreamPtr st, + int events, + virNetClientStreamEventCallback cb, + void *opaque, + virFreeCallback ff); + +int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st, + int events); +int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st); + + +#endif /* __VIR_NET_CLIENT_STREAM_H__ */