diff --git a/port_fwd.c b/port_fwd.c index f400766..40e633d 100644 --- a/port_fwd.c +++ b/port_fwd.c @@ -68,45 +68,55 @@ static void procfs_scan_listen(int fd, unsigned int lstate, } /** - * get_bound_ports() - Get maps of ports with bound sockets + * get_bound_ports_tcp() - Get maps of TCP ports with bound sockets * @c: Execution context * @ns: If set, set bitmaps for ports to tap/ns -- to init otherwise - * @proto: Protocol number (IPPROTO_TCP or IPPROTO_UDP) */ -void get_bound_ports(struct ctx *c, int ns, uint8_t proto) +void get_bound_ports_tcp(struct ctx *c, int ns) { - uint8_t *udp_map, *udp_excl, *tcp_map, *tcp_excl; + uint8_t *map, *excl; if (ns) { - udp_map = c->udp.fwd_in.f.map; - udp_excl = c->udp.fwd_out.f.map; - tcp_map = c->tcp.fwd_in.map; - tcp_excl = c->tcp.fwd_out.map; + map = c->tcp.fwd_in.map; + excl = c->tcp.fwd_out.map; } else { - udp_map = c->udp.fwd_out.f.map; - udp_excl = c->udp.fwd_in.f.map; - tcp_map = c->tcp.fwd_out.map; - tcp_excl = c->tcp.fwd_in.map; + map = c->tcp.fwd_out.map; + excl = c->tcp.fwd_in.map; } - if (proto == IPPROTO_UDP) { - memset(udp_map, 0, PORT_BITMAP_SIZE); - procfs_scan_listen(c->proc_net_udp[V4][ns], - UDP_LISTEN, udp_map, udp_excl); - procfs_scan_listen(c->proc_net_udp[V6][ns], - UDP_LISTEN, udp_map, udp_excl); + memset(map, 0, PORT_BITMAP_SIZE); + procfs_scan_listen(c->proc_net_tcp[V4][ns], TCP_LISTEN, map, excl); + procfs_scan_listen(c->proc_net_tcp[V6][ns], TCP_LISTEN, map, excl); +} - procfs_scan_listen(c->proc_net_tcp[V4][ns], - TCP_LISTEN, udp_map, udp_excl); - procfs_scan_listen(c->proc_net_tcp[V6][ns], - TCP_LISTEN, udp_map, udp_excl); - } else if (proto == IPPROTO_TCP) { - memset(tcp_map, 0, PORT_BITMAP_SIZE); - procfs_scan_listen(c->proc_net_tcp[V4][ns], - TCP_LISTEN, tcp_map, tcp_excl); - procfs_scan_listen(c->proc_net_tcp[V6][ns], - TCP_LISTEN, tcp_map, tcp_excl); +/** + * get_bound_ports_udp() - Get maps of UDP ports with bound sockets + * @c: Execution context + * @ns: If set, set bitmaps for ports to tap/ns -- to init otherwise + */ +void get_bound_ports_udp(struct ctx *c, int ns) +{ + uint8_t *map, *excl; + + if (ns) { + map = c->udp.fwd_in.f.map; + excl = c->udp.fwd_out.f.map; + } else { + map = c->udp.fwd_out.f.map; + excl = c->udp.fwd_in.f.map; } + + memset(map, 0, PORT_BITMAP_SIZE); + procfs_scan_listen(c->proc_net_udp[V4][ns], UDP_LISTEN, map, excl); + procfs_scan_listen(c->proc_net_udp[V6][ns], UDP_LISTEN, map, excl); + + /* Also forward UDP ports with the same numbers as bound TCP ports. + * This is useful for a handful of protocols (e.g. iperf3) where a TCP + * control port is used to set up transfers on a corresponding UDP + * port. + */ + procfs_scan_listen(c->proc_net_tcp[V4][ns], TCP_LISTEN, map, excl); + procfs_scan_listen(c->proc_net_tcp[V6][ns], TCP_LISTEN, map, excl); } /** @@ -125,21 +135,21 @@ void port_fwd_init(struct ctx *c) if (c->tcp.fwd_in.mode == FWD_AUTO) { c->proc_net_tcp[V4][1] = open_in_ns(c, "/proc/net/tcp", flags); c->proc_net_tcp[V6][1] = open_in_ns(c, "/proc/net/tcp6", flags); - get_bound_ports(c, 1, IPPROTO_TCP); + get_bound_ports_tcp(c, 1); } if (c->udp.fwd_in.f.mode == FWD_AUTO) { c->proc_net_udp[V4][1] = open_in_ns(c, "/proc/net/udp", flags); c->proc_net_udp[V6][1] = open_in_ns(c, "/proc/net/udp6", flags); - get_bound_ports(c, 1, IPPROTO_UDP); + get_bound_ports_udp(c, 1); } if (c->tcp.fwd_out.mode == FWD_AUTO) { c->proc_net_tcp[V4][0] = open("/proc/net/tcp", flags); c->proc_net_tcp[V6][0] = open("/proc/net/tcp6", flags); - get_bound_ports(c, 0, IPPROTO_TCP); + get_bound_ports_tcp(c, 0); } if (c->udp.fwd_out.f.mode == FWD_AUTO) { c->proc_net_udp[V4][0] = open("/proc/net/udp", flags); c->proc_net_udp[V6][0] = open("/proc/net/udp6", flags); - get_bound_ports(c, 0, IPPROTO_UDP); + get_bound_ports_udp(c, 0); } } diff --git a/port_fwd.h b/port_fwd.h index ad8ed1f..2f8f526 100644 --- a/port_fwd.h +++ b/port_fwd.h @@ -31,7 +31,8 @@ struct port_fwd { in_port_t delta[NUM_PORTS]; }; -void get_bound_ports(struct ctx *c, int ns, uint8_t proto); +void get_bound_ports_tcp(struct ctx *c, int ns); +void get_bound_ports_udp(struct ctx *c, int ns); void port_fwd_init(struct ctx *c); #endif /* PORT_FWD_H */ diff --git a/tcp.c b/tcp.c index 6fe9cdd..5b41897 100644 --- a/tcp.c +++ b/tcp.c @@ -3287,13 +3287,13 @@ void tcp_timer(struct ctx *c, const struct timespec *ts) struct tcp_port_rebind_arg rebind_arg = { c, 0 }; if (c->tcp.fwd_out.mode == FWD_AUTO) { - get_bound_ports(c, 0, IPPROTO_TCP); + get_bound_ports_tcp(c, 0); rebind_arg.bind_in_ns = 1; NS_CALL(tcp_port_rebind, &rebind_arg); } if (c->tcp.fwd_in.mode == FWD_AUTO) { - get_bound_ports(c, 1, IPPROTO_TCP); + get_bound_ports_tcp(c, 1); rebind_arg.bind_in_ns = 0; tcp_port_rebind(&rebind_arg); }