diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c index 12f7d47388..87398df575 100644 --- a/src/remote/remote_driver.c +++ b/src/remote/remote_driver.c @@ -24,7 +24,6 @@ #include #include -#include #include "virnetclient.h" #include "virnetclientprogram.h" @@ -162,7 +161,26 @@ static void make_nonnull_domain_snapshot(remote_nonnull_domain_snapshot *snapsho /*----------------------------------------------------------------------*/ /* Helper functions for remoteOpen. */ -static char *get_transport_from_scheme(char *scheme); +static int remoteSplitURIScheme(virURIPtr uri, + char **driver, + char **transport) +{ + char *p = strchr(uri->scheme, '+'); + + *driver = *transport = NULL; + + if (VIR_STRNDUP(*driver, uri->scheme, p ? p - uri->scheme : -1) < 0) + return -1; + + if (p && + VIR_STRDUP(*transport, p + 1) < 0) { + VIR_FREE(*driver); + return -1; + } + + return 0; +} + static int remoteStateInitialize(bool privileged ATTRIBUTE_UNUSED, @@ -715,11 +733,12 @@ remoteConnectSupportsFeatureUnlocked(virConnectPtr conn, static int doRemoteOpen(virConnectPtr conn, struct private_data *priv, + const char *driver_str, + const char *transport_str, virConnectAuthPtr auth ATTRIBUTE_UNUSED, virConfPtr conf, unsigned int flags) { - char *transport_str = NULL; enum { trans_tls, trans_unix, @@ -738,8 +757,6 @@ doRemoteOpen(virConnectPtr conn, * URIs we don't care about */ if (conn->uri) { - transport_str = get_transport_from_scheme(conn->uri->scheme); - if (!transport_str) { if (conn->uri->server) transport = trans_tls; @@ -873,26 +890,16 @@ doRemoteOpen(virConnectPtr conn, goto failed; } else { virURI tmpuri = { - .scheme = conn->uri->scheme, + .scheme = (char *)driver_str, .query = virURIFormatParams(conn->uri), .path = conn->uri->path, .fragment = conn->uri->fragment, }; - /* Evil, blank out transport scheme temporarily */ - if (transport_str) { - assert(transport_str[-1] == '+'); - transport_str[-1] = '\0'; - } - name = virURIFormat(&tmpuri); VIR_FREE(tmpuri.query); - /* Restore transport scheme */ - if (transport_str) - transport_str[-1] = '+'; - if (!name) goto failed; } @@ -1297,14 +1304,23 @@ remoteConnectOpen(virConnectPtr conn, unsigned int flags) { struct private_data *priv; - int ret, rflags = 0; + int ret = VIR_DRV_OPEN_ERROR; + int rflags = 0; const char *autostart = virGetEnvBlockSUID("LIBVIRT_AUTOSTART"); + char *driver = NULL; + char *transport = NULL; - if (inside_daemon && (!conn->uri || !conn->uri->server)) - return VIR_DRV_OPEN_DECLINED; + if (conn->uri && + remoteSplitURIScheme(conn->uri, &driver, &transport) < 0) + goto cleanup; + + if (inside_daemon && (!conn->uri || !conn->uri->server)) { + ret = VIR_DRV_OPEN_DECLINED; + goto cleanup; + } if (!(priv = remoteAllocPrivateData())) - return VIR_DRV_OPEN_ERROR; + goto cleanup; if (flags & VIR_CONNECT_RO) rflags |= VIR_DRV_OPEN_REMOTE_RO; @@ -1319,8 +1335,7 @@ remoteConnectOpen(virConnectPtr conn, !conn->uri->server && conn->uri->path && conn->uri->scheme && - ((strchr(conn->uri->scheme, '+') == 0)|| - (strstr(conn->uri->scheme, "+unix") != NULL)) && + (transport == NULL || STREQ(transport, "unix")) && (STREQ(conn->uri->path, "/session") || STRPREFIX(conn->uri->scheme, "test+")) && geteuid() > 0) { @@ -1348,7 +1363,7 @@ remoteConnectOpen(virConnectPtr conn, } } - ret = doRemoteOpen(conn, priv, auth, conf, rflags); + ret = doRemoteOpen(conn, priv, driver, transport, auth, conf, rflags); if (ret != VIR_DRV_OPEN_SUCCESS) { conn->privateData = NULL; remoteDriverUnlock(priv); @@ -1357,18 +1372,14 @@ remoteConnectOpen(virConnectPtr conn, conn->privateData = priv; remoteDriverUnlock(priv); } + + cleanup: + VIR_FREE(driver); + VIR_FREE(transport); return ret; } -/* In a string "driver+transport" return a pointer to "transport". */ -static char * -get_transport_from_scheme(char *scheme) -{ - char *p = strchr(scheme, '+'); - return p ? p + 1 : NULL; -} - /*----------------------------------------------------------------------*/