diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c index ea119c67af..840e481d30 100644 --- a/src/remote/remote_driver.c +++ b/src/remote/remote_driver.c @@ -268,7 +268,7 @@ void remoteDomainEventQueueFlush(int timer, void *opaque); static char *get_transport_from_scheme (char *scheme); /* GnuTLS functions used by remoteOpen. */ -static int initialize_gnutls(void); +static int initialize_gnutls(char *pkipath, int flags); static gnutls_session_t negotiate_gnutls_on_connection (virConnectPtr conn, struct private_data *priv, int no_verify); #ifdef WITH_LIBVIRTD @@ -430,6 +430,7 @@ doRemoteOpen (virConnectPtr conn, char *port = NULL, *authtype = NULL, *username = NULL; int no_verify = 0, no_tty = 0; char **cmd_argv = NULL; + char *pkipath = NULL; /* Return code from this function, and the private data. */ int retcode = VIR_DRV_OPEN_ERROR; @@ -509,9 +510,14 @@ doRemoteOpen (virConnectPtr conn, priv->debugLog = stdout; else priv->debugLog = stderr; - } else + } else if (STRCASEEQ(var->name, "pkipath")) { + pkipath = strdup(var->value); + if (!pkipath) goto out_of_memory; + var->ignore = 1; + } else { DEBUG("passing through variable '%s' ('%s') to remote end", var->name, var->value); + } } /* Construct the original name. */ @@ -577,7 +583,7 @@ doRemoteOpen (virConnectPtr conn, /* Connect to the remote service. */ switch (transport) { case trans_tls: - if (initialize_gnutls() == -1) goto failed; + if (initialize_gnutls(pkipath, flags) == -1) goto failed; priv->uses_tls = 1; priv->is_secure = 1; @@ -947,6 +953,7 @@ doRemoteOpen (virConnectPtr conn, } VIR_FREE(cmd_argv); } + VIR_FREE(pkipath); return retcode; @@ -1139,11 +1146,17 @@ static void remote_debug_gnutls_log(int level, const char* str) { } static int -initialize_gnutls(void) +initialize_gnutls(char *pkipath, int flags) { static int initialized = 0; int err; char *gnutlsdebug; + char *libvirt_cacert = NULL; + char *libvirt_clientkey = NULL; + char *libvirt_clientcert = NULL; + int ret = -1; + char *userdir = NULL; + char *user_pki_path = NULL; if (initialized) return 0; @@ -1166,43 +1179,124 @@ initialize_gnutls(void) return -1; } + if (pkipath) { + if ((virAsprintf(&libvirt_cacert, "%s/%s", pkipath, + "cacert.pem")) < 0) + goto out_of_memory; - if (check_cert_file("CA certificate", LIBVIRT_CACERT) < 0) - return -1; - if (check_cert_file("client key", LIBVIRT_CLIENTKEY) < 0) - return -1; - if (check_cert_file("client certificate", LIBVIRT_CLIENTCERT) < 0) - return -1; + if ((virAsprintf(&libvirt_clientkey, "%s/%s", pkipath, + "clientkey.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(&libvirt_clientcert, "%s/%s", pkipath, + "clientcert.pem")) < 0) + goto out_of_memory; + } else if (flags & VIR_DRV_OPEN_REMOTE_USER) { + userdir = virGetUserDirectory(getuid()); + + if (!userdir) + goto out_of_memory; + + if (virAsprintf(&user_pki_path, "%s/.pki/libvirt", userdir) < 0) + goto out_of_memory; + + if ((virAsprintf(&libvirt_cacert, "%s/%s", user_pki_path, + "cacert.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(&libvirt_clientkey, "%s/%s", user_pki_path, + "clientkey.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(&libvirt_clientcert, "%s/%s", user_pki_path, + "clientcert.pem")) < 0) + goto out_of_memory; + + /* Use default location as long as one of CA certificate, + * client key, and client certificate can not be found in + * $HOME/.pki/libvirt, we don't want to make user confused + * with one file is here, the other is there. + */ + if (!virFileExists(libvirt_cacert) || + !virFileExists(libvirt_clientkey) || + !virFileExists(libvirt_clientcert)) { + VIR_FREE(libvirt_cacert); + VIR_FREE(libvirt_clientkey); + VIR_FREE(libvirt_clientcert); + + libvirt_cacert = strdup(LIBVIRT_CACERT); + if (!libvirt_cacert) goto out_of_memory; + + libvirt_clientkey = strdup(LIBVIRT_CLIENTKEY); + if (!libvirt_clientkey) goto out_of_memory; + + libvirt_clientcert = strdup(LIBVIRT_CLIENTCERT); + if (!libvirt_clientcert) goto out_of_memory; + } + } else { + libvirt_cacert = strdup(LIBVIRT_CACERT); + if (!libvirt_cacert) goto out_of_memory; + + libvirt_clientkey = strdup(LIBVIRT_CLIENTKEY); + if (!libvirt_clientkey) goto out_of_memory; + + libvirt_clientcert = strdup(LIBVIRT_CLIENTCERT); + if (!libvirt_clientcert) goto out_of_memory; + } + + if (check_cert_file("CA certificate", libvirt_cacert) < 0) + goto error; + if (check_cert_file("client key", libvirt_clientkey) < 0) + goto error; + if (check_cert_file("client certificate", libvirt_clientcert) < 0) + goto error; /* Set the trusted CA cert. */ - DEBUG("loading CA file %s", LIBVIRT_CACERT); + DEBUG("loading CA file %s", libvirt_cacert); err = - gnutls_certificate_set_x509_trust_file (x509_cred, LIBVIRT_CACERT, + gnutls_certificate_set_x509_trust_file (x509_cred, libvirt_cacert, GNUTLS_X509_FMT_PEM); if (err < 0) { remoteError(VIR_ERR_GNUTLS_ERROR, _("unable to load CA certificate: %s"), gnutls_strerror (err)); - return -1; + goto error; } /* Set the client certificate and private key. */ DEBUG("loading client cert and key from files %s and %s", - LIBVIRT_CLIENTCERT, LIBVIRT_CLIENTKEY); + libvirt_clientcert, libvirt_clientkey); err = gnutls_certificate_set_x509_key_file (x509_cred, - LIBVIRT_CLIENTCERT, - LIBVIRT_CLIENTKEY, + libvirt_clientcert, + libvirt_clientkey, GNUTLS_X509_FMT_PEM); if (err < 0) { remoteError(VIR_ERR_GNUTLS_ERROR, _("unable to load private key/certificate: %s"), gnutls_strerror (err)); - return -1; + goto error; } initialized = 1; - return 0; + ret = 0; + +cleanup: + VIR_FREE(libvirt_cacert); + VIR_FREE(libvirt_clientkey); + VIR_FREE(libvirt_clientcert); + VIR_FREE(userdir); + VIR_FREE(user_pki_path); + return ret; + +error: + ret = -1; + goto cleanup; + +out_of_memory: + ret = -1; + virReportOOMError(); + goto cleanup; } static int verify_certificate (virConnectPtr conn, struct private_data *priv, gnutls_session_t session);