Turn virNetTLSContext and virNetTLSSession into virObject instances

Make virNetTLSContext and virNetTLSSession use the virObject
APIs for reference counting

Signed-off-by: Daniel P. Berrange <berrange@redhat.com>
This commit is contained in:
Daniel P. Berrange 2012-07-11 14:35:48 +01:00
parent b57ee0921e
commit e10e1969d5
13 changed files with 66 additions and 120 deletions

1
cfg.mk
View File

@ -158,7 +158,6 @@ useless_free_options = \
--name=virNetSocketFree \
--name=virNetSASLContextFree \
--name=virNetSASLSessionFree \
--name=virNetTLSSessionFree \
--name=virNWFilterDefFree \
--name=virNWFilterEntryFree \
--name=virNWFilterHashTableFree \

View File

@ -541,7 +541,7 @@ static int daemonSetupNetworking(virNetServerPtr srv,
false,
config->max_client_requests,
ctxt))) {
virNetTLSContextFree(ctxt);
virObjectUnref(ctxt);
goto error;
}
if (virNetServerAddService(srv, svcTLS,
@ -549,7 +549,7 @@ static int daemonSetupNetworking(virNetServerPtr srv,
!config->listen_tcp ? "_libvirt._tcp" : NULL) < 0)
goto error;
virNetTLSContextFree(ctxt);
virObjectUnref(ctxt);
}
}

View File

@ -1625,20 +1625,16 @@ virNetSocketWrite;
# virnettlscontext.h
virNetTLSContextCheckCertificate;
virNetTLSContextFree;
virNetTLSContextNewClient;
virNetTLSContextNewClientPath;
virNetTLSContextNewServer;
virNetTLSContextNewServerPath;
virNetTLSContextRef;
virNetTLSInit;
virNetTLSSessionFree;
virNetTLSSessionGetHandshakeStatus;
virNetTLSSessionGetKeySize;
virNetTLSSessionHandshake;
virNetTLSSessionNew;
virNetTLSSessionRead;
virNetTLSSessionRef;
virNetTLSSessionSetIOCallbacks;
virNetTLSSessionWrite;

View File

@ -61,19 +61,15 @@ provider libvirt {
# file: src/rpc/virnettlscontext.c
# prefix: rpc
probe rpc_tls_context_new(void *ctxt, int refs, const char *cacert, const char *cacrl,
probe rpc_tls_context_new(void *ctxt, const char *cacert, const char *cacrl,
const char *cert, const char *key, int sanityCheckCert, int requireValidCert, int isServer);
probe rpc_tls_context_ref(void *ctxt, int refs);
probe rpc_tls_context_free(void *ctxt, int refs);
probe rpc_tls_context_session_allow(void *ctxt, void *sess, const char *dname);
probe rpc_tls_context_session_deny(void *ctxt, void *sess, const char *dname);
probe rpc_tls_context_session_fail(void *ctxt, void *sess);
probe rpc_tls_session_new(void *sess, void *ctxt, int refs, const char *hostname, int isServer);
probe rpc_tls_session_ref(void *sess, int refs);
probe rpc_tls_session_free(void *sess, int refs);
probe rpc_tls_session_new(void *sess, void *ctxt, const char *hostname, int isServer);
probe rpc_tls_session_handshake_pass(void *sess);
probe rpc_tls_session_handshake_fail(void *sess);

View File

@ -943,7 +943,7 @@ doRemoteClose (virConnectPtr conn, struct private_data *priv)
(xdrproc_t) xdr_void, (char *) NULL) == -1)
ret = -1;
virNetTLSContextFree(priv->tls);
virObjectUnref(priv->tls);
priv->tls = NULL;
virNetClientClose(priv->client);
virNetClientFree(priv->client);

View File

