diff --git a/port_fwd.c b/port_fwd.c index eac8d2f..7943a30 100644 --- a/port_fwd.c +++ b/port_fwd.c @@ -86,22 +86,33 @@ void port_fwd_scan_tcp(struct port_fwd *fwd, const struct port_fwd *rev) * port_fwd_scan_tcp() - Scan /proc to update TCP forwarding map * @fwd: Forwarding information to update * @rev: Forwarding information for the reverse direction - * @tcp: Corresponding TCP forwarding information + * @tcp_fwd: Corresponding TCP forwarding information + * @tcp_rev: TCP forwarding information for the reverse direction */ void port_fwd_scan_udp(struct port_fwd *fwd, const struct port_fwd *rev, - const struct port_fwd *tcp) + const struct port_fwd *tcp_fwd, + const struct port_fwd *tcp_rev) { + uint8_t exclude[PORT_BITMAP_SIZE]; + + bitmap_or(exclude, PORT_BITMAP_SIZE, rev->map, tcp_rev->map); + memset(fwd->map, 0, PORT_BITMAP_SIZE); - procfs_scan_listen(fwd->scan4, UDP_LISTEN, fwd->map, rev->map); - procfs_scan_listen(fwd->scan6, UDP_LISTEN, fwd->map, rev->map); + procfs_scan_listen(fwd->scan4, UDP_LISTEN, fwd->map, exclude); + procfs_scan_listen(fwd->scan6, UDP_LISTEN, fwd->map, exclude); /* Also forward UDP ports with the same numbers as bound TCP ports. * This is useful for a handful of protocols (e.g. iperf3) where a TCP * control port is used to set up transfers on a corresponding UDP * port. + * + * This means we need to skip numbers of TCP ports bound on the other + * side, too. Otherwise, we would detect corresponding UDP ports as + * bound and try to forward them from the opposite side, but it's + * already us handling them. */ - procfs_scan_listen(tcp->scan4, TCP_LISTEN, fwd->map, rev->map); - procfs_scan_listen(tcp->scan6, TCP_LISTEN, fwd->map, rev->map); + procfs_scan_listen(tcp_fwd->scan4, TCP_LISTEN, fwd->map, exclude); + procfs_scan_listen(tcp_fwd->scan6, TCP_LISTEN, fwd->map, exclude); } /** @@ -126,7 +137,7 @@ void port_fwd_init(struct ctx *c) c->udp.fwd_in.f.scan4 = open_in_ns(c, "/proc/net/udp", flags); c->udp.fwd_in.f.scan6 = open_in_ns(c, "/proc/net/udp6", flags); port_fwd_scan_udp(&c->udp.fwd_in.f, &c->udp.fwd_out.f, - &c->tcp.fwd_in); + &c->tcp.fwd_in, &c->tcp.fwd_out); } if (c->tcp.fwd_out.mode == FWD_AUTO) { c->tcp.fwd_out.scan4 = open("/proc/net/tcp", flags); @@ -137,6 +148,6 @@ void port_fwd_init(struct ctx *c) c->udp.fwd_out.f.scan4 = open("/proc/net/udp", flags); c->udp.fwd_out.f.scan6 = open("/proc/net/udp6", flags); port_fwd_scan_udp(&c->udp.fwd_out.f, &c->udp.fwd_in.f, - &c->tcp.fwd_out); + &c->tcp.fwd_out, &c->tcp.fwd_in); } } diff --git a/port_fwd.h b/port_fwd.h index 8a823d8..f6bf1b5 100644 --- a/port_fwd.h +++ b/port_fwd.h @@ -37,7 +37,8 @@ struct port_fwd { void port_fwd_scan_tcp(struct port_fwd *fwd, const struct port_fwd *rev); void port_fwd_scan_udp(struct port_fwd *fwd, const struct port_fwd *rev, - const struct port_fwd *tcp); + const struct port_fwd *tcp_fwd, + const struct port_fwd *tcp_rev); void port_fwd_init(struct ctx *c); #endif /* PORT_FWD_H */ diff --git a/udp.c b/udp.c index cc1ea9c..1f8c306 100644 --- a/udp.c +++ b/udp.c @@ -1260,13 +1260,13 @@ void udp_timer(struct ctx *c, const struct timespec *ts) if (c->mode == MODE_PASTA) { if (c->udp.fwd_out.f.mode == FWD_AUTO) { port_fwd_scan_udp(&c->udp.fwd_out.f, &c->udp.fwd_in.f, - &c->tcp.fwd_out); + &c->tcp.fwd_out, &c->tcp.fwd_in); NS_CALL(udp_port_rebind_outbound, c); } if (c->udp.fwd_in.f.mode == FWD_AUTO) { port_fwd_scan_udp(&c->udp.fwd_in.f, &c->udp.fwd_out.f, - &c->tcp.fwd_in); + &c->tcp.fwd_in, &c->tcp.fwd_out); udp_port_rebind(c, false); } } diff --git a/util.c b/util.c index c38ab7e..d465e48 100644 --- a/util.c +++ b/util.c @@ -325,6 +325,27 @@ int bitmap_isset(const uint8_t *map, int bit) return !!(*word & BITMAP_BIT(bit)); } +/** + * bitmap_or() - Logical disjunction (OR) of two bitmaps + * @dst: Pointer to result bitmap + * @size: Size of bitmaps, in bytes + * @a: First operand + * @b: Second operand + */ +void bitmap_or(uint8_t *dst, size_t size, const uint8_t *a, const uint8_t *b) +{ + unsigned long *dw = (unsigned long *)dst; + unsigned long *aw = (unsigned long *)a; + unsigned long *bw = (unsigned long *)b; + size_t i; + + for (i = 0; i < size / sizeof(long); i++, dw++, aw++, bw++) + *dw = *aw | *bw; + + for (i = size / sizeof(long) * sizeof(long); i < size; i++) + dst[i] = a[i] | b[i]; +} + /* * ns_enter() - Enter configured user (unless already joined) and network ns * @c: Execution context diff --git a/util.h b/util.h index 78a8fb2..1f02588 100644 --- a/util.h +++ b/util.h @@ -216,6 +216,7 @@ int timespec_diff_ms(const struct timespec *a, const struct timespec *b); void bitmap_set(uint8_t *map, int bit); void bitmap_clear(uint8_t *map, int bit); int bitmap_isset(const uint8_t *map, int bit); +void bitmap_or(uint8_t *dst, size_t size, const uint8_t *a, const uint8_t *b); char *line_read(char *buf, size_t len, int fd); void ns_enter(const struct ctx *c); bool ns_is_init(void);