diff --git a/checksum.c b/checksum.c index 175381d..cf6fc31 100644 --- a/checksum.c +++ b/checksum.c @@ -56,6 +56,12 @@ #include #include +/* Checksums are optional for UDP over IPv4, so we usually just set + * them to 0. Change this to 1 to calculate real UDP over IPv4 + * checksums + */ +#define UDP4_REAL_CHECKSUMS 0 + /** * sum_16b() - Calculate sum of 16-bit words * @buf: Input buffer @@ -109,6 +115,33 @@ uint16_t csum_unaligned(const void *buf, size_t len, uint32_t init) return (uint16_t)~csum_fold(sum_16b(buf, len) + init); } +/** + * csum_udp4() - Calculate and set checksum for a UDP over IPv4 packet + * @udp4hr: UDP header, initialised apart from checksum + * @saddr: IPv4 source address + * @daddr: IPv4 destination address + * @payload: ICMPv4 packet payload + * @len: Length of @payload (not including UDP) + */ +void csum_udp4(struct udphdr *udp4hr, in_addr_t saddr, in_addr_t daddr, + const void *payload, size_t len) +{ + /* UDP checksums are optional, so don't bother */ + udp4hr->check = 0; + + if (UDP4_REAL_CHECKSUMS) { + /* UNTESTED: if we did want real UDPv4 checksums, this + * is roughly what we'd need */ + uint32_t psum = csum_fold(htonl(saddr)) + + csum_fold(htonl(daddr)) + + htons(len + sizeof(*udp4hr)) + + htons(IPPROTO_UDP); + /* Add in partial checksum for the UDP header alone */ + psum += sum_16b(udp4hr, sizeof(*udp4hr)); + udp4hr->check = csum_unaligned(payload, len, psum); + } +} + /** * csum_icmp4() - Calculate and set checksum for an ICMP packet * @icmp4hr: ICMP header, initialised apart from checksum diff --git a/checksum.h b/checksum.h index 2bb2ff9..2a5e915 100644 --- a/checksum.h +++ b/checksum.h @@ -13,6 +13,8 @@ struct icmp6hdr; uint32_t sum_16b(const void *buf, size_t len); uint16_t csum_fold(uint32_t sum); uint16_t csum_unaligned(const void *buf, size_t len, uint32_t init); +void csum_udp4(struct udphdr *udp4hr, in_addr_t saddr, in_addr_t daddr, + const void *payload, size_t len); void csum_icmp4(struct icmphdr *ih, const void *payload, size_t len); void csum_udp6(struct udphdr *udp6hr, const struct in6_addr *saddr, const struct in6_addr *daddr, diff --git a/dhcp.c b/dhcp.c index 7f0cc0b..8dcf645 100644 --- a/dhcp.c +++ b/dhcp.c @@ -364,9 +364,9 @@ int dhcp(const struct ctx *c, const struct pool *p) opt_set_dns_search(c, sizeof(m->o)); uh->len = htons(len = offsetof(struct msg, o) + fill(m) + sizeof(*uh)); - uh->check = 0; uh->source = htons(67); uh->dest = htons(68); + csum_udp4(uh, c->ip4.gw, c->ip4.addr, uh + 1, len - sizeof(*uh)); iph->tot_len = htons(len += sizeof(*iph)); iph->daddr = c->ip4.addr; diff --git a/tap.c b/tap.c index 9c197cb..58fc1de 100644 --- a/tap.c +++ b/tap.c @@ -145,7 +145,7 @@ void tap_ip_send(const struct ctx *c, const struct in6_addr *src, uint8_t proto, } else if (iph->protocol == IPPROTO_UDP) { struct udphdr *uh = (struct udphdr *)(iph + 1); - uh->check = 0; + csum_udp4(uh, iph->saddr, iph->daddr, uh + 1, len - sizeof(*uh)); } else if (iph->protocol == IPPROTO_ICMP) { struct icmphdr *ih = (struct icmphdr *)(iph + 1); csum_icmp4(ih, ih + 1, len - sizeof(*ih));