/* * virnetserverprogram.c: generic network RPC server program * * Copyright (C) 2006-2012 Red Hat, Inc. * Copyright (C) 2006 Daniel P. Berrange * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library. If not, see * . */ #include #include "virnetserverprogram.h" #include "virnetserverclient.h" #include "viralloc.h" #include "virerror.h" #include "virlog.h" #include "virfile.h" #include "virthread.h" #define VIR_FROM_THIS VIR_FROM_RPC VIR_LOG_INIT("rpc.netserverprogram"); struct _virNetServerProgram { virObject parent; unsigned program; unsigned version; virNetServerProgramProc *procs; size_t nprocs; }; static virClass *virNetServerProgramClass; static void virNetServerProgramDispose(void *obj); static int virNetServerProgramOnceInit(void) { if (!VIR_CLASS_NEW(virNetServerProgram, virClassForObject())) return -1; return 0; } VIR_ONCE_GLOBAL_INIT(virNetServerProgram); virNetServerProgram *virNetServerProgramNew(unsigned program, unsigned version, virNetServerProgramProc *procs, size_t nprocs) { virNetServerProgram *prog; if (virNetServerProgramInitialize() < 0) return NULL; if (!(prog = virObjectNew(virNetServerProgramClass))) return NULL; prog->program = program; prog->version = version; prog->procs = procs; prog->nprocs = nprocs; VIR_DEBUG("prog=%p", prog); return prog; } int virNetServerProgramGetID(virNetServerProgram *prog) { return prog->program; } int virNetServerProgramGetVersion(virNetServerProgram *prog) { return prog->version; } int virNetServerProgramMatches(virNetServerProgram *prog, virNetMessage *msg) { if (prog->program == msg->header.prog && prog->version == msg->header.vers) return 1; return 0; } static virNetServerProgramProc *virNetServerProgramGetProc(virNetServerProgram *prog, int procedure) { virNetServerProgramProc *proc; if (procedure < 0) return NULL; if (procedure >= prog->nprocs) return NULL; proc = &prog->procs[procedure]; if (!proc->func) return NULL; return proc; } unsigned int virNetServerProgramGetPriority(virNetServerProgram *prog, int procedure) { virNetServerProgramProc *proc = virNetServerProgramGetProc(prog, procedure); if (!proc) return 0; return proc->priority; } static int virNetServerProgramSendError(unsigned program, unsigned version, virNetServerClient *client, virNetMessage *msg, struct virNetMessageError *rerr, int procedure, int type, unsigned int serial) { VIR_DEBUG("prog=%d ver=%d proc=%d type=%d serial=%u msg=%p rerr=%p", program, version, procedure, type, serial, msg, rerr); virNetMessageSaveError(rerr); /* Return header. */ msg->header.prog = program; msg->header.vers = version; msg->header.proc = procedure; msg->header.type = type; msg->header.serial = serial; msg->header.status = VIR_NET_ERROR; if (virNetMessageEncodeHeader(msg) < 0) goto error; if (virNetMessageEncodePayload(msg, (xdrproc_t)xdr_virNetMessageError, rerr) < 0) goto error; xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr); /* Put reply on end of tx queue to send out */ if (virNetServerClientSendMessage(client, msg) < 0) return -1; return 0; error: VIR_WARN("Failed to serialize remote error '%p'", rerr); xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr); return -1; } /* * @client: the client to send the error to * @req: the message this error is in reply to * * Send an error message to the client * * Returns 0 if the error was sent, -1 upon fatal error */ int virNetServerProgramSendReplyError(virNetServerProgram *prog, virNetServerClient *client, virNetMessage *msg, struct virNetMessageError *rerr, struct virNetMessageHeader *req) { /* * For data streams, errors are sent back as data streams * For method calls, errors are sent back as method replies */ return virNetServerProgramSendError(prog->program, prog->version, client, msg, rerr, req->proc, req->type == VIR_NET_STREAM ? VIR_NET_STREAM : VIR_NET_REPLY, req->serial); } int virNetServerProgramSendStreamError(virNetServerProgram *prog, virNetServerClient *client, virNetMessage *msg, struct virNetMessageError *rerr, int procedure, unsigned int serial) { return virNetServerProgramSendError(prog->program, prog->version, client, msg, rerr, procedure, VIR_NET_STREAM, serial); } int virNetServerProgramUnknownError(virNetServerClient *client, virNetMessage *msg, struct virNetMessageHeader *req) { virNetMessageError rerr; virReportError(VIR_ERR_RPC, _("Cannot find program %d version %d"), req->prog, req->vers); memset(&rerr, 0, sizeof(rerr)); return virNetServerProgramSendError(req->prog, req->vers, client, msg, &rerr, req->proc, VIR_NET_REPLY, req->serial); } static int virNetServerProgramDispatchCall(virNetServerProgram *prog, virNetServer *server, virNetServerClient *client, virNetMessage *msg); /* * @server: the unlocked server object * @client: the unlocked client object * @msg: the complete incoming message packet, with header already decoded * * This function is intended to be called from worker threads * when an incoming message is ready to be dispatched for * execution. * * Upon successful return the '@msg' instance will be released * by this function (or more often, reused to send a reply). * Upon failure, the '@msg' must be freed by the caller. * * Returns 0 if the message was dispatched, -1 upon fatal error */ int virNetServerProgramDispatch(virNetServerProgram *prog, virNetServer *server, virNetServerClient *client, virNetMessage *msg) { int ret = -1; virNetMessageError rerr; memset(&rerr, 0, sizeof(rerr)); VIR_DEBUG("prog=%d ver=%d type=%d status=%d serial=%u proc=%d", msg->header.prog, msg->header.vers, msg->header.type, msg->header.status, msg->header.serial, msg->header.proc); /* Check version, etc. */ if (msg->header.prog != prog->program) { virReportError(VIR_ERR_RPC, _("program mismatch (actual %x, expected %x)"), msg->header.prog, prog->program); goto error; } if (msg->header.vers != prog->version) { virReportError(VIR_ERR_RPC, _("version mismatch (actual %x, expected %x)"), msg->header.vers, prog->version); goto error; } switch (msg->header.type) { case VIR_NET_CALL: case VIR_NET_CALL_WITH_FDS: ret = virNetServerProgramDispatchCall(prog, server, client, msg); break; case VIR_NET_STREAM: /* Since stream data is non-acked, async, we may continue to receive * stream packets after we closed down a stream. Just drop & ignore * these. */ VIR_INFO("Ignoring unexpected stream data serial=%u proc=%d status=%d", msg->header.serial, msg->header.proc, msg->header.status); /* Send a dummy reply to free up 'msg' & unblock client rx */ virNetMessageClear(msg); msg->header.type = VIR_NET_REPLY; if (virNetServerClientSendMessage(client, msg) < 0) return -1; ret = 0; break; case VIR_NET_REPLY: case VIR_NET_REPLY_WITH_FDS: case VIR_NET_MESSAGE: case VIR_NET_STREAM_HOLE: default: virReportError(VIR_ERR_RPC, _("Unexpected message type %u"), msg->header.type); goto error; } return ret; error: if (msg->header.type == VIR_NET_CALL || msg->header.type == VIR_NET_CALL_WITH_FDS) { ret = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header); } else { /* Send a dummy reply to free up 'msg' & unblock client rx */ virNetMessageClear(msg); msg->header.type = VIR_NET_REPLY; if (virNetServerClientSendMessage(client, msg) < 0) return -1; ret = 0; } return ret; } /* * @server: the unlocked server object * @client: the unlocked client object * @msg: the complete incoming method call, with header already decoded * * This method is used to dispatch a message representing an * incoming method call from a client. It decodes the payload * to obtain method call arguments, invokes the method and * then sends a reply packet with the return values * * Returns 0 if the reply was sent, or -1 upon fatal error */ static int virNetServerProgramDispatchCall(virNetServerProgram *prog, virNetServer *server, virNetServerClient *client, virNetMessage *msg) { g_autofree char *arg = NULL; g_autofree char *ret = NULL; int rv = -1; virNetServerProgramProc *dispatcher; virNetMessageError rerr; size_t i; g_autoptr(virIdentity) identity = NULL; memset(&rerr, 0, sizeof(rerr)); if (msg->header.status != VIR_NET_OK) { virReportError(VIR_ERR_RPC, _("Unexpected message status %u"), msg->header.status); goto error; } dispatcher = virNetServerProgramGetProc(prog, msg->header.proc); if (!dispatcher) { virReportError(VIR_ERR_RPC, _("unknown procedure: %d"), msg->header.proc); goto error; } /* If the client is not authenticated, don't allow any RPC ops * which are except for authentication ones */ if (dispatcher->needAuth && !virNetServerClientIsAuthenticated(client)) { /* Explicitly *NOT* calling remoteDispatchAuthError() because we want back-compatibility with libvirt clients which don't support the VIR_ERR_AUTH_FAILED error code */ virReportError(VIR_ERR_RPC, "%s", _("authentication required")); goto error; } arg = g_new0(char, dispatcher->arg_len); ret = g_new0(char, dispatcher->ret_len); if (virNetMessageDecodePayload(msg, dispatcher->arg_filter, arg) < 0) goto error; if (!(identity = virNetServerClientGetIdentity(client))) goto error; if (virIdentitySetCurrent(identity) < 0) goto error; /* * When the RPC handler is called: * * - Server object is unlocked * - Client object is unlocked * * Without locking, it is safe to use: * * 'args and 'ret' */ rv = (dispatcher->func)(server, client, msg, &rerr, arg, ret); if (virIdentitySetCurrent(NULL) < 0) goto error; /* * If rv == 1, this indicates the dispatch func has * populated 'msg' with a list of FDs to return to * the caller. * * Otherwise we must clear out the FDs we got from * the client originally. * */ if (rv != 1) { for (i = 0; i < msg->nfds; i++) VIR_FORCE_CLOSE(msg->fds[i]); VIR_FREE(msg->fds); msg->nfds = 0; } xdr_free(dispatcher->arg_filter, arg); if (rv < 0) goto error; /* Return header. We're re-using same message object, so * only need to tweak type/status fields */ /*msg->header.prog = msg->header.prog;*/ /*msg->header.vers = msg->header.vers;*/ /*msg->header.proc = msg->header.proc;*/ msg->header.type = msg->nfds ? VIR_NET_REPLY_WITH_FDS : VIR_NET_REPLY; /*msg->header.serial = msg->header.serial;*/ msg->header.status = VIR_NET_OK; if (virNetMessageEncodeHeader(msg) < 0) { xdr_free(dispatcher->ret_filter, ret); goto error; } if (msg->nfds && virNetMessageEncodeNumFDs(msg) < 0) { xdr_free(dispatcher->ret_filter, ret); goto error; } if (virNetMessageEncodePayload(msg, dispatcher->ret_filter, ret) < 0) { xdr_free(dispatcher->ret_filter, ret); goto error; } xdr_free(dispatcher->ret_filter, ret); /* Put reply on end of tx queue to send out */ return virNetServerClientSendMessage(client, msg); error: /* Bad stuff (de-)serializing message, but we have an * RPC error message we can send back to the client */ rv = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header); return rv; } int virNetServerProgramSendStreamData(virNetServerProgram *prog, virNetServerClient *client, virNetMessage *msg, int procedure, unsigned int serial, const char *data, size_t len) { VIR_DEBUG("client=%p msg=%p data=%p len=%zu", client, msg, data, len); /* Return header. We're reusing same message object, so * only need to tweak type/status fields */ msg->header.prog = prog->program; msg->header.vers = prog->version; msg->header.proc = procedure; msg->header.type = VIR_NET_STREAM; msg->header.serial = serial; /* * NB * data != NULL + len > 0 => VIR_NET_CONTINUE (Sending back data) * data != NULL + len == 0 => VIR_NET_CONTINUE (Sending read EOF) * data == NULL => VIR_NET_OK (Sending finish handshake confirmation) */ msg->header.status = data ? VIR_NET_CONTINUE : VIR_NET_OK; if (virNetMessageEncodeHeader(msg) < 0) return -1; if (data && len) { if (virNetMessageEncodePayloadRaw(msg, data, len) < 0) return -1; } else { if (virNetMessageEncodePayloadEmpty(msg) < 0) return -1; } VIR_DEBUG("Total %zu", msg->bufferLength); return virNetServerClientSendMessage(client, msg); } int virNetServerProgramSendStreamHole(virNetServerProgram *prog, virNetServerClient *client, virNetMessage *msg, int procedure, unsigned int serial, long long length, unsigned int flags) { virNetStreamHole data; VIR_DEBUG("client=%p msg=%p length=%lld", client, msg, length); memset(&data, 0, sizeof(data)); data.length = length; data.flags = flags; msg->header.prog = prog->program; msg->header.vers = prog->version; msg->header.proc = procedure; msg->header.type = VIR_NET_STREAM_HOLE; msg->header.serial = serial; msg->header.status = VIR_NET_CONTINUE; if (virNetMessageEncodeHeader(msg) < 0) return -1; if (virNetMessageEncodePayload(msg, (xdrproc_t)xdr_virNetStreamHole, &data) < 0) return -1; return virNetServerClientSendMessage(client, msg); } void virNetServerProgramDispose(void *obj G_GNUC_UNUSED) { }