diff --git a/daemon/remote.c b/daemon/remote.c index fde029da21..b2a420bc00 100644 --- a/daemon/remote.c +++ b/daemon/remote.c @@ -2937,6 +2937,8 @@ remoteDispatchAuthSaslInit(virNetServerPtr server ATTRIBUTE_UNUSED, virNetSASLSessionPtr sasl = NULL; struct daemonClientPrivate *priv = virNetServerClientGetPrivateData(client); + char *localAddr = NULL; + char *remoteAddr = NULL; virMutexLock(&priv->lock); @@ -2947,10 +2949,17 @@ remoteDispatchAuthSaslInit(virNetServerPtr server ATTRIBUTE_UNUSED, goto authfail; } + localAddr = virNetServerClientLocalAddrFormatSASL(client); + remoteAddr = virNetServerClientRemoteAddrFormatSASL(client); + sasl = virNetSASLSessionNewServer(saslCtxt, "libvirt", - virNetServerClientLocalAddrString(client), - virNetServerClientRemoteAddrString(client)); + localAddr, + remoteAddr); + + VIR_FREE(localAddr); + VIR_FREE(remoteAddr); + if (!sasl) goto authfail; diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c index 6bed2c5389..e3cf5fbead 100644 --- a/src/remote/remote_driver.c +++ b/src/remote/remote_driver.c @@ -3684,6 +3684,8 @@ remoteAuthSASL(virConnectPtr conn, struct private_data *priv, sasl_callback_t *saslcb = NULL; int ret = -1; const char *mechlist; + char *localAddr = NULL; + char *remoteAddr = NULL; virNetSASLContextPtr saslCtxt; virNetSASLSessionPtr sasl = NULL; struct remoteAuthInteractState state; @@ -3702,6 +3704,9 @@ remoteAuthSASL(virConnectPtr conn, struct private_data *priv, saslcb = NULL; } + localAddr = virNetClientLocalAddrFormatSASL(priv->client); + remoteAddr = virNetClientRemoteAddrFormatSASL(priv->client); + /* Setup a handle for being a client */ if (!(sasl = virNetSASLSessionNewClient(saslCtxt, "libvirt", @@ -3889,6 +3894,8 @@ remoteAuthSASL(virConnectPtr conn, struct private_data *priv, cleanup: VIR_FREE(serverin); + VIR_FREE(localAddr); + VIR_FREE(remoteAddr); remoteAuthInteractStateClear(&state, true); VIR_FREE(saslcb); diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c index d8ed15b7e2..26b02fa944 100644 --- a/src/rpc/virnetclient.c +++ b/src/rpc/virnetclient.c @@ -954,6 +954,16 @@ const char *virNetClientRemoteAddrString(virNetClientPtr client) return virNetSocketRemoteAddrString(client->sock); } +char *virNetClientLocalAddrFormatSASL(virNetClientPtr client) +{ + return virNetSocketLocalAddrFormatSASL(client->sock); +} + +char *virNetClientRemoteAddrFormatSASL(virNetClientPtr client) +{ + return virNetSocketRemoteAddrFormatSASL(client->sock); +} + #if WITH_GNUTLS int virNetClientGetTLSKeySize(virNetClientPtr client) { diff --git a/src/rpc/virnetclient.h b/src/rpc/virnetclient.h index 38f929ca55..4b786775cb 100644 --- a/src/rpc/virnetclient.h +++ b/src/rpc/virnetclient.h @@ -123,6 +123,8 @@ bool virNetClientIsOpen(virNetClientPtr client); const char *virNetClientLocalAddrString(virNetClientPtr client); const char *virNetClientRemoteAddrString(virNetClientPtr client); +char *virNetClientLocalAddrFormatSASL(virNetClientPtr client); +char *virNetClientRemoteAddrFormatSASL(virNetClientPtr client); # ifdef WITH_GNUTLS int virNetClientGetTLSKeySize(virNetClientPtr client); diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index a9d70e1738..a7b3b15622 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -911,6 +911,19 @@ const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client) return virNetSocketRemoteAddrString(client->sock); } +char *virNetServerClientLocalAddrFormatSASL(virNetServerClientPtr client) +{ + if (!client->sock) + return NULL; + return virNetSocketLocalAddrFormatSASL(client->sock); +} + +char *virNetServerClientRemoteAddrFormatSASL(virNetServerClientPtr client) +{ + if (!client->sock) + return NULL; + return virNetSocketRemoteAddrFormatSASL(client->sock); +} void virNetServerClientDispose(void *obj) { diff --git a/src/rpc/virnetserverclient.h b/src/rpc/virnetserverclient.h index 1318fa2410..f44b7caba0 100644 --- a/src/rpc/virnetserverclient.h +++ b/src/rpc/virnetserverclient.h @@ -139,6 +139,8 @@ int virNetServerClientStartKeepAlive(virNetServerClientPtr client); const char *virNetServerClientLocalAddrString(virNetServerClientPtr client); const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client); +char *virNetServerClientLocalAddrFormatSASL(virNetServerClientPtr client); +char *virNetServerClientRemoteAddrFormatSASL(virNetServerClientPtr client); int virNetServerClientSendMessage(virNetServerClientPtr client, virNetMessagePtr msg); diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index d909b94d0f..a90cc55ada 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -262,11 +262,11 @@ static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr, if (localAddr && - !(sock->localAddrStr = virSocketAddrFormatFull(localAddr, true, ";"))) + !(sock->localAddrStr = virSocketAddrFormatFull(localAddr, true, NULL))) goto error; if (remoteAddr && - !(sock->remoteAddrStr = virSocketAddrFormatFull(remoteAddr, true, ";"))) + !(sock->remoteAddrStr = virSocketAddrFormatFull(remoteAddr, true, NULL))) goto error; sock->client = isClient; @@ -1465,6 +1465,19 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock) return sock->remoteAddrStr; } +/* These helper functions return a SASL-formatted socket addr string, + * caller is responsible for freeing the string. + */ +char *virNetSocketLocalAddrFormatSASL(virNetSocketPtr sock) +{ + return virSocketAddrFormatFull(&sock->localAddr, true, ";"); +} + +char *virNetSocketRemoteAddrFormatSASL(virNetSocketPtr sock) +{ + return virSocketAddrFormatFull(&sock->remoteAddr, true, ";"); +} + #if WITH_GNUTLS static ssize_t virNetSocketTLSSessionWrite(const char *buf, diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index 5de3d92631..4eb7187b23 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -150,6 +150,8 @@ bool virNetSocketHasPendingData(virNetSocketPtr sock); const char *virNetSocketLocalAddrString(virNetSocketPtr sock); const char *virNetSocketRemoteAddrString(virNetSocketPtr sock); +char *virNetSocketLocalAddrFormatSASL(virNetSocketPtr sock); +char *virNetSocketRemoteAddrFormatSASL(virNetSocketPtr sock); int virNetSocketListen(virNetSocketPtr sock, int backlog); int virNetSocketAccept(virNetSocketPtr sock, diff --git a/src/util/virsocketaddr.c b/src/util/virsocketaddr.c index 4b456819b1..a0c92ea609 100644 --- a/src/util/virsocketaddr.c +++ b/src/util/virsocketaddr.c @@ -339,9 +339,11 @@ virSocketAddrFormat(const virSocketAddr *addr) * @withService: if true, then service info is appended * @separator: separator between hostname & service. * - * Returns a string representation of the given address - * Returns NULL on any error - * Caller must free the returned string + * Returns a string representation of the given address. If a format conforming + * to URI specification is required, NULL should be passed to separator. + * Set @separator only if non-URI format is required, e.g. passing ';' for + * @separator if the address should be used with SASL. + * Caller must free the returned string. */ char * virSocketAddrFormatFull(const virSocketAddr *addr, @@ -383,8 +385,22 @@ virSocketAddrFormatFull(const virSocketAddr *addr, } if (withService) { - if (virAsprintf(&addrstr, "%s%s%s", host, separator, port) == -1) + char *ipv6_host = NULL; + /* sasl_new_client demands the socket address to be in an odd format: + * a.b.c.d;port or e:f:g:h:i:j:k:l;port, so use square brackets for + * IPv6 only if no separator is passed to the function + */ + if (!separator && VIR_SOCKET_ADDR_FAMILY(addr) == AF_INET6) { + if (virAsprintf(&ipv6_host, "[%s]", host) < 0) + goto error; + } + + if (virAsprintf(&addrstr, "%s%s%s", + ipv6_host ? ipv6_host : host, + separator ? separator : ":", port) == -1) goto error; + + VIR_FREE(ipv6_host); } else { if (VIR_STRDUP(addrstr, host) < 0) goto error; diff --git a/tests/virnetsockettest.c b/tests/virnetsockettest.c index 5786870a82..c2bc4e739c 100644 --- a/tests/virnetsockettest.c +++ b/tests/virnetsockettest.c @@ -249,7 +249,7 @@ static int testSocketUNIXAddrs(const void *data ATTRIBUTE_UNUSED) if (virNetSocketNewListenUNIX(path, 0700, -1, getegid(), &lsock) < 0) goto cleanup; - if (STRNEQ(virNetSocketLocalAddrString(lsock), "127.0.0.1;0")) { + if (STRNEQ(virNetSocketLocalAddrString(lsock), "127.0.0.1:0")) { VIR_DEBUG("Unexpected local address"); goto cleanup; } @@ -265,12 +265,12 @@ static int testSocketUNIXAddrs(const void *data ATTRIBUTE_UNUSED) if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0) goto cleanup; - if (STRNEQ(virNetSocketLocalAddrString(csock), "127.0.0.1;0")) { + if (STRNEQ(virNetSocketLocalAddrString(csock), "127.0.0.1:0")) { VIR_DEBUG("Unexpected local address"); goto cleanup; } - if (STRNEQ(virNetSocketRemoteAddrString(csock), "127.0.0.1;0")) { + if (STRNEQ(virNetSocketRemoteAddrString(csock), "127.0.0.1:0")) { VIR_DEBUG("Unexpected local address"); goto cleanup; } @@ -282,12 +282,12 @@ static int testSocketUNIXAddrs(const void *data ATTRIBUTE_UNUSED) } - if (STRNEQ(virNetSocketLocalAddrString(ssock), "127.0.0.1;0")) { + if (STRNEQ(virNetSocketLocalAddrString(ssock), "127.0.0.1:0")) { VIR_DEBUG("Unexpected local address"); goto cleanup; } - if (STRNEQ(virNetSocketRemoteAddrString(ssock), "127.0.0.1;0")) { + if (STRNEQ(virNetSocketRemoteAddrString(ssock), "127.0.0.1:0")) { VIR_DEBUG("Unexpected local address"); goto cleanup; }