/*
* Copyright (C) 2011, 2014 Red Hat, Inc.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library. If not, see
* .
*/
#include
#include
#ifdef HAVE_IFADDRS_H
# include
#endif
#include
#include "testutils.h"
#include "virutil.h"
#include "virerror.h"
#include "viralloc.h"
#include "virlog.h"
#include "virfile.h"
#include "virstring.h"
#include "rpc/virnetsocket.h"
#define VIR_FROM_THIS VIR_FROM_RPC
VIR_LOG_INIT("tests.netsockettest");
#if HAVE_IFADDRS_H
# define BASE_PORT 5672
static int
checkProtocols(bool *hasIPv4, bool *hasIPv6,
int *freePort)
{
struct sockaddr_in in4;
struct sockaddr_in6 in6;
int s4 = -1, s6 = -1;
size_t i;
int ret = -1;
*freePort = 0;
if (virNetSocketCheckProtocols(hasIPv4, hasIPv6) < 0)
return -1;
for (i = 0; i < 50; i++) {
int only = 1;
if ((s4 = socket(AF_INET, SOCK_STREAM, 0)) < 0)
goto cleanup;
if (*hasIPv6) {
if ((s6 = socket(AF_INET6, SOCK_STREAM, 0)) < 0)
goto cleanup;
if (setsockopt(s6, IPPROTO_IPV6, IPV6_V6ONLY, &only, sizeof(only)) < 0)
goto cleanup;
}
memset(&in4, 0, sizeof(in4));
memset(&in6, 0, sizeof(in6));
in4.sin_family = AF_INET;
in4.sin_port = htons(BASE_PORT + i);
in4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
in6.sin6_family = AF_INET6;
in6.sin6_port = htons(BASE_PORT + i);
in6.sin6_addr = in6addr_loopback;
if (bind(s4, (struct sockaddr *)&in4, sizeof(in4)) < 0) {
if (errno == EADDRINUSE) {
VIR_FORCE_CLOSE(s4);
VIR_FORCE_CLOSE(s6);
continue;
}
goto cleanup;
}
if (*hasIPv6) {
if (bind(s6, (struct sockaddr *)&in6, sizeof(in6)) < 0) {
if (errno == EADDRINUSE) {
VIR_FORCE_CLOSE(s4);
VIR_FORCE_CLOSE(s6);
continue;
}
goto cleanup;
}
}
*freePort = BASE_PORT + i;
break;
}
VIR_DEBUG("Choose port %d", *freePort);
ret = 0;
cleanup:
VIR_FORCE_CLOSE(s4);
VIR_FORCE_CLOSE(s6);
return ret;
}
struct testClientData {
const char *path;
const char *cnode;
const char *portstr;
};
static void
testSocketClient(void *opaque)
{
struct testClientData *data = opaque;
char c;
virNetSocketPtr csock = NULL;
if (data->path) {
if (virNetSocketNewConnectUNIX(data->path, false,
NULL, &csock) < 0)
return;
} else {
if (virNetSocketNewConnectTCP(data->cnode, data->portstr,
AF_UNSPEC,
&csock) < 0)
return;
}
virNetSocketSetBlocking(csock, true);
if (virNetSocketRead(csock, &c, 1) != 1) {
VIR_DEBUG("Cannot read from server");
goto done;
}
if (virNetSocketWrite(csock, &c, 1) != 1) {
VIR_DEBUG("Cannot write to server");
goto done;
}
done:
virObjectUnref(csock);
}
static void
testSocketIncoming(virNetSocketPtr sock,
int events ATTRIBUTE_UNUSED,
void *opaque)
{
virNetSocketPtr *retsock = opaque;
VIR_DEBUG("Incoming sock=%p events=%d", sock, events);
*retsock = sock;
}
struct testSocketData {
const char *lnode;
int port;
const char *cnode;
};
static int
testSocketAccept(const void *opaque)
{
virNetSocketPtr *lsock = NULL; /* Listen socket */
size_t nlsock = 0, i;
virNetSocketPtr ssock = NULL; /* Server socket */
virNetSocketPtr rsock = NULL; /* Incoming client socket */
const struct testSocketData *data = opaque;
int ret = -1;
char portstr[100];
char *tmpdir = NULL;
char *path = NULL;
char template[] = "/tmp/libvirt_XXXXXX";
virThread th;
struct testClientData cdata = { 0 };
bool goodsock = false;
char a = 'a';
char b = '\0';
if (!data) {
virNetSocketPtr usock;
tmpdir = mkdtemp(template);
if (tmpdir == NULL) {
VIR_WARN("Failed to create temporary directory");
goto cleanup;
}
if (virAsprintf(&path, "%s/test.sock", tmpdir) < 0)
goto cleanup;
if (virNetSocketNewListenUNIX(path, 0700, -1, getegid(), &usock) < 0)
goto cleanup;
if (VIR_ALLOC_N(lsock, 1) < 0) {
virObjectUnref(usock);
goto cleanup;
}
lsock[0] = usock;
nlsock = 1;
cdata.path = path;
} else {
snprintf(portstr, sizeof(portstr), "%d", data->port);
if (virNetSocketNewListenTCP(data->lnode, portstr,
AF_UNSPEC,
&lsock, &nlsock) < 0)
goto cleanup;
cdata.cnode = data->cnode;
cdata.portstr = portstr;
}
for (i = 0; i < nlsock; i++) {
if (virNetSocketListen(lsock[i], 0) < 0)
goto cleanup;
if (virNetSocketAddIOCallback(lsock[i],
VIR_EVENT_HANDLE_READABLE,
testSocketIncoming,
&rsock,
NULL) < 0) {
goto cleanup;
}
}
if (virThreadCreate(&th, true,
testSocketClient,
&cdata) < 0)
goto cleanup;
while (rsock == NULL) {
if (virEventRunDefaultImpl() < 0)
break;
}
for (i = 0; i < nlsock; i++) {
if (lsock[i] == rsock) {
goodsock = true;
break;
}
}
if (!goodsock) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
"Unexpected server socket seen");
goto join;
}
if (virNetSocketAccept(rsock, &ssock) < 0)
goto join;
if (!ssock) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
"Client went away unexpectedly");
goto join;
}
virNetSocketSetBlocking(ssock, true);
if (virNetSocketWrite(ssock, &a, 1) < 0 ||
virNetSocketRead(ssock, &b, 1) < 0) {
goto join;
}
if (a != b) {
virReportError(VIR_ERR_INTERNAL_ERROR,
"Bad data received '%x' != '%x'", a, b);
goto join;
}
virObjectUnref(ssock);
ssock = NULL;
ret = 0;
join:
virThreadJoin(&th);
cleanup:
virObjectUnref(ssock);
for (i = 0; i < nlsock; i++) {
virNetSocketRemoveIOCallback(lsock[i]);
virNetSocketClose(lsock[i]);
virObjectUnref(lsock[i]);
}
VIR_FREE(lsock);
VIR_FREE(path);
if (tmpdir)
rmdir(tmpdir);
return ret;
}
#endif
#ifndef WIN32
static int testSocketUNIXAddrs(const void *data ATTRIBUTE_UNUSED)
{
virNetSocketPtr lsock = NULL; /* Listen socket */
virNetSocketPtr ssock = NULL; /* Server socket */
virNetSocketPtr csock = NULL; /* Client socket */
int ret = -1;
char *path = NULL;
char *tmpdir;
char template[] = "/tmp/libvirt_XXXXXX";
tmpdir = mkdtemp(template);
if (tmpdir == NULL) {
VIR_WARN("Failed to create temporary directory");
goto cleanup;
}
if (virAsprintf(&path, "%s/test.sock", tmpdir) < 0)
goto cleanup;
if (virNetSocketNewListenUNIX(path, 0700, -1, getegid(), &lsock) < 0)
goto cleanup;
if (STRNEQ(virNetSocketLocalAddrStringSASL(lsock), "127.0.0.1;0")) {
VIR_DEBUG("Unexpected local address");
goto cleanup;
}
if (virNetSocketRemoteAddrStringSASL(lsock) != NULL) {
VIR_DEBUG("Unexpected remote address");
goto cleanup;
}
if (virNetSocketListen(lsock, 0) < 0)
goto cleanup;
if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0)
goto cleanup;
if (STRNEQ(virNetSocketLocalAddrStringSASL(csock), "127.0.0.1;0")) {
VIR_DEBUG("Unexpected local address");
goto cleanup;
}
if (STRNEQ(virNetSocketRemoteAddrStringSASL(csock), "127.0.0.1;0")) {
VIR_DEBUG("Unexpected remote address");
goto cleanup;
}
if (STRNEQ(virNetSocketRemoteAddrStringURI(csock), "127.0.0.1:0")) {
VIR_DEBUG("Unexpected remote address");
goto cleanup;
}
if (virNetSocketAccept(lsock, &ssock) < 0) {
VIR_DEBUG("Unexpected client socket missing");
goto cleanup;
}
if (STRNEQ(virNetSocketLocalAddrStringSASL(ssock), "127.0.0.1;0")) {
VIR_DEBUG("Unexpected local address");
goto cleanup;
}
if (STRNEQ(virNetSocketRemoteAddrStringSASL(ssock), "127.0.0.1;0")) {
VIR_DEBUG("Unexpected remote address");
goto cleanup;
}
if (STRNEQ(virNetSocketRemoteAddrStringURI(ssock), "127.0.0.1:0")) {
VIR_DEBUG("Unexpected remote address");
goto cleanup;
}
ret = 0;
cleanup:
VIR_FREE(path);
virObjectUnref(lsock);
virObjectUnref(ssock);
virObjectUnref(csock);
if (tmpdir)
rmdir(tmpdir);
return ret;
}
static int testSocketCommandNormal(const void *data ATTRIBUTE_UNUSED)
{
virNetSocketPtr csock = NULL; /* Client socket */
char buf[100];
size_t i;
int ret = -1;
virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/zero", NULL);
virCommandAddEnvPassCommon(cmd);
if (virNetSocketNewConnectCommand(cmd, &csock) < 0)
goto cleanup;
virNetSocketSetBlocking(csock, true);
if (virNetSocketRead(csock, buf, sizeof(buf)) < 0)
goto cleanup;
for (i = 0; i < sizeof(buf); i++)
if (buf[i] != '\0')
goto cleanup;
ret = 0;
cleanup:
virObjectUnref(csock);
return ret;
}
static int testSocketCommandFail(const void *data ATTRIBUTE_UNUSED)
{
virNetSocketPtr csock = NULL; /* Client socket */
char buf[100];
int ret = -1;
virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/does-not-exist", NULL);
virCommandAddEnvPassCommon(cmd);
if (virNetSocketNewConnectCommand(cmd, &csock) < 0)
goto cleanup;
virNetSocketSetBlocking(csock, true);
if (virNetSocketRead(csock, buf, sizeof(buf)) == 0)
goto cleanup;
ret = 0;
cleanup:
virObjectUnref(csock);
return ret;
}
struct testSSHData {
const char *nodename;
const char *service;
const char *binary;
const char *username;
bool noTTY;
bool noVerify;
const char *netcat;
const char *keyfile;
const char *path;
const char *expectOut;
bool failConnect;
bool dieEarly;
};
static int testSocketSSH(const void *opaque)
{
const struct testSSHData *data = opaque;
virNetSocketPtr csock = NULL; /* Client socket */
int ret = -1;
char buf[1024];
if (virNetSocketNewConnectSSH(data->nodename,
data->service,
data->binary,
data->username,
data->noTTY,
data->noVerify,
data->netcat,
data->keyfile,
data->path,
&csock) < 0)
goto cleanup;
virNetSocketSetBlocking(csock, true);
if (data->failConnect) {
if (virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) {
VIR_DEBUG("Expected connect failure, but got some socket data");
goto cleanup;
}
} else {
ssize_t rv;
if ((rv = virNetSocketRead(csock, buf, sizeof(buf)-1)) < 0) {
VIR_DEBUG("Didn't get any socket data");
goto cleanup;
}
buf[rv] = '\0';
if (STRNEQ(buf, data->expectOut)) {
virTestDifference(stderr, data->expectOut, buf);
goto cleanup;
}
if (data->dieEarly &&
virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) {
VIR_DEBUG("Got too much socket data");
goto cleanup;
}
}
ret = 0;
cleanup:
virObjectUnref(csock);
return ret;
}
#endif
static int
mymain(void)
{
int ret = 0;
#ifdef HAVE_IFADDRS_H
bool hasIPv4, hasIPv6;
int freePort;
#endif
signal(SIGPIPE, SIG_IGN);
virEventRegisterDefaultImpl();
#ifdef HAVE_IFADDRS_H
if (checkProtocols(&hasIPv4, &hasIPv6, &freePort) < 0) {
fprintf(stderr, "Cannot identify IPv4/6 availability\n");
return EXIT_FAILURE;
}
if (hasIPv4) {
struct testSocketData tcpData = { "127.0.0.1", freePort, "127.0.0.1" };
if (virTestRun("Socket TCP/IPv4 Accept", testSocketAccept, &tcpData) < 0)
ret = -1;
}
if (hasIPv6) {
struct testSocketData tcpData = { "::1", freePort, "::1" };
if (virTestRun("Socket TCP/IPv6 Accept", testSocketAccept, &tcpData) < 0)
ret = -1;
}
if (hasIPv6 && hasIPv4) {
struct testSocketData tcpData = { NULL, freePort, "127.0.0.1" };
if (virTestRun("Socket TCP/IPv4+IPv6 Accept", testSocketAccept, &tcpData) < 0)
ret = -1;
tcpData.cnode = "::1";
if (virTestRun("Socket TCP/IPv4+IPv6 Accept", testSocketAccept, &tcpData) < 0)
ret = -1;
}
#endif
#ifndef WIN32
if (virTestRun("Socket UNIX Accept", testSocketAccept, NULL) < 0)
ret = -1;
if (virTestRun("Socket UNIX Addrs", testSocketUNIXAddrs, NULL) < 0)
ret = -1;
if (virTestRun("Socket External Command /dev/zero", testSocketCommandNormal, NULL) < 0)
ret = -1;
if (virTestRun("Socket External Command /dev/does-not-exist", testSocketCommandFail, NULL) < 0)
ret = -1;
struct testSSHData sshData1 = {
.nodename = "somehost",
.path = "/tmp/socket",
.expectOut = "-T -e none -- somehost sh -c '"
"if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
"ARG=-q0;"
"else "
"ARG=;"
"fi;"
"'nc' $ARG -U /tmp/socket'\n",
};
if (virTestRun("SSH test 1", testSocketSSH, &sshData1) < 0)
ret = -1;
struct testSSHData sshData2 = {
.nodename = "somehost",
.service = "9000",
.username = "fred",
.netcat = "netcat",
.noTTY = true,
.noVerify = false,
.path = "/tmp/socket",
.expectOut = "-p 9000 -l fred -T -e none -o BatchMode=yes -- somehost sh -c '"
"if 'netcat' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
"ARG=-q0;"
"else "
"ARG=;"
"fi;"
"'netcat' $ARG -U /tmp/socket'\n",
};
if (virTestRun("SSH test 2", testSocketSSH, &sshData2) < 0)
ret = -1;
struct testSSHData sshData3 = {
.nodename = "somehost",
.service = "9000",
.username = "fred",
.netcat = "netcat",
.noTTY = false,
.noVerify = true,
.path = "/tmp/socket",
.expectOut = "-p 9000 -l fred -T -e none -o StrictHostKeyChecking=no -- somehost sh -c '"
"if 'netcat' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
"ARG=-q0;"
"else "
"ARG=;"
"fi;"
"'netcat' $ARG -U /tmp/socket'\n",
};
if (virTestRun("SSH test 3", testSocketSSH, &sshData3) < 0)
ret = -1;
struct testSSHData sshData4 = {
.nodename = "nosuchhost",
.path = "/tmp/socket",
.failConnect = true,
};
if (virTestRun("SSH test 4", testSocketSSH, &sshData4) < 0)
ret = -1;
struct testSSHData sshData5 = {
.nodename = "crashyhost",
.path = "/tmp/socket",
.expectOut = "-T -e none -- crashyhost sh -c "
"'if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
"ARG=-q0;"
"else "
"ARG=;"
"fi;"
"'nc' $ARG -U /tmp/socket'\n",
.dieEarly = true,
};
if (virTestRun("SSH test 5", testSocketSSH, &sshData5) < 0)
ret = -1;
struct testSSHData sshData6 = {
.nodename = "example.com",
.path = "/tmp/socket",
.keyfile = "/root/.ssh/example_key",
.noVerify = true,
.expectOut = "-i /root/.ssh/example_key -T -e none -o StrictHostKeyChecking=no -- example.com sh -c '"
"if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
"ARG=-q0;"
"else "
"ARG=;"
"fi;"
"'nc' $ARG -U /tmp/socket'\n",
};
if (virTestRun("SSH test 6", testSocketSSH, &sshData6) < 0)
ret = -1;
struct testSSHData sshData7 = {
.nodename = "somehost",
.netcat = "/tmp/fo o/nc",
.path = "/tmp/socket",
.expectOut = "-T -e none -- somehost sh -c '"
"if \'''\\''/tmp/fo o/nc'\\'''' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
"ARG=-q0;"
"else "
"ARG=;"
"fi;"
"'''\\''/tmp/fo o/nc'\\'''' $ARG -U /tmp/socket'\n",
};
if (virTestRun("SSH test 7", testSocketSSH, &sshData7) < 0)
ret = -1;
#endif
return ret == 0 ? EXIT_SUCCESS : EXIT_FAILURE;
}
VIR_TEST_MAIN(mymain)