@ -495,7 +495,7 @@ void virNetClientFree(virNetClientPtr client)
if (client->sock)
virNetSocketRemoveIOCallback(client->sock);
virNetSocketFree(client->sock);
virNetTLSSessionFree(client->tls);
virObjectUnref(client->tls);
#if HAVE_SASL
virNetSASLSessionFree(client->sasl);
#endif
@ -532,7 +532,7 @@ virNetClientCloseLocked(virNetClientPtr client)
virNetSocketFree(client->sock);
client->sock = NULL;
virNetTLSSessionFree(client->tls);
virObjectUnref(client->tls);
client->tls = NULL;
#if HAVE_SASL
virNetSASLSessionFree(client->sasl);
@ -712,7 +712,7 @@ int virNetClientSetTLSSession(virNetClientPtr client,
return 0;
error:
virNetTLSSessionFree(client->tls);
virObjectUnref(client->tls);
client->tls = NULL;
virNetClientUnlock(client);
return -1;

View File

@ -642,8 +642,7 @@ no_memory:
int virNetServerSetTLSContext(virNetServerPtr srv,
virNetTLSContextPtr tls)
{
srv->tls = tls;
virNetTLSContextRef(tls);
srv->tls = virObjectRef(tls);
return 0;
}

View File

@ -346,7 +346,7 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
client->sock = sock;
client->auth = auth;
client->readonly = readonly;
client->tlsCtxt = tls;
client->tlsCtxt = virObjectRef(tls);
client->nrequests_max = nrequests_max;
client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc,
@ -354,9 +354,6 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
if (client->sockTimer < 0)
goto error;
if (tls)
virNetTLSContextRef(tls);
/* Prepare one for packet receive */
if (!(client->rx = virNetMessageNew(true)))
goto error;
@ -598,8 +595,8 @@ void virNetServerClientFree(virNetServerClientPtr client)
#endif
if (client->sockTimer > 0)
virEventRemoveTimeout(client->sockTimer);
virNetTLSSessionFree(client->tls);
virNetTLSContextFree(client->tlsCtxt);
virObjectUnref(client->tls);
virObjectUnref(client->tlsCtxt);
virNetSocketFree(client->sock);
virNetServerClientUnlock(client);
virMutexDestroy(&client->lock);
@ -654,7 +651,7 @@ void virNetServerClientClose(virNetServerClientPtr client)
virNetSocketRemoveIOCallback(client->sock);
if (client->tls) {
virNetTLSSessionFree(client->tls);
virObjectUnref(client->tls);
client->tls = NULL;
}
client->wantClose = true;

View File

@ -116,9 +116,7 @@ virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename,
svc->auth = auth;
svc->readonly = readonly;
svc->nrequests_client_max = nrequests_client_max;
svc->tls = tls;
if (tls)
virNetTLSContextRef(tls);
svc->tls = virObjectRef(tls);
if (virNetSocketNewListenTCP(nodename,
service,
@ -172,9 +170,7 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path,
svc->auth = auth;
svc->readonly = readonly;
svc->nrequests_client_max = nrequests_client_max;
svc->tls = tls;
if (tls)
virNetTLSContextRef(tls);
svc->tls = virObjectRef(tls);
svc->nsocks = 1;
if (VIR_ALLOC_N(svc->socks, svc->nsocks) < 0)
@ -265,7 +261,7 @@ void virNetServerServiceFree(virNetServerServicePtr svc)
virNetSocketFree(svc->socks[i]);
VIR_FREE(svc->socks);
virNetTLSContextFree(svc->tls);
virObjectUnref(svc->tls);
VIR_FREE(svc);
}

View File

@ -748,7 +748,7 @@ void virNetSocketFree(virNetSocketPtr sock)
/* Make sure it can't send any more I/O during shutdown */
if (sock->tlsSession)
virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
virNetTLSSessionFree(sock->tlsSession);
virObjectUnref(sock->tlsSession);
#if HAVE_SASL
virNetSASLSessionFree(sock->saslSession);
#endif
@ -909,13 +909,12 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess)
{
virMutexLock(&sock->lock);
virNetTLSSessionFree(sock->tlsSession);
sock->tlsSession = sess;
virObjectUnref(sock->tlsSession);
sock->tlsSession = virObjectRef(sess);
virNetTLSSessionSetIOCallbacks(sess,
virNetSocketTLSSessionWrite,
virNetSocketTLSSessionRead,
sock);
virNetTLSSessionRef(sess);
virMutexUnlock(&sock->lock);
}

View File

@ -50,8 +50,9 @@
#define VIR_FROM_THIS VIR_FROM_RPC
struct _virNetTLSContext {
virObject object;
virMutex lock;
int refs;
gnutls_certificate_credentials_t x509cred;
gnutls_dh_params_t dhParams;
@ -62,9 +63,9 @@ struct _virNetTLSContext {
};
struct _virNetTLSSession {
virMutex lock;
virObject object;
int refs;
virMutex lock;
bool handshakeComplete;
@ -76,6 +77,29 @@ struct _virNetTLSSession {
void *opaque;
};
static virClassPtr virNetTLSContextClass;
static virClassPtr virNetTLSSessionClass;
static void virNetTLSContextDispose(void *obj);
static void virNetTLSSessionDispose(void *obj);
static int virNetTLSContextOnceInit(void)
{
if (!(virNetTLSContextClass = virClassNew("virNetTLSContext",
sizeof(virNetTLSContext),
virNetTLSContextDispose)))
return -1;
if (!(virNetTLSSessionClass = virClassNew("virNetTLSSession",
sizeof(virNetTLSSession),
virNetTLSSessionDispose)))
return -1;
return 0;
}
VIR_ONCE_GLOBAL_INIT(virNetTLSContext)
static int
virNetTLSContextCheckCertFile(const char *type, const char *file, bool allowMissing)
@ -647,10 +671,11 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
char *gnutlsdebug;
int err;
if (VIR_ALLOC(ctxt) < 0) {
virReportOOMError();
if (virNetTLSContextInitialize() < 0)
return NULL;
if (!(ctxt = virObjectNew(virNetTLSContextClass)))
return NULL;
}
if (virMutexInit(&ctxt->lock) < 0) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
@ -659,8 +684,6 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
return NULL;
}
ctxt->refs = 1;
if ((gnutlsdebug = getenv("LIBVIRT_GNUTLS_DEBUG")) != NULL) {
int val;
if (virStrToLong_i(gnutlsdebug, NULL, 10, &val) < 0)
@ -716,8 +739,8 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
ctxt->isServer = isServer;
PROBE(RPC_TLS_CONTEXT_NEW,
"ctxt=%p refs=%d cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d requireValidCert=%d isServer=%d",
ctxt, ctxt->refs, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, requireValidCert, isServer);
"ctxt=%p cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d requireValidCert=%d isServer=%d",
ctxt, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, requireValidCert, isServer);
return ctxt;
@ -927,17 +950,6 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert,
}
void virNetTLSContextRef(virNetTLSContextPtr ctxt)
{
virMutexLock(&ctxt->lock);
ctxt->refs++;
PROBE(RPC_TLS_CONTEXT_REF,
"ctxt=%p refs=%d",
ctxt, ctxt->refs);
virMutexUnlock(&ctxt->lock);
}
static int virNetTLSContextValidCertificate(virNetTLSContextPtr ctxt,
virNetTLSSessionPtr sess)
{
@ -1106,30 +1118,16 @@ cleanup:
return ret;
}
void virNetTLSContextFree(virNetTLSContextPtr ctxt)
void virNetTLSContextDispose(void *obj)
{
if (!ctxt)
return;
virMutexLock(&ctxt->lock);
PROBE(RPC_TLS_CONTEXT_FREE,
"ctxt=%p refs=%d",
ctxt, ctxt->refs);
ctxt->refs--;
if (ctxt->refs > 0) {
virMutexUnlock(&ctxt->lock);
return;
}
virNetTLSContextPtr ctxt = obj;
gnutls_dh_params_deinit(ctxt->dhParams);
gnutls_certificate_free_credentials(ctxt->x509cred);
virMutexUnlock(&ctxt->lock);
virMutexDestroy(&ctxt->lock);
VIR_FREE(ctxt);
}
static ssize_t
virNetTLSSessionPush(void *opaque, const void *buf, size_t len)
{
@ -1167,10 +1165,8 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
VIR_DEBUG("ctxt=%p hostname=%s isServer=%d",
ctxt, NULLSTR(hostname), ctxt->isServer);
if (VIR_ALLOC(sess) < 0) {
virReportOOMError();
if (!(sess = virObjectNew(virNetTLSSessionClass)))
return NULL;
}
if (virMutexInit(&sess->lock) < 0) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
@ -1179,7 +1175,6 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
return NULL;
}
sess->refs = 1;
if (hostname &&
!(sess->hostname = strdup(hostname))) {
virReportOOMError();
@ -1230,27 +1225,17 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
sess->isServer = ctxt->isServer;
PROBE(RPC_TLS_SESSION_NEW,
"sess=%p refs=%d ctxt=%p hostname=%s isServer=%d",
sess, sess->refs, ctxt, hostname, sess->isServer);
"sess=%p ctxt=%p hostname=%s isServer=%d",
sess, ctxt, hostname, sess->isServer);
return sess;
error:
virNetTLSSessionFree(sess);
virObjectUnref(sess);
return NULL;
}
void virNetTLSSessionRef(virNetTLSSessionPtr sess)
{
virMutexLock(&sess->lock);
sess->refs++;
PROBE(RPC_TLS_SESSION_REF,
"sess=%p refs=%d",
sess, sess->refs);
virMutexUnlock(&sess->lock);
}
void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
virNetTLSSessionWriteFunc writeFunc,
virNetTLSSessionReadFunc readFunc,
@ -1393,26 +1378,13 @@ cleanup:
}
void virNetTLSSessionFree(virNetTLSSessionPtr sess)
void virNetTLSSessionDispose(void *obj)
{
if (!sess)
return;
virMutexLock(&sess->lock);
PROBE(RPC_TLS_SESSION_FREE,
"sess=%p refs=%d",
sess, sess->refs);
sess->refs--;
if (sess->refs > 0) {
virMutexUnlock(&sess->lock);
return;
}
virNetTLSSessionPtr sess = obj;
VIR_FREE(sess->hostname);
gnutls_deinit(sess->session);
virMutexUnlock(&sess->lock);
virMutexDestroy(&sess->lock);
VIR_FREE(sess);
}
/*

View File

@ -22,6 +22,7 @@
# define __VIR_NET_TLS_CONTEXT_H__
# include "internal.h"
# include "virobject.h"
typedef struct _virNetTLSContext virNetTLSContext;
typedef virNetTLSContext *virNetTLSContextPtr;
@ -58,13 +59,9 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert,
bool sanityCheckCert,
bool requireValidCert);
void virNetTLSContextRef(virNetTLSContextPtr ctxt);
int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt,
virNetTLSSessionPtr sess);
void virNetTLSContextFree(virNetTLSContextPtr ctxt);
typedef ssize_t (*virNetTLSSessionWriteFunc)(const char *buf, size_t len,
void *opaque);
@ -79,8 +76,6 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
virNetTLSSessionReadFunc readFunc,
void *opaque);
void virNetTLSSessionRef(virNetTLSSessionPtr sess);
ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
const char *buf, size_t len);
ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
@ -99,7 +94,4 @@ virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess);
int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess);
void virNetTLSSessionFree(virNetTLSSessionPtr sess);
#endif

View File

@ -496,7 +496,7 @@ static int testTLSContextInit(const void *opaque)
ret = 0;
cleanup:
virNetTLSContextFree(ctxt);
virObjectUnref(ctxt);
gnutls_x509_crt_deinit(data->careq.crt);
gnutls_x509_crt_deinit(data->certreq.crt);
data->careq.crt = data->certreq.crt = NULL;
@ -710,10 +710,10 @@ static int testTLSSessionInit(const void *opaque)
ret = 0;
cleanup:
virNetTLSContextFree(serverCtxt);
virNetTLSContextFree(clientCtxt);
virNetTLSSessionFree(serverSess);
virNetTLSSessionFree(clientSess);
virObjectUnref(serverCtxt);
virObjectUnref(clientCtxt);
virObjectUnref(serverSess);
virObjectUnref(clientSess);
gnutls_x509_crt_deinit(data->careq.crt);
if (data->othercareq.filename)
gnutls_x509_crt_deinit(data->othercareq.crt);