virnetsshsession: Pass in username via virNetSSHSessionNew rather than auth functions

We only ever allow one username so there's no point passing it to each
authentication registration function. Additionally the only caller
(virNetClientNewLibSSH2) always passes a username so all the checks were
pointless.

Signed-off-by: Peter Krempa <pkrempa@redhat.com>
Reviewed-by: Jonathon Jongsma <jjongsma@redhat.com>
This commit is contained in:
Peter Krempa 2022-12-08 15:18:54 +01:00
parent 513d84daf6
commit 7fb0c7418e
3 changed files with 29 additions and 79 deletions

View File

@ -909,7 +909,7 @@ virNetSocketNewConnectLibSSH2(const char *host,
} }
/* create ssh session context */ /* create ssh session context */
if (!(sess = virNetSSHSessionNew())) if (!(sess = virNetSSHSessionNew(username)))
goto error; goto error;
/* set ssh session parameters */ /* set ssh session parameters */
@ -946,17 +946,13 @@ virNetSocketNewConnectLibSSH2(const char *host,
const char *authMethod = *authMethodNext; const char *authMethod = *authMethodNext;
if (STRCASEEQ(authMethod, "keyboard-interactive")) { if (STRCASEEQ(authMethod, "keyboard-interactive")) {
ret = virNetSSHSessionAuthAddKeyboardAuth(sess, username, -1); ret = virNetSSHSessionAuthAddKeyboardAuth(sess, -1);
} else if (STRCASEEQ(authMethod, "password")) { } else if (STRCASEEQ(authMethod, "password")) {
ret = virNetSSHSessionAuthAddPasswordAuth(sess, ret = virNetSSHSessionAuthAddPasswordAuth(sess, uri);
uri,
username);
} else if (STRCASEEQ(authMethod, "privkey")) { } else if (STRCASEEQ(authMethod, "privkey")) {
ret = virNetSSHSessionAuthAddPrivKeyAuth(sess, ret = virNetSSHSessionAuthAddPrivKeyAuth(sess, privkey);
username,
privkey);
} else if (STRCASEEQ(authMethod, "agent")) { } else if (STRCASEEQ(authMethod, "agent")) {
ret = virNetSSHSessionAuthAddAgentAuth(sess, username); ret = virNetSSHSessionAuthAddAgentAuth(sess);
} else { } else {
virReportError(VIR_ERR_INVALID_ARG, virReportError(VIR_ERR_INVALID_ARG,
_("Invalid authentication method: '%s'"), _("Invalid authentication method: '%s'"),

View File

@ -70,7 +70,6 @@ typedef struct _virNetSSHAuthMethod virNetSSHAuthMethod;
struct _virNetSSHAuthMethod { struct _virNetSSHAuthMethod {
virNetSSHAuthMethods method; virNetSSHAuthMethods method;
char *username;
char *filename; char *filename;
int tries; int tries;
@ -93,6 +92,7 @@ struct _virNetSSHSession {
int port; int port;
/* authentication stuff */ /* authentication stuff */
char *username;
virConnectAuthPtr cred; virConnectAuthPtr cred;
char *authPath; char *authPath;
virNetSSHAuthCallbackError authCbErr; virNetSSHAuthCallbackError authCbErr;
@ -115,7 +115,6 @@ virNetSSHSessionAuthMethodsClear(virNetSSHSession *sess)
size_t i; size_t i;
for (i = 0; i < sess->nauths; i++) { for (i = 0; i < sess->nauths; i++) {
VIR_FREE(sess->auths[i]->username);
VIR_FREE(sess->auths[i]->filename); VIR_FREE(sess->auths[i]->filename);
VIR_FREE(sess->auths[i]); VIR_FREE(sess->auths[i]);
} }
@ -151,6 +150,7 @@ virNetSSHSessionDispose(void *obj)
g_free(sess->hostname); g_free(sess->hostname);
g_free(sess->knownHostsFile); g_free(sess->knownHostsFile);
g_free(sess->authPath); g_free(sess->authPath);
g_free(sess->username);
} }
static virClass *virNetSSHSessionClass; static virClass *virNetSSHSessionClass;
@ -488,8 +488,7 @@ virNetSSHCheckHostKey(virNetSSHSession *sess)
* -1 on error * -1 on error
*/ */
static int static int
virNetSSHAuthenticateAgent(virNetSSHSession *sess, virNetSSHAuthenticateAgent(virNetSSHSession *sess)
virNetSSHAuthMethod *priv)
{ {
struct libssh2_agent_publickey *agent_identity = NULL; struct libssh2_agent_publickey *agent_identity = NULL;
bool no_identity = true; bool no_identity = true;
@ -515,7 +514,7 @@ virNetSSHAuthenticateAgent(virNetSSHSession *sess,
agent_identity))) { agent_identity))) {
no_identity = false; no_identity = false;
if (!(ret = libssh2_agent_userauth(sess->agent, if (!(ret = libssh2_agent_userauth(sess->agent,
priv->username, sess->username,
agent_identity))) agent_identity)))
return 0; /* key accepted */ return 0; /* key accepted */
@ -575,7 +574,7 @@ virNetSSHAuthenticatePrivkey(virNetSSHSession *sess,
/* try open the key with no password */ /* try open the key with no password */
if ((ret = libssh2_userauth_publickey_fromfile(sess->session, if ((ret = libssh2_userauth_publickey_fromfile(sess->session,
priv->username, sess->username,
NULL, NULL,
priv->filename, priv->filename,
NULL)) == 0) NULL)) == 0)
@ -634,7 +633,7 @@ virNetSSHAuthenticatePrivkey(virNetSSHSession *sess,
VIR_FREE(tmp); VIR_FREE(tmp);
ret = libssh2_userauth_publickey_fromfile(sess->session, ret = libssh2_userauth_publickey_fromfile(sess->session,
priv->username, sess->username,
NULL, NULL,
priv->filename, priv->filename,
retr_passphrase.result); retr_passphrase.result);
@ -668,8 +667,7 @@ virNetSSHAuthenticatePrivkey(virNetSSHSession *sess,
* -1 on error * -1 on error
*/ */
static int static int
virNetSSHAuthenticatePassword(virNetSSHSession *sess, virNetSSHAuthenticatePassword(virNetSSHSession *sess)
virNetSSHAuthMethod *priv)
{ {
char *password = NULL; char *password = NULL;
char *errmsg; char *errmsg;
@ -690,13 +688,13 @@ virNetSSHAuthenticatePassword(virNetSSHSession *sess,
* connection if maximum number of bad auth tries is exceeded */ * connection if maximum number of bad auth tries is exceeded */
while (true) { while (true) {
if (!(password = virAuthGetPasswordPath(sess->authPath, sess->cred, if (!(password = virAuthGetPasswordPath(sess->authPath, sess->cred,
"ssh", priv->username, "ssh", sess->username,
sess->hostname))) sess->hostname)))
goto cleanup; goto cleanup;
/* tunnelled password authentication */ /* tunnelled password authentication */
if ((rc = libssh2_userauth_password(sess->session, if ((rc = libssh2_userauth_password(sess->session,
priv->username, sess->username,
password)) == 0) { password)) == 0) {
ret = 0; ret = 0;
goto cleanup; goto cleanup;
@ -751,7 +749,7 @@ virNetSSHAuthenticateKeyboardInteractive(virNetSSHSession *sess,
* connection if maximum number of bad auth tries is exceeded */ * connection if maximum number of bad auth tries is exceeded */
while (priv->tries < 0 || priv->tries-- > 0) { while (priv->tries < 0 || priv->tries-- > 0) {
ret = libssh2_userauth_keyboard_interactive(sess->session, ret = libssh2_userauth_keyboard_interactive(sess->session,
priv->username, sess->username,
virNetSSHKbIntCb); virNetSSHKbIntCb);
/* check for errors while calling the callback */ /* check for errors while calling the callback */
@ -817,9 +815,8 @@ virNetSSHAuthenticate(virNetSSHSession *sess)
} }
/* obtain list of supported auth methods */ /* obtain list of supported auth methods */
auth_list = libssh2_userauth_list(sess->session, auth_list = libssh2_userauth_list(sess->session, sess->username,
sess->auths[0]->username, strlen(sess->username));
strlen(sess->auths[0]->username));
if (!auth_list) { if (!auth_list) {
/* unlikely event, authentication succeeded with NONE as method */ /* unlikely event, authentication succeeded with NONE as method */
if (libssh2_userauth_authenticated(sess->session) == 1) if (libssh2_userauth_authenticated(sess->session) == 1)
@ -845,7 +842,7 @@ virNetSSHAuthenticate(virNetSSHSession *sess)
break; break;
case VIR_NET_SSH_AUTH_AGENT: case VIR_NET_SSH_AUTH_AGENT:
if (strstr(auth_list, "publickey")) if (strstr(auth_list, "publickey"))
ret = virNetSSHAuthenticateAgent(sess, auth); ret = virNetSSHAuthenticateAgent(sess);
break; break;
case VIR_NET_SSH_AUTH_PRIVKEY: case VIR_NET_SSH_AUTH_PRIVKEY:
if (strstr(auth_list, "publickey")) if (strstr(auth_list, "publickey"))
@ -853,7 +850,7 @@ virNetSSHAuthenticate(virNetSSHSession *sess)
break; break;
case VIR_NET_SSH_AUTH_PASSWORD: case VIR_NET_SSH_AUTH_PASSWORD:
if (strstr(auth_list, "password")) if (strstr(auth_list, "password"))
ret = virNetSSHAuthenticatePassword(sess, auth); ret = virNetSSHAuthenticatePassword(sess);
break; break;
} }
@ -969,11 +966,9 @@ virNetSSHSessionAuthReset(virNetSSHSession *sess)
int int
virNetSSHSessionAuthAddPasswordAuth(virNetSSHSession *sess, virNetSSHSessionAuthAddPasswordAuth(virNetSSHSession *sess,
virURI *uri, virURI *uri)
const char *username)
{ {
virNetSSHAuthMethod *auth; virNetSSHAuthMethod *auth;
char *user = NULL;
if (uri) { if (uri) {
VIR_FREE(sess->authPath); VIR_FREE(sess->authPath);
@ -982,75 +977,50 @@ virNetSSHSessionAuthAddPasswordAuth(virNetSSHSession *sess,
goto error; goto error;
} }
if (!username) {
if (!(user = virAuthGetUsernamePath(sess->authPath, sess->cred,
"ssh", NULL, sess->hostname)))
goto error;
} else {
user = g_strdup(username);
}
virObjectLock(sess); virObjectLock(sess);
if (!(auth = virNetSSHSessionAuthMethodNew(sess))) if (!(auth = virNetSSHSessionAuthMethodNew(sess)))
goto error; goto error;
auth->username = user;
auth->method = VIR_NET_SSH_AUTH_PASSWORD; auth->method = VIR_NET_SSH_AUTH_PASSWORD;
virObjectUnlock(sess); virObjectUnlock(sess);
return 0; return 0;
error: error:
VIR_FREE(user);
virObjectUnlock(sess); virObjectUnlock(sess);
return -1; return -1;
} }
int int
virNetSSHSessionAuthAddAgentAuth(virNetSSHSession *sess, virNetSSHSessionAuthAddAgentAuth(virNetSSHSession *sess)
const char *username)
{ {
virNetSSHAuthMethod *auth; virNetSSHAuthMethod *auth;
char *user = NULL;
if (!username) {
virReportError(VIR_ERR_SSH, "%s",
_("Username must be provided "
"for ssh agent authentication"));
return -1;
}
virObjectLock(sess); virObjectLock(sess);
user = g_strdup(username);
if (!(auth = virNetSSHSessionAuthMethodNew(sess))) if (!(auth = virNetSSHSessionAuthMethodNew(sess)))
goto error; goto error;
auth->username = user;
auth->method = VIR_NET_SSH_AUTH_AGENT; auth->method = VIR_NET_SSH_AUTH_AGENT;
virObjectUnlock(sess); virObjectUnlock(sess);
return 0; return 0;
error: error:
VIR_FREE(user);
virObjectUnlock(sess); virObjectUnlock(sess);
return -1; return -1;
} }
int int
virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess, virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess,
const char *username,
const char *keyfile) const char *keyfile)
{ {
virNetSSHAuthMethod *auth; virNetSSHAuthMethod *auth;
if (!username || !keyfile) { if (!keyfile) {
virReportError(VIR_ERR_SSH, "%s", virReportError(VIR_ERR_SSH, "%s",
_("Username and key file path must be provided " _("Key file path must be provided for private key authentication"));
"for private key authentication"));
return -1; return -1;
} }
@ -1061,7 +1031,6 @@ virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess,
return -1; return -1;
} }
auth->username = g_strdup(username);
auth->filename = g_strdup(keyfile); auth->filename = g_strdup(keyfile);
auth->method = VIR_NET_SSH_AUTH_PRIVKEY; auth->method = VIR_NET_SSH_AUTH_PRIVKEY;
@ -1071,27 +1040,15 @@ virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess,
int int
virNetSSHSessionAuthAddKeyboardAuth(virNetSSHSession *sess, virNetSSHSessionAuthAddKeyboardAuth(virNetSSHSession *sess,
const char *username,
int tries) int tries)
{ {
virNetSSHAuthMethod *auth; virNetSSHAuthMethod *auth;
char *user = NULL;
if (!username) {
virReportError(VIR_ERR_SSH, "%s",
_("Username must be provided "
"for ssh agent authentication"));
return -1;
}
virObjectLock(sess); virObjectLock(sess);
user = g_strdup(username);
if (!(auth = virNetSSHSessionAuthMethodNew(sess))) if (!(auth = virNetSSHSessionAuthMethodNew(sess)))
goto error; goto error;
auth->username = user;
auth->tries = tries; auth->tries = tries;
auth->method = VIR_NET_SSH_AUTH_KEYBOARD_INTERACTIVE; auth->method = VIR_NET_SSH_AUTH_KEYBOARD_INTERACTIVE;
@ -1099,7 +1056,6 @@ virNetSSHSessionAuthAddKeyboardAuth(virNetSSHSession *sess,
return 0; return 0;
error: error:
VIR_FREE(user);
virObjectUnlock(sess); virObjectUnlock(sess);
return -1; return -1;
@ -1172,7 +1128,7 @@ virNetSSHSessionSetHostKeyVerification(virNetSSHSession *sess,
} }
/* allocate and initialize a ssh session object */ /* allocate and initialize a ssh session object */
virNetSSHSession *virNetSSHSessionNew(void) virNetSSHSession *virNetSSHSessionNew(const char *username)
{ {
virNetSSHSession *sess = NULL; virNetSSHSession *sess = NULL;
@ -1182,6 +1138,8 @@ virNetSSHSession *virNetSSHSessionNew(void)
if (!(sess = virObjectLockableNew(virNetSSHSessionClass))) if (!(sess = virObjectLockableNew(virNetSSHSessionClass)))
goto error; goto error;
sess->username = g_strdup(username);
/* initialize session data, use the internal data for callbacks /* initialize session data, use the internal data for callbacks
* and stick to default memory management functions */ * and stick to default memory management functions */
if (!(sess->session = libssh2_session_init_ex(NULL, if (!(sess->session = libssh2_session_init_ex(NULL,

View File

@ -25,7 +25,7 @@
typedef struct _virNetSSHSession virNetSSHSession; typedef struct _virNetSSHSession virNetSSHSession;
virNetSSHSession *virNetSSHSessionNew(void); virNetSSHSession *virNetSSHSessionNew(const char *username);
void virNetSSHSessionFree(virNetSSHSession *sess); void virNetSSHSessionFree(virNetSSHSession *sess);
typedef enum { typedef enum {
@ -48,18 +48,14 @@ int virNetSSHSessionAuthSetCallback(virNetSSHSession *sess,
virConnectAuthPtr auth); virConnectAuthPtr auth);
int virNetSSHSessionAuthAddPasswordAuth(virNetSSHSession *sess, int virNetSSHSessionAuthAddPasswordAuth(virNetSSHSession *sess,
virURI *uri, virURI *uri);
const char *username);
int virNetSSHSessionAuthAddAgentAuth(virNetSSHSession *sess, int virNetSSHSessionAuthAddAgentAuth(virNetSSHSession *sess);
const char *username);
int virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess, int virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess,
const char *username,
const char *keyfile); const char *keyfile);
int virNetSSHSessionAuthAddKeyboardAuth(virNetSSHSession *sess, int virNetSSHSessionAuthAddKeyboardAuth(virNetSSHSession *sess,
const char *username,
int tries); int tries);
int virNetSSHSessionSetHostKeyVerification(virNetSSHSession *sess, int virNetSSHSessionSetHostKeyVerification(virNetSSHSession *sess,