libvirt/src/rpc/virnetserverservice.c
Marc Hartmayer d6bc7622f0 rpc: Fix potentially segfaults
We have to allocate first and if, and only if, it was successful we
can set the count. A segfault has occurred in
virNetServerServiceNewPostExecRestart() when VIR_ALLOC_N(svc->socks,
n) has failed, but svc->nsocsk = n was already set. Thus
virObejectUnref(svc) was called and therefore it was possible that
virNetServerServiceDispose was called => segmentation fault.  For
safeness NULL pointer check were added in
virNetServerServiceDispose().

Signed-off-by: Marc Hartmayer <mhartmay@linux.vnet.ibm.com>
Reviewed-by: Boris Fiuczynski <fiuczy@linux.vnet.ibm.com>
Reviewed-by: Bjoern Walk <bwalk@linux.vnet.ibm.com>
2017-02-12 15:02:42 -05:00

529 lines
15 KiB
C

/*
* virnetserverservice.c: generic network RPC server service
*
* Copyright (C) 2006-2012, 2014 Red Hat, Inc.
* Copyright (C) 2006 Daniel P. Berrange
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library. If not, see
* <http://www.gnu.org/licenses/>.
*
* Author: Daniel P. Berrange <berrange@redhat.com>
*/
#include <config.h>
#include "virnetserverservice.h"
#include <unistd.h>
#include "viralloc.h"
#include "virerror.h"
#include "virthread.h"
#define VIR_FROM_THIS VIR_FROM_RPC
struct _virNetServerService {
virObject object;
size_t nsocks;
virNetSocketPtr *socks;
int auth;
bool readonly;
size_t nrequests_client_max;
#if WITH_GNUTLS
virNetTLSContextPtr tls;
#endif
virNetServerServiceDispatchFunc dispatchFunc;
void *dispatchOpaque;
};
static virClassPtr virNetServerServiceClass;
static void virNetServerServiceDispose(void *obj);
static int virNetServerServiceOnceInit(void)
{
if (!(virNetServerServiceClass = virClassNew(virClassForObject(),
"virNetServerService",
sizeof(virNetServerService),
virNetServerServiceDispose)))
return -1;
return 0;
}
VIR_ONCE_GLOBAL_INIT(virNetServerService)
static void virNetServerServiceAccept(virNetSocketPtr sock,
int events ATTRIBUTE_UNUSED,
void *opaque)
{
virNetServerServicePtr svc = opaque;
virNetSocketPtr clientsock = NULL;
if (virNetSocketAccept(sock, &clientsock) < 0)
goto cleanup;
if (!clientsock) /* Connection already went away */
goto cleanup;
if (!svc->dispatchFunc)
goto cleanup;
svc->dispatchFunc(svc, clientsock, svc->dispatchOpaque);
cleanup:
virObjectUnref(clientsock);
}
virNetServerServicePtr
virNetServerServiceNewFDOrUNIX(const char *path,
mode_t mask,
gid_t grp,
int auth,
#if WITH_GNUTLS
virNetTLSContextPtr tls,
#endif
bool readonly,
size_t max_queued_clients,
size_t nrequests_client_max,
unsigned int nfds,
unsigned int *cur_fd)
{
if (*cur_fd - STDERR_FILENO > nfds) {
/*
* There are no more file descriptors to use, so we have to
* fallback to UNIX socket.
*/
return virNetServerServiceNewUNIX(path,
mask,
grp,
auth,
#if WITH_GNUTLS
tls,
#endif
readonly,
max_queued_clients,
nrequests_client_max);
} else {
/*
* There's still enough file descriptors. In this case we'll
* use the current one and increment it afterwards. Take care
* with order of operation for pointer arithmetic and auto
* increment on cur_fd - the parentheses are necessary.
*/
return virNetServerServiceNewFD((*cur_fd)++,
auth,
#if WITH_GNUTLS
tls,
#endif
readonly,
max_queued_clients,
nrequests_client_max);
}
}
virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename,
const char *service,
int family,
int auth,
#if WITH_GNUTLS
virNetTLSContextPtr tls,
#endif
bool readonly,
size_t max_queued_clients,
size_t nrequests_client_max)
{
virNetServerServicePtr svc;
size_t i;
if (virNetServerServiceInitialize() < 0)
return NULL;
if (!(svc = virObjectNew(virNetServerServiceClass)))
return NULL;
svc->auth = auth;
svc->readonly = readonly;
svc->nrequests_client_max = nrequests_client_max;
#if WITH_GNUTLS
svc->tls = virObjectRef(tls);
#endif
if (virNetSocketNewListenTCP(nodename,
service,
family,
&svc->socks,
&svc->nsocks) < 0)
goto error;
for (i = 0; i < svc->nsocks; i++) {
if (virNetSocketListen(svc->socks[i], max_queued_clients) < 0)
goto error;
/* IO callback is initially disabled, until we're ready
* to deal with incoming clients */
virObjectRef(svc);
if (virNetSocketAddIOCallback(svc->socks[i],
0,
virNetServerServiceAccept,
svc,
virObjectFreeCallback) < 0) {
virObjectUnref(svc);
goto error;
}
}
return svc;
error:
virObjectUnref(svc);
return NULL;
}
virNetServerServicePtr virNetServerServiceNewUNIX(const char *path,
mode_t mask,
gid_t grp,
int auth,
#if WITH_GNUTLS
virNetTLSContextPtr tls,
#endif
bool readonly,
size_t max_queued_clients,
size_t nrequests_client_max)
{
virNetServerServicePtr svc;
size_t i;
if (virNetServerServiceInitialize() < 0)
return NULL;
if (!(svc = virObjectNew(virNetServerServiceClass)))
return NULL;
svc->auth = auth;
svc->readonly = readonly;
svc->nrequests_client_max = nrequests_client_max;
#if WITH_GNUTLS
svc->tls = virObjectRef(tls);
#endif
if (VIR_ALLOC_N(svc->socks, 1) < 0)
goto error;
svc->nsocks = 1;
if (virNetSocketNewListenUNIX(path,
mask,
-1,
grp,
&svc->socks[0]) < 0)
goto error;
for (i = 0; i < svc->nsocks; i++) {
if (virNetSocketListen(svc->socks[i], max_queued_clients) < 0)
goto error;
/* IO callback is initially disabled, until we're ready
* to deal with incoming clients */
virObjectRef(svc);
if (virNetSocketAddIOCallback(svc->socks[i],
0,
virNetServerServiceAccept,
svc,
virObjectFreeCallback) < 0) {
virObjectUnref(svc);
goto error;
}
}
return svc;
error:
virObjectUnref(svc);
return NULL;
}
virNetServerServicePtr virNetServerServiceNewFD(int fd,
int auth,
#if WITH_GNUTLS
virNetTLSContextPtr tls,
#endif
bool readonly,
size_t max_queued_clients,
size_t nrequests_client_max)
{
virNetServerServicePtr svc;
size_t i;
if (virNetServerServiceInitialize() < 0)
return NULL;
if (!(svc = virObjectNew(virNetServerServiceClass)))
return NULL;
svc->auth = auth;
svc->readonly = readonly;
svc->nrequests_client_max = nrequests_client_max;
#if WITH_GNUTLS
svc->tls = virObjectRef(tls);
#endif
if (VIR_ALLOC_N(svc->socks, 1) < 0)
goto error;
svc->nsocks = 1;
if (virNetSocketNewListenFD(fd,
&svc->socks[0]) < 0)
goto error;
for (i = 0; i < svc->nsocks; i++) {
if (virNetSocketListen(svc->socks[i], max_queued_clients) < 0)
goto error;
/* IO callback is initially disabled, until we're ready
* to deal with incoming clients */
virObjectRef(svc);
if (virNetSocketAddIOCallback(svc->socks[i],
0,
virNetServerServiceAccept,
svc,
virObjectFreeCallback) < 0) {
virObjectUnref(svc);
goto error;
}
}
return svc;
error:
virObjectUnref(svc);
return NULL;
}
virNetServerServicePtr virNetServerServiceNewPostExecRestart(virJSONValuePtr object)
{
virNetServerServicePtr svc;
virJSONValuePtr socks;
size_t i;
ssize_t n;
unsigned int max;
if (virNetServerServiceInitialize() < 0)
return NULL;
if (!(svc = virObjectNew(virNetServerServiceClass)))
return NULL;
if (virJSONValueObjectGetNumberInt(object, "auth", &svc->auth) < 0) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Missing auth field in JSON state document"));
goto error;
}
if (virJSONValueObjectGetBoolean(object, "readonly", &svc->readonly) < 0) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Missing readonly field in JSON state document"));
goto error;
}
if (virJSONValueObjectGetNumberUint(object, "nrequests_client_max",
&max) < 0) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Missing nrequests_client_max field in JSON state document"));
goto error;
}
svc->nrequests_client_max = max;
if (!(socks = virJSONValueObjectGet(object, "socks"))) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
_("Missing socks field in JSON state document"));
goto error;
}
if ((n = virJSONValueArraySize(socks)) < 0) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
_("socks field in JSON was not an array"));
goto error;
}
if (VIR_ALLOC_N(svc->socks, n) < 0)
goto error;
svc->nsocks = n;
for (i = 0; i < svc->nsocks; i++) {
virJSONValuePtr child = virJSONValueArrayGet(socks, i);
virNetSocketPtr sock;
if (!(sock = virNetSocketNewPostExecRestart(child))) {
virObjectUnref(sock);
goto error;
}
svc->socks[i] = sock;
/* IO callback is initially disabled, until we're ready
* to deal with incoming clients */
virObjectRef(svc);
if (virNetSocketAddIOCallback(sock,
0,
virNetServerServiceAccept,
svc,
virObjectFreeCallback) < 0) {
virObjectUnref(svc);
goto error;
}
}
return svc;
error:
virObjectUnref(svc);
return NULL;
}
virJSONValuePtr virNetServerServicePreExecRestart(virNetServerServicePtr svc)
{
virJSONValuePtr object = virJSONValueNewObject();
virJSONValuePtr socks;
size_t i;
if (!object)
return NULL;
if (virJSONValueObjectAppendNumberInt(object, "auth", svc->auth) < 0)
goto error;
if (virJSONValueObjectAppendBoolean(object, "readonly", svc->readonly) < 0)
goto error;
if (virJSONValueObjectAppendNumberUint(object, "nrequests_client_max", svc->nrequests_client_max) < 0)
goto error;
if (!(socks = virJSONValueNewArray()))
goto error;
if (virJSONValueObjectAppend(object, "socks", socks) < 0) {
virJSONValueFree(socks);
goto error;
}
for (i = 0; i < svc->nsocks; i++) {
virJSONValuePtr child;
if (!(child = virNetSocketPreExecRestart(svc->socks[i])))
goto error;
if (virJSONValueArrayAppend(socks, child) < 0) {
virJSONValueFree(child);
goto error;
}
}
return object;
error:
virJSONValueFree(object);
return NULL;
}
int virNetServerServiceGetPort(virNetServerServicePtr svc)
{
/* We're assuming if there are multiple sockets
* for IPv4 & 6, then they are all on same port */
return virNetSocketGetPort(svc->socks[0]);
}
int virNetServerServiceGetAuth(virNetServerServicePtr svc)
{
return svc->auth;
}
bool virNetServerServiceIsReadonly(virNetServerServicePtr svc)
{
return svc->readonly;
}
size_t virNetServerServiceGetMaxRequests(virNetServerServicePtr svc)
{
return svc->nrequests_client_max;
}
#if WITH_GNUTLS
virNetTLSContextPtr virNetServerServiceGetTLSContext(virNetServerServicePtr svc)
{
return svc->tls;
}
#endif
void virNetServerServiceSetDispatcher(virNetServerServicePtr svc,
virNetServerServiceDispatchFunc func,
void *opaque)
{
svc->dispatchFunc = func;
svc->dispatchOpaque = opaque;
}
void virNetServerServiceDispose(void *obj)
{
virNetServerServicePtr svc = obj;
size_t i;
for (i = 0; i < svc->nsocks; i++)
virObjectUnref(svc->socks[i]);
VIR_FREE(svc->socks);
#if WITH_GNUTLS
virObjectUnref(svc->tls);
#endif
}
void virNetServerServiceToggle(virNetServerServicePtr svc,
bool enabled)
{
size_t i;
for (i = 0; i < svc->nsocks; i++)
virNetSocketUpdateIOCallback(svc->socks[i],
enabled ?
VIR_EVENT_HANDLE_READABLE :
0);
}
void virNetServerServiceClose(virNetServerServicePtr svc)
{
size_t i;
if (!svc)
return;
for (i = 0; i < svc->nsocks; i++) {
virNetSocketRemoveIOCallback(svc->socks[i]);
virNetSocketClose(svc->socks[i]);
virObjectUnref(svc);
}
}