util: refactor virNetlinkCommand to fix several bugs / style problems

Inspired by a simpler patch from "Wangrui (K) <moon.wangrui@huawei.com>".

A submitted patch pointed out that virNetlinkCommand() was doing an
improper typecast of the return value from nl_recv() (int to
unsigned), causing it to miss error returns, and that even after
remedying that problem, virNetlinkCommand() was calling VIR_FREE() on
the pointer returned from nl_recv() (*resp) even if nl_recv() had
returned an error, and that in this case the pointer was verifiably
invalid, as it was pointing to memory that had been allocated by
libnl, but then freed prior to returning the error.

While reviewing this patch, I noticed several other problems with this
seemingly simple function (at least one of them as serious as the
problem being reported/fixed by the aforementioned patch), and decided
they all deserved to be fixed. Here is the list:

1) The return value from nl_recv() must be assigned to an int (rather
   than unsigned int) in order to detect failure.

2) When nl_recv() returns an error or 0, the contents of *resp is
   invalid, and should be simply set to 0, *not* VIR_FREE()'d.

3) When nl_recv() returns 0, errno is not set, so the logged error
   message should not reference errno (it *is* an error though).

4) The first error return from virNetlinkCommand returns -EINVAL,
   incorrectly implying that the caller can expect the return value to
   be of the "-errno" variety, which is not true in any other case.

5) The 2nd error return returns directly with garbage in *resp. While
   the caller should never use *resp in this case, it's still good
   practice to set it to NULL.

6) For the next 5 (!!) error conditions, *resp will contain garbage,
   and virNetlinkCommand() will goto it's cleanup code which will
   VIR_FREE(*resp), almost surely leading to a segfault.

In addition to fixing these 6 problems, this patch also makes the
following two changes to make the function conform more closely to the
style of other libvirt code:

1) Change the handling of return code from "named rc and defaulted to
0, but changed to -1 on error" to the more common "named ret and
defaulted to -1, but changed to 0 on success".

2) Rename the "error" label to "cleanup", since the code that follows
is executed in success cases as well as failure.
This commit is contained in:
Laine Stump 2014-05-13 14:34:43 +03:00
parent 88b5acb67f
commit 5d85b8a8f4

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2010-2013 Red Hat, Inc. * Copyright (C) 2010-2014 Red Hat, Inc.
* Copyright (C) 2010-2012 IBM Corporation * Copyright (C) 2010-2012 IBM Corporation
* *
* This library is free software; you can redistribute it and/or * This library is free software; you can redistribute it and/or
@ -180,7 +180,7 @@ int virNetlinkCommand(struct nl_msg *nl_msg,
uint32_t src_pid, uint32_t dst_pid, uint32_t src_pid, uint32_t dst_pid,
unsigned int protocol, unsigned int groups) unsigned int protocol, unsigned int groups)
{ {
int rc = 0; int ret = -1;
struct sockaddr_nl nladdr = { struct sockaddr_nl nladdr = {
.nl_family = AF_NETLINK, .nl_family = AF_NETLINK,
.nl_pid = dst_pid, .nl_pid = dst_pid,
@ -192,41 +192,39 @@ int virNetlinkCommand(struct nl_msg *nl_msg,
int n; int n;
struct nlmsghdr *nlmsg = nlmsg_hdr(nl_msg); struct nlmsghdr *nlmsg = nlmsg_hdr(nl_msg);
virNetlinkHandle *nlhandle = NULL; virNetlinkHandle *nlhandle = NULL;
int len = 0;
if (protocol >= MAX_LINKS) { if (protocol >= MAX_LINKS) {
virReportSystemError(EINVAL, virReportSystemError(EINVAL,
_("invalid protocol argument: %d"), protocol); _("invalid protocol argument: %d"), protocol);
return -EINVAL; goto cleanup;
} }
nlhandle = virNetlinkAlloc(); nlhandle = virNetlinkAlloc();
if (!nlhandle) { if (!nlhandle) {
virReportSystemError(errno, virReportSystemError(errno,
"%s", _("cannot allocate nlhandle for netlink")); "%s", _("cannot allocate nlhandle for netlink"));
return -1; goto cleanup;
} }
if (nl_connect(nlhandle, protocol) < 0) { if (nl_connect(nlhandle, protocol) < 0) {
virReportSystemError(errno, virReportSystemError(errno,
_("cannot connect to netlink socket with protocol %d"), _("cannot connect to netlink socket with protocol %d"),
protocol); protocol);
rc = -1; goto cleanup;
goto error;
} }
fd = nl_socket_get_fd(nlhandle); fd = nl_socket_get_fd(nlhandle);
if (fd < 0) { if (fd < 0) {
virReportSystemError(errno, virReportSystemError(errno,
"%s", _("cannot get netlink socket fd")); "%s", _("cannot get netlink socket fd"));
rc = -1; goto cleanup;
goto error;
} }
if (groups && nl_socket_add_membership(nlhandle, groups) < 0) { if (groups && nl_socket_add_membership(nlhandle, groups) < 0) {
virReportSystemError(errno, virReportSystemError(errno,
"%s", _("cannot add netlink membership")); "%s", _("cannot add netlink membership"));
rc = -1; goto cleanup;
goto error;
} }
nlmsg_set_dst(nl_msg, &nladdr); nlmsg_set_dst(nl_msg, &nladdr);
@ -237,8 +235,7 @@ int virNetlinkCommand(struct nl_msg *nl_msg,
if (nbytes < 0) { if (nbytes < 0) {
virReportSystemError(errno, virReportSystemError(errno,
"%s", _("cannot send to netlink socket")); "%s", _("cannot send to netlink socket"));
rc = -1; goto cleanup;
goto error;
} }
memset(fds, 0, sizeof(fds)); memset(fds, 0, sizeof(fds));
@ -253,26 +250,30 @@ int virNetlinkCommand(struct nl_msg *nl_msg,
if (n == 0) if (n == 0)
virReportSystemError(ETIMEDOUT, "%s", virReportSystemError(ETIMEDOUT, "%s",
_("no valid netlink response was received")); _("no valid netlink response was received"));
rc = -1; goto cleanup;
goto error;
} }
*respbuflen = nl_recv(nlhandle, &nladdr, len = nl_recv(nlhandle, &nladdr, (unsigned char **)resp, NULL);
(unsigned char **)resp, NULL); if (len == 0) {
if (*respbuflen <= 0) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
virReportSystemError(errno, _("nl_recv failed - returned 0 bytes"));
"%s", _("nl_recv failed")); goto cleanup;
rc = -1;
} }
error: if (len < 0) {
if (rc == -1) { virReportSystemError(errno, "%s", _("nl_recv failed"));
VIR_FREE(*resp); goto cleanup;
}
ret = 0;
*respbuflen = len;
cleanup:
if (ret < 0) {
*resp = NULL; *resp = NULL;
*respbuflen = 0; *respbuflen = 0;
} }
virNetlinkFree(nlhandle); virNetlinkFree(nlhandle);
return rc; return ret;
} }
static void static void