libvirt/src/rpc/virnetserverclient.c

943 lines
25 KiB
C
Raw Normal View History

/*
* virnetserverclient.c: generic network RPC server client
*
* Copyright (C) 2006-2011 Red Hat, Inc.
* Copyright (C) 2006 Daniel P. Berrange
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Daniel P. Berrange <berrange@redhat.com>
*/
#include <config.h>
#if HAVE_SASL
# include <sasl/sasl.h>
#endif
#include "virnetserverclient.h"
#include "logging.h"
#include "virterror_internal.h"
#include "memory.h"
#include "threads.h"
#define VIR_FROM_THIS VIR_FROM_RPC
#define virNetError(code, ...) \
virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \
__FUNCTION__, __LINE__, __VA_ARGS__)
/* Allow for filtering of incoming messages to a custom
* dispatch processing queue, instead of the workers.
* This allows for certain types of messages to be handled
* strictly "in order"
*/
typedef struct _virNetServerClientFilter virNetServerClientFilter;
typedef virNetServerClientFilter *virNetServerClientFilterPtr;
struct _virNetServerClientFilter {
int id;
virNetServerClientFilterFunc func;
void *opaque;
virNetServerClientFilterPtr next;
};
struct _virNetServerClient
{
int refs;
bool wantClose;
virMutex lock;
virNetSocketPtr sock;
int auth;
bool readonly;
char *identity;
virNetTLSContextPtr tlsCtxt;
virNetTLSSessionPtr tls;
#if HAVE_SASL
virNetSASLSessionPtr sasl;
#endif
/* Count of messages in the 'tx' queue,
* and the server worker pool queue
* ie RPC calls in progress. Does not count
* async events which are not used for
* throttling calculations */
size_t nrequests;
size_t nrequests_max;
/* Zero or one messages being received. Zero if
* nrequests >= max_clients and throttling */
virNetMessagePtr rx;
/* Zero or many messages waiting for transmit
* back to client, including async events */
virNetMessagePtr tx;
/* Filters to capture messages that would otherwise
* end up on the 'dx' queue */
virNetServerClientFilterPtr filters;
int nextFilterID;
virNetServerClientDispatchFunc dispatchFunc;
void *dispatchOpaque;
void *privateData;
virNetServerClientFreeFunc privateDataFreeFunc;
};
static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque);
static void virNetServerClientUpdateEvent(virNetServerClientPtr client);
static void virNetServerClientLock(virNetServerClientPtr client)
{
virMutexLock(&client->lock);
}
static void virNetServerClientUnlock(virNetServerClientPtr client)
{
virMutexUnlock(&client->lock);
}
/*
* @client: a locked client object
*/
static int
virNetServerClientCalculateHandleMode(virNetServerClientPtr client) {
int mode = 0;
VIR_DEBUG("tls=%p hs=%d, rx=%p tx=%p",
client->tls,
client->tls ? virNetTLSSessionGetHandshakeStatus(client->tls) : -1,
client->rx,
client->tx);
if (!client->sock || client->wantClose)
return 0;
if (client->tls) {
switch (virNetTLSSessionGetHandshakeStatus(client->tls)) {
case VIR_NET_TLS_HANDSHAKE_RECVING:
mode |= VIR_EVENT_HANDLE_READABLE;
break;
case VIR_NET_TLS_HANDSHAKE_SENDING:
mode |= VIR_EVENT_HANDLE_WRITABLE;
break;
default:
case VIR_NET_TLS_HANDSHAKE_COMPLETE:
if (client->rx)
mode |= VIR_EVENT_HANDLE_READABLE;
if (client->tx)
mode |= VIR_EVENT_HANDLE_WRITABLE;
}
} else {
/* If there is a message on the rx queue then
* we're wanting more input */
if (client->rx)
mode |= VIR_EVENT_HANDLE_READABLE;
/* If there are one or more messages to send back to client,
then monitor for writability on socket */
if (client->tx)
mode |= VIR_EVENT_HANDLE_WRITABLE;
}
VIR_DEBUG("mode=%d", mode);
return mode;
}
/*
* @server: a locked or unlocked server object
* @client: a locked client object
*/
static int virNetServerClientRegisterEvent(virNetServerClientPtr client)
{
int mode = virNetServerClientCalculateHandleMode(client);
VIR_DEBUG("Registering client event callback %d", mode);
if (virNetSocketAddIOCallback(client->sock,
mode,
virNetServerClientDispatchEvent,
client) < 0)
return -1;
return 0;
}
/*
* @client: a locked client object
*/
static void virNetServerClientUpdateEvent(virNetServerClientPtr client)
{
int mode;
if (!client->sock)
return;
mode = virNetServerClientCalculateHandleMode(client);
virNetSocketUpdateIOCallback(client->sock, mode);
}
int virNetServerClientAddFilter(virNetServerClientPtr client,
virNetServerClientFilterFunc func,
void *opaque)
{
virNetServerClientFilterPtr filter;
int ret = -1;
virNetServerClientLock(client);
if (VIR_ALLOC(filter) < 0) {
virReportOOMError();
goto cleanup;
}
filter->id = client->nextFilterID++;
filter->func = func;
filter->opaque = opaque;
filter->next = client->filters;
client->filters = filter;
ret = filter->id;
cleanup:
virNetServerClientUnlock(client);
return ret;
}
void virNetServerClientRemoveFilter(virNetServerClientPtr client,
int filterID)
{
virNetServerClientFilterPtr tmp, prev;
virNetServerClientLock(client);
prev = NULL;
tmp = client->filters;
while (tmp) {
if (tmp->id == filterID) {
if (prev)
prev->next = tmp->next;
else
client->filters = tmp->next;
VIR_FREE(tmp);
break;
}
tmp = tmp->next;
}
virNetServerClientUnlock(client);
}
/* Check the client's access. */
static int
virNetServerClientCheckAccess(virNetServerClientPtr client)
{
virNetMessagePtr confirm;
/* Verify client certificate. */
if (virNetTLSContextCheckCertificate(client->tlsCtxt, client->tls) < 0)
return -1;
if (client->tx) {
VIR_DEBUG("client had unexpected data pending tx after access check");
return -1;
}
if (!(confirm = virNetMessageNew()))
return -1;
/* Checks have succeeded. Write a '\1' byte back to the client to
* indicate this (otherwise the socket is abruptly closed).
* (NB. The '\1' byte is sent in an encrypted record).
*/
confirm->bufferLength = 1;
confirm->bufferOffset = 0;
confirm->buffer[0] = '\1';
client->tx = confirm;
return 0;
}
virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
int auth,
bool readonly,
size_t nrequests_max,
virNetTLSContextPtr tls)
{
virNetServerClientPtr client;
VIR_DEBUG("sock=%p auth=%d tls=%p", sock, auth, tls);
if (VIR_ALLOC(client) < 0) {
virReportOOMError();
return NULL;
}
if (virMutexInit(&client->lock) < 0)
goto error;
client->refs = 1;
client->sock = sock;
client->auth = auth;
client->readonly = readonly;
client->tlsCtxt = tls;
client->nrequests_max = nrequests_max;
if (tls)
virNetTLSContextRef(tls);
/* Prepare one for packet receive */
if (!(client->rx = virNetMessageNew()))
goto error;
client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
client->nrequests = 1;
VIR_DEBUG("client=%p refs=%d", client, client->refs);
return client;
error:
/* XXX ref counting is better than this */
client->sock = NULL; /* Caller owns 'sock' upon failure */
virNetServerClientFree(client);
return NULL;
}
void virNetServerClientRef(virNetServerClientPtr client)
{
virNetServerClientLock(client);
client->refs++;
VIR_DEBUG("client=%p refs=%d", client, client->refs);
virNetServerClientUnlock(client);
}
int virNetServerClientGetAuth(virNetServerClientPtr client)
{
int auth;
virNetServerClientLock(client);
auth = client->auth;
virNetServerClientUnlock(client);
return auth;
}
bool virNetServerClientGetReadonly(virNetServerClientPtr client)
{
bool readonly;
virNetServerClientLock(client);
readonly = client->readonly;
virNetServerClientUnlock(client);
return readonly;
}
bool virNetServerClientHasTLSSession(virNetServerClientPtr client)
{
bool has;
virNetServerClientLock(client);
has = client->tls ? true : false;
virNetServerClientUnlock(client);
return has;
}
int virNetServerClientGetTLSKeySize(virNetServerClientPtr client)
{
int size = 0;
virNetServerClientLock(client);
if (client->tls)
size = virNetTLSSessionGetKeySize(client->tls);
virNetServerClientUnlock(client);
return size;
}
int virNetServerClientGetFD(virNetServerClientPtr client)
{
int fd = 0;
virNetServerClientLock(client);
fd = virNetSocketGetFD(client->sock);
virNetServerClientUnlock(client);
return fd;
}
int virNetServerClientGetLocalIdentity(virNetServerClientPtr client,
uid_t *uid, pid_t *pid)
{
int ret;
virNetServerClientLock(client);
ret = virNetSocketGetLocalIdentity(client->sock, uid, pid);
virNetServerClientUnlock(client);
return ret;
}
bool virNetServerClientIsSecure(virNetServerClientPtr client)
{
bool secure = false;
virNetServerClientLock(client);
if (client->tls)
secure = true;
#if HAVE_SASL
if (client->sasl)
secure = true;
#endif
if (virNetSocketIsLocal(client->sock))
secure = true;
virNetServerClientUnlock(client);
return secure;
}
#if HAVE_SASL
void virNetServerClientSetSASLSession(virNetServerClientPtr client,
virNetSASLSessionPtr sasl)
{
/* We don't set the sasl session on the socket here
* because we need to send out the auth confirmation
* in the clear. Only once we complete the next 'tx'
* operation do we switch to SASL mode
*/
virNetServerClientLock(client);
client->sasl = sasl;
virNetSASLSessionRef(sasl);
virNetServerClientUnlock(client);
}
#endif
int virNetServerClientSetIdentity(virNetServerClientPtr client,
const char *identity)
{
int ret = -1;
virNetServerClientLock(client);
if (!(client->identity = strdup(identity))) {
virReportOOMError();
goto error;
}
ret = 0;
error:
virNetServerClientUnlock(client);
return ret;
}
const char *virNetServerClientGetIdentity(virNetServerClientPtr client)
{
const char *identity;
virNetServerClientLock(client);
identity = client->identity;
virNetServerClientLock(client);
return identity;
}
void virNetServerClientSetPrivateData(virNetServerClientPtr client,
void *opaque,
virNetServerClientFreeFunc ff)
{
virNetServerClientLock(client);
if (client->privateData &&
client->privateDataFreeFunc)
client->privateDataFreeFunc(client->privateData);
client->privateData = opaque;
client->privateDataFreeFunc = ff;
virNetServerClientUnlock(client);
}
void *virNetServerClientGetPrivateData(virNetServerClientPtr client)
{
void *data;
virNetServerClientLock(client);
data = client->privateData;
virNetServerClientUnlock(client);
return data;
}
void virNetServerClientSetDispatcher(virNetServerClientPtr client,
virNetServerClientDispatchFunc func,
void *opaque)
{
virNetServerClientLock(client);
client->dispatchFunc = func;
client->dispatchOpaque = opaque;
virNetServerClientUnlock(client);
}
const char *virNetServerClientLocalAddrString(virNetServerClientPtr client)
{
return virNetSocketLocalAddrString(client->sock);
}
const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client)
{
return virNetSocketRemoteAddrString(client->sock);
}
void virNetServerClientFree(virNetServerClientPtr client)
{
if (!client)
return;
virNetServerClientLock(client);
VIR_DEBUG("client=%p refs=%d", client, client->refs);
client->refs--;
if (client->refs > 0) {
virNetServerClientUnlock(client);
return;
}
if (client->privateData &&
client->privateDataFreeFunc)
client->privateDataFreeFunc(client->privateData);
VIR_FREE(client->identity);
#if HAVE_SASL
virNetSASLSessionFree(client->sasl);
#endif
virNetTLSSessionFree(client->tls);
virNetTLSContextFree(client->tlsCtxt);
virNetSocketFree(client->sock);
virNetServerClientUnlock(client);
virMutexDestroy(&client->lock);
VIR_FREE(client);
}
/*
*
* We don't free stuff here, merely disconnect the client's
* network socket & resources.
*
* Full free of the client is done later in a safe point
* where it can be guaranteed it is no longer in use
*/
void virNetServerClientClose(virNetServerClientPtr client)
{
virNetServerClientLock(client);
VIR_DEBUG("client=%p refs=%d", client, client->refs);
if (!client->sock) {
virNetServerClientUnlock(client);
return;
}
/* Do now, even though we don't close the socket
* until end, to ensure we don't get invoked
* again due to tls shutdown */
if (client->sock)
virNetSocketRemoveIOCallback(client->sock);
if (client->tls) {
virNetTLSSessionFree(client->tls);
client->tls = NULL;
}
if (client->sock) {
virNetSocketFree(client->sock);
client->sock = NULL;
}
while (client->rx) {
virNetMessagePtr msg
= virNetMessageQueueServe(&client->rx);
virNetMessageFree(msg);
}
while (client->tx) {
virNetMessagePtr msg
= virNetMessageQueueServe(&client->tx);
virNetMessageFree(msg);
}
virNetServerClientUnlock(client);
}
bool virNetServerClientIsClosed(virNetServerClientPtr client)
{
bool closed;
virNetServerClientLock(client);
closed = client->sock == NULL ? true : false;
virNetServerClientUnlock(client);
return closed;
}
void virNetServerClientMarkClose(virNetServerClientPtr client)
{
virNetServerClientLock(client);
client->wantClose = true;
virNetServerClientUnlock(client);
}
bool virNetServerClientWantClose(virNetServerClientPtr client)
{
bool wantClose;
virNetServerClientLock(client);
wantClose = client->wantClose;
virNetServerClientUnlock(client);
return wantClose;
}
int virNetServerClientInit(virNetServerClientPtr client)
{
virNetServerClientLock(client);
if (!client->tlsCtxt) {
/* Plain socket, so prepare to read first message */
if (virNetServerClientRegisterEvent(client) < 0)
goto error;
} else {
int ret;
if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt,
NULL)))
goto error;
virNetSocketSetTLSSession(client->sock,
client->tls);
/* Begin the TLS handshake. */
ret = virNetTLSSessionHandshake(client->tls);
if (ret == 0) {
/* Unlikely, but ... Next step is to check the certificate. */
if (virNetServerClientCheckAccess(client) < 0)
goto error;
/* Handshake & cert check OK, so prepare to read first message */
if (virNetServerClientRegisterEvent(client) < 0)
goto error;
} else if (ret > 0) {
/* Most likely, need to do more handshake data */
if (virNetServerClientRegisterEvent(client) < 0)
goto error;
} else {
goto error;
}
}
virNetServerClientUnlock(client);
return 0;
error:
client->wantClose = true;
virNetServerClientUnlock(client);
return -1;
}
/*
* Read data into buffer using wire decoding (plain or TLS)
*
* Returns:
* -1 on error or EOF
* 0 on EAGAIN
* n number of bytes
*/
static ssize_t virNetServerClientRead(virNetServerClientPtr client)
{
ssize_t ret;
if (client->rx->bufferLength <= client->rx->bufferOffset) {
virNetError(VIR_ERR_RPC,
_("unexpected zero/negative length request %lld"),
(long long int)(client->rx->bufferLength - client->rx->bufferOffset));
client->wantClose = true;
return -1;
}
ret = virNetSocketRead(client->sock,
client->rx->buffer + client->rx->bufferOffset,
client->rx->bufferLength - client->rx->bufferOffset);
if (ret <= 0)
return ret;
client->rx->bufferOffset += ret;
return ret;
}
/*
* Read data until we get a complete message to process
*/
static void virNetServerClientDispatchRead(virNetServerClientPtr client)
{
readmore:
if (virNetServerClientRead(client) < 0) {
client->wantClose = true;
return; /* Error */
}
if (client->rx->bufferOffset < client->rx->bufferLength)
return; /* Still not read enough */
/* Either done with length word header */
if (client->rx->bufferLength == VIR_NET_MESSAGE_LEN_MAX) {
if (virNetMessageDecodeLength(client->rx) < 0)
return;
virNetServerClientUpdateEvent(client);
/* Try and read payload immediately instead of going back
into poll() because chances are the data is already
waiting for us */
goto readmore;
} else {
/* Grab the completed message */
virNetMessagePtr msg = virNetMessageQueueServe(&client->rx);
virNetServerClientFilterPtr filter;
/* Decode the header so we can use it for routing decisions */
if (virNetMessageDecodeHeader(msg) < 0) {
virNetMessageFree(msg);
client->wantClose = true;
return;
}
/* Maybe send off for queue against a filter */
filter = client->filters;
while (filter) {
int ret = filter->func(client, msg, filter->opaque);
if (ret < 0) {
virNetMessageFree(msg);
msg = NULL;
if (ret < 0)
client->wantClose = true;
break;
}
if (ret > 0) {
msg = NULL;
break;
}
filter = filter->next;
}
/* Send off to for normal dispatch to workers */
if (msg) {
if (!client->dispatchFunc ||
client->dispatchFunc(client, msg, client->dispatchOpaque) < 0) {
virNetMessageFree(msg);
client->wantClose = true;
return;
}
}
/* Possibly need to create another receive buffer */
if (client->nrequests < client->nrequests_max) {
if (!(client->rx = virNetMessageNew())) {
client->wantClose = true;
}
client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
client->nrequests++;
}
virNetServerClientUpdateEvent(client);
}
}
/*
* Send client->tx using no encoding
*
* Returns:
* -1 on error or EOF
* 0 on EAGAIN
* n number of bytes
*/
static ssize_t virNetServerClientWrite(virNetServerClientPtr client)
{
ssize_t ret;
if (client->tx->bufferLength < client->tx->bufferOffset) {
virNetError(VIR_ERR_RPC,
_("unexpected zero/negative length request %lld"),
(long long int)(client->tx->bufferLength - client->tx->bufferOffset));
client->wantClose = true;
return -1;
}
if (client->tx->bufferLength == client->tx->bufferOffset)
return 1;
ret = virNetSocketWrite(client->sock,
client->tx->buffer + client->tx->bufferOffset,
client->tx->bufferLength - client->tx->bufferOffset);
if (ret <= 0)
return ret; /* -1 error, 0 = egain */
client->tx->bufferOffset += ret;
return ret;
}
/*
* Process all queued client->tx messages until
* we would block on I/O
*/
static void
virNetServerClientDispatchWrite(virNetServerClientPtr client)
{
while (client->tx) {
ssize_t ret;
ret = virNetServerClientWrite(client);
if (ret < 0) {
client->wantClose = true;
return;
}
if (ret == 0)
return; /* Would block on write EAGAIN */
if (client->tx->bufferOffset == client->tx->bufferLength) {
virNetMessagePtr msg;
#if HAVE_SASL
/* Completed this 'tx' operation, so now read for all
* future rx/tx to be under a SASL SSF layer
*/
if (client->sasl) {
virNetSocketSetSASLSession(client->sock, client->sasl);
virNetSASLSessionFree(client->sasl);
client->sasl = NULL;
}
#endif
/* Get finished msg from head of tx queue */
msg = virNetMessageQueueServe(&client->tx);
if (msg->header.type == VIR_NET_REPLY) {
client->nrequests--;
/* See if the recv queue is currently throttled */
if (!client->rx &&
client->nrequests < client->nrequests_max) {
/* Ready to recv more messages */
client->rx = msg;
client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
msg = NULL;
client->nrequests++;
}
}
virNetMessageFree(msg);
virNetServerClientUpdateEvent(client);
}
}
}
static void
virNetServerClientDispatchHandshake(virNetServerClientPtr client)
{
int ret;
/* Continue the handshake. */
ret = virNetTLSSessionHandshake(client->tls);
if (ret == 0) {
/* Finished. Next step is to check the certificate. */
if (virNetServerClientCheckAccess(client) < 0)
client->wantClose = true;
else
virNetServerClientUpdateEvent(client);
} else if (ret > 0) {
/* Carry on waiting for more handshake. Update
the events just in case handshake data flow
direction has changed */
virNetServerClientUpdateEvent (client);
} else {
/* Fatal error in handshake */
client->wantClose = true;
}
}
static void
virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque)
{
virNetServerClientPtr client = opaque;
virNetServerClientLock(client);
if (client->sock != sock) {
virNetSocketRemoveIOCallback(sock);
virNetServerClientUnlock(client);
return;
}
if (events & (VIR_EVENT_HANDLE_WRITABLE |
VIR_EVENT_HANDLE_READABLE)) {
if (client->tls &&
virNetTLSSessionGetHandshakeStatus(client->tls) !=
VIR_NET_TLS_HANDSHAKE_COMPLETE) {
virNetServerClientDispatchHandshake(client);
} else {
if (events & VIR_EVENT_HANDLE_WRITABLE)
virNetServerClientDispatchWrite(client);
if (events & VIR_EVENT_HANDLE_READABLE)
virNetServerClientDispatchRead(client);
}
}
/* NB, will get HANGUP + READABLE at same time upon
* disconnect */
if (events & (VIR_EVENT_HANDLE_ERROR |
VIR_EVENT_HANDLE_HANGUP))
client->wantClose = true;
virNetServerClientUnlock(client);
}
int virNetServerClientSendMessage(virNetServerClientPtr client,
virNetMessagePtr msg)
{
int ret = -1;
VIR_DEBUG("msg=%p proc=%d len=%zu offset=%zu",
msg, msg->header.proc,
msg->bufferLength, msg->bufferOffset);
virNetServerClientLock(client);
if (client->sock && !client->wantClose) {
virNetMessageQueuePush(&client->tx, msg);
virNetServerClientUpdateEvent(client);
ret = 0;
}
virNetServerClientUnlock(client);
return ret;
}
bool virNetServerClientNeedAuth(virNetServerClientPtr client)
{
bool need = false;
virNetServerClientLock(client);
if (client->auth && !client->identity)
need = true;
virNetServerClientUnlock(client);
return need;
}