1
0
mirror of https://gitlab.com/libvirt/libvirt.git synced 2025-03-07 17:28:15 +00:00

Integrate TLS/SASL directly into the socket APIs

This extends the basic virNetSocket APIs to allow them to have
a handle to the TLS/SASL session objects, once established.
This ensures that any data reads/writes are automagically
passed through the TLS/SASL encryption layers if required.

* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up
  SASL/TLS encryption
This commit is contained in:
Daniel P. Berrange 2010-12-10 12:22:03 +00:00
parent bb1c9296f5
commit f5fa167e8d
2 changed files with 224 additions and 5 deletions

View File

@ -59,6 +59,19 @@ struct _virNetSocket {
virSocketAddr remoteAddr;
char *localAddrStr;
char *remoteAddrStr;
virNetTLSSessionPtr tlsSession;
#if HAVE_SASL
virNetSASLSessionPtr saslSession;
const char *saslDecoded;
size_t saslDecodedLength;
size_t saslDecodedOffset;
const char *saslEncoded;
size_t saslEncodedLength;
size_t saslEncodedOffset;
#endif
};
@ -417,7 +430,7 @@ error:
}
#if HAVE_SYS_UN_H
#ifdef HAVE_SYS_UN_H
int virNetSocketNewConnectUNIX(const char *path,
bool spawnDaemon,
const char *binary,
@ -624,6 +637,14 @@ void virNetSocketFree(virNetSocketPtr sock)
unlink(sock->localAddr.data.un.sun_path);
#endif
/* Make sure it can't send any more I/O during shutdown */
if (sock->tlsSession)
virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
virNetTLSSessionFree(sock->tlsSession);
#if HAVE_SASL
virNetSASLSessionFree(sock->saslSession);
#endif
VIR_FORCE_CLOSE(sock->fd);
VIR_FORCE_CLOSE(sock->errfd);
@ -709,17 +730,77 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock)
return sock->remoteAddrStr;
}
ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
static ssize_t virNetSocketTLSSessionWrite(const char *buf,
size_t len,
void *opaque)
{
virNetSocketPtr sock = opaque;
return write(sock->fd, buf, len);
}
static ssize_t virNetSocketTLSSessionRead(char *buf,
size_t len,
void *opaque)
{
virNetSocketPtr sock = opaque;
return read(sock->fd, buf, len);
}
void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess)
{
virNetTLSSessionFree(sock->tlsSession);
sock->tlsSession = sess;
virNetTLSSessionSetIOCallbacks(sess,
virNetSocketTLSSessionWrite,
virNetSocketTLSSessionRead,
sock);
virNetTLSSessionRef(sess);
}
#if HAVE_SASL
void virNetSocketSetSASLSession(virNetSocketPtr sock,
virNetSASLSessionPtr sess)
{
virNetSASLSessionFree(sock->saslSession);
sock->saslSession = sess;
virNetSASLSessionRef(sess);
}
#endif
bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED)
{
#if HAVE_SASL
if (sock->saslDecoded)
return true;
#endif
return false;
}
static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len)
{
char *errout = NULL;
ssize_t ret;
reread:
ret = read(sock->fd, buf, len);
if (sock->tlsSession &&
virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
VIR_NET_TLS_HANDSHAKE_COMPLETE) {
ret = virNetTLSSessionRead(sock->tlsSession, buf, len);
} else {
ret = read(sock->fd, buf, len);
}
if ((ret < 0) && (errno == EINTR))
goto reread;
if ((ret < 0) && (errno == EAGAIN))
return 0;
if (ret <= 0 &&
sock->errfd != -1 &&
virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 &&
@ -751,11 +832,17 @@ reread:
return ret;
}
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len)
{
ssize_t ret;
rewrite:
ret = write(sock->fd, buf, len);
if (sock->tlsSession &&
virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
VIR_NET_TLS_HANDSHAKE_COMPLETE) {
ret = virNetTLSSessionWrite(sock->tlsSession, buf, len);
} else {
ret = write(sock->fd, buf, len);
}
if (ret < 0) {
if (errno == EINTR)
@ -777,6 +864,127 @@ rewrite:
}
#if HAVE_SASL
static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len)
{
ssize_t got;
/* Need to read some more data off the wire */
if (sock->saslDecoded == NULL) {
ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession);
char *encoded;
if (VIR_ALLOC_N(encoded, encodedLen) < 0) {
virReportOOMError();
return -1;
}
encodedLen = virNetSocketReadWire(sock, encoded, encodedLen);
if (encodedLen <= 0) {
VIR_FREE(encoded);
return encodedLen;
}
if (virNetSASLSessionDecode(sock->saslSession,
encoded, encodedLen,
&sock->saslDecoded, &sock->saslDecodedLength) < 0) {
VIR_FREE(encoded);
return -1;
}
VIR_FREE(encoded);
sock->saslDecodedOffset = 0;
}
/* Some buffered decoded data to return now */
got = sock->saslDecodedLength - sock->saslDecodedOffset;
if (len > got)
len = got;
memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len);
sock->saslDecodedOffset += len;
if (sock->saslDecodedOffset == sock->saslDecodedLength) {
sock->saslDecoded = NULL;
sock->saslDecodedOffset = sock->saslDecodedLength = 0;
}
return len;
}
static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len)
{
int ret;
size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession);
/* SASL doesn't necessarily let us send the whole
buffer at once */
if (tosend > len)
tosend = len;
/* Not got any pending encoded data, so we need to encode raw stuff */
if (sock->saslEncoded == NULL) {
if (virNetSASLSessionEncode(sock->saslSession,
buf, tosend,
&sock->saslEncoded,
&sock->saslEncodedLength) < 0)
return -1;
sock->saslEncodedOffset = 0;
}
/* Send some of the encoded stuff out on the wire */
ret = virNetSocketWriteWire(sock,
sock->saslEncoded + sock->saslEncodedOffset,
sock->saslEncodedLength - sock->saslEncodedOffset);
if (ret <= 0)
return ret; /* -1 error, 0 == egain */
/* Note how much we sent */
sock->saslEncodedOffset += ret;
/* Sent all encoded, so update raw buffer to indicate completion */
if (sock->saslEncodedOffset == sock->saslEncodedLength) {
sock->saslEncoded = NULL;
sock->saslEncodedOffset = sock->saslEncodedLength = 0;
/* Mark as complete, so caller detects completion */
return tosend;
} else {
/* Still have stuff pending in saslEncoded buffer.
* Pretend to caller that we didn't send any yet.
* The caller will then retry with same buffer
* shortly, which lets us finish saslEncoded.
*/
return 0;
}
}
#endif
ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
{
#if HAVE_SASL
if (sock->saslSession)
return virNetSocketReadSASL(sock, buf, len);
else
#endif
return virNetSocketReadWire(sock, buf, len);
}
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
{
#if HAVE_SASL
if (sock->saslSession)
return virNetSocketWriteSASL(sock, buf, len);
else
#endif
return virNetSocketWriteWire(sock, buf, len);
}
int virNetSocketListen(virNetSocketPtr sock)
{
if (listen(sock->fd, 30) < 0) {

View File

@ -26,6 +26,10 @@
# include "network.h"
# include "command.h"
# include "virnettlscontext.h"
# ifdef HAVE_SASL
# include "virnetsaslcontext.h"
# endif
typedef struct _virNetSocket virNetSocket;
typedef virNetSocket *virNetSocketPtr;
@ -83,6 +87,13 @@ int virNetSocketSetBlocking(virNetSocketPtr sock,
ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess);
# ifdef HAVE_SASL
void virNetSocketSetSASLSession(virNetSocketPtr sock,
virNetSASLSessionPtr sess);
# endif
bool virNetSocketHasCachedData(virNetSocketPtr sock);
void virNetSocketFree(virNetSocketPtr sock);
const char *virNetSocketLocalAddrString(virNetSocketPtr sock);