linux/tools/testing/selftests/bpf/progs/test_l4lb_noinline.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2// Copyright (c) 2017 Facebook
   3#include <stddef.h>
   4#include <stdbool.h>
   5#include <string.h>
   6#include <linux/pkt_cls.h>
   7#include <linux/bpf.h>
   8#include <linux/in.h>
   9#include <linux/if_ether.h>
  10#include <linux/ip.h>
  11#include <linux/ipv6.h>
  12#include <linux/icmp.h>
  13#include <linux/icmpv6.h>
  14#include <linux/tcp.h>
  15#include <linux/udp.h>
  16#include <bpf/bpf_helpers.h>
  17#include "test_iptunnel_common.h"
  18#include <bpf/bpf_endian.h>
  19
  20static __always_inline __u32 rol32(__u32 word, unsigned int shift)
  21{
  22        return (word << shift) | (word >> ((-shift) & 31));
  23}
  24
  25/* copy paste of jhash from kernel sources to make sure llvm
  26 * can compile it into valid sequence of bpf instructions
  27 */
  28#define __jhash_mix(a, b, c)                    \
  29{                                               \
  30        a -= c;  a ^= rol32(c, 4);  c += b;     \
  31        b -= a;  b ^= rol32(a, 6);  a += c;     \
  32        c -= b;  c ^= rol32(b, 8);  b += a;     \
  33        a -= c;  a ^= rol32(c, 16); c += b;     \
  34        b -= a;  b ^= rol32(a, 19); a += c;     \
  35        c -= b;  c ^= rol32(b, 4);  b += a;     \
  36}
  37
  38#define __jhash_final(a, b, c)                  \
  39{                                               \
  40        c ^= b; c -= rol32(b, 14);              \
  41        a ^= c; a -= rol32(c, 11);              \
  42        b ^= a; b -= rol32(a, 25);              \
  43        c ^= b; c -= rol32(b, 16);              \
  44        a ^= c; a -= rol32(c, 4);               \
  45        b ^= a; b -= rol32(a, 14);              \
  46        c ^= b; c -= rol32(b, 24);              \
  47}
  48
  49#define JHASH_INITVAL           0xdeadbeef
  50
  51typedef unsigned int u32;
  52
  53static __noinline u32 jhash(const void *key, u32 length, u32 initval)
  54{
  55        u32 a, b, c;
  56        const unsigned char *k = key;
  57
  58        a = b = c = JHASH_INITVAL + length + initval;
  59
  60        while (length > 12) {
  61                a += *(u32 *)(k);
  62                b += *(u32 *)(k + 4);
  63                c += *(u32 *)(k + 8);
  64                __jhash_mix(a, b, c);
  65                length -= 12;
  66                k += 12;
  67        }
  68        switch (length) {
  69        case 12: c += (u32)k[11]<<24;
  70        case 11: c += (u32)k[10]<<16;
  71        case 10: c += (u32)k[9]<<8;
  72        case 9:  c += k[8];
  73        case 8:  b += (u32)k[7]<<24;
  74        case 7:  b += (u32)k[6]<<16;
  75        case 6:  b += (u32)k[5]<<8;
  76        case 5:  b += k[4];
  77        case 4:  a += (u32)k[3]<<24;
  78        case 3:  a += (u32)k[2]<<16;
  79        case 2:  a += (u32)k[1]<<8;
  80        case 1:  a += k[0];
  81                 __jhash_final(a, b, c);
  82        case 0: /* Nothing left to add */
  83                break;
  84        }
  85
  86        return c;
  87}
  88
  89static __noinline u32 __jhash_nwords(u32 a, u32 b, u32 c, u32 initval)
  90{
  91        a += initval;
  92        b += initval;
  93        c += initval;
  94        __jhash_final(a, b, c);
  95        return c;
  96}
  97
  98static __noinline u32 jhash_2words(u32 a, u32 b, u32 initval)
  99{
 100        return __jhash_nwords(a, b, 0, initval + JHASH_INITVAL + (2 << 2));
 101}
 102
 103#define PCKT_FRAGMENTED 65343
 104#define IPV4_HDR_LEN_NO_OPT 20
 105#define IPV4_PLUS_ICMP_HDR 28
 106#define IPV6_PLUS_ICMP_HDR 48
 107#define RING_SIZE 2
 108#define MAX_VIPS 12
 109#define MAX_REALS 5
 110#define CTL_MAP_SIZE 16
 111#define CH_RINGS_SIZE (MAX_VIPS * RING_SIZE)
 112#define F_IPV6 (1 << 0)
 113#define F_HASH_NO_SRC_PORT (1 << 0)
 114#define F_ICMP (1 << 0)
 115#define F_SYN_SET (1 << 1)
 116
 117struct packet_description {
 118        union {
 119                __be32 src;
 120                __be32 srcv6[4];
 121        };
 122        union {
 123                __be32 dst;
 124                __be32 dstv6[4];
 125        };
 126        union {
 127                __u32 ports;
 128                __u16 port16[2];
 129        };
 130        __u8 proto;
 131        __u8 flags;
 132};
 133
 134struct ctl_value {
 135        union {
 136                __u64 value;
 137                __u32 ifindex;
 138                __u8 mac[6];
 139        };
 140};
 141
 142struct vip_meta {
 143        __u32 flags;
 144        __u32 vip_num;
 145};
 146
 147struct real_definition {
 148        union {
 149                __be32 dst;
 150                __be32 dstv6[4];
 151        };
 152        __u8 flags;
 153};
 154
 155struct vip_stats {
 156        __u64 bytes;
 157        __u64 pkts;
 158};
 159
 160struct eth_hdr {
 161        unsigned char eth_dest[ETH_ALEN];
 162        unsigned char eth_source[ETH_ALEN];
 163        unsigned short eth_proto;
 164};
 165
 166struct {
 167        __uint(type, BPF_MAP_TYPE_HASH);
 168        __uint(max_entries, MAX_VIPS);
 169        __type(key, struct vip);
 170        __type(value, struct vip_meta);
 171} vip_map SEC(".maps");
 172
 173struct {
 174        __uint(type, BPF_MAP_TYPE_ARRAY);
 175        __uint(max_entries, CH_RINGS_SIZE);
 176        __type(key, __u32);
 177        __type(value, __u32);
 178} ch_rings SEC(".maps");
 179
 180struct {
 181        __uint(type, BPF_MAP_TYPE_ARRAY);
 182        __uint(max_entries, MAX_REALS);
 183        __type(key, __u32);
 184        __type(value, struct real_definition);
 185} reals SEC(".maps");
 186
 187struct {
 188        __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
 189        __uint(max_entries, MAX_VIPS);
 190        __type(key, __u32);
 191        __type(value, struct vip_stats);
 192} stats SEC(".maps");
 193
 194struct {
 195        __uint(type, BPF_MAP_TYPE_ARRAY);
 196        __uint(max_entries, CTL_MAP_SIZE);
 197        __type(key, __u32);
 198        __type(value, struct ctl_value);
 199} ctl_array SEC(".maps");
 200
 201static __noinline __u32 get_packet_hash(struct packet_description *pckt, bool ipv6)
 202{
 203        if (ipv6)
 204                return jhash_2words(jhash(pckt->srcv6, 16, MAX_VIPS),
 205                                    pckt->ports, CH_RINGS_SIZE);
 206        else
 207                return jhash_2words(pckt->src, pckt->ports, CH_RINGS_SIZE);
 208}
 209
 210static __noinline bool get_packet_dst(struct real_definition **real,
 211                                      struct packet_description *pckt,
 212                                      struct vip_meta *vip_info,
 213                                      bool is_ipv6)
 214{
 215        __u32 hash = get_packet_hash(pckt, is_ipv6);
 216        __u32 key = RING_SIZE * vip_info->vip_num + hash % RING_SIZE;
 217        __u32 *real_pos;
 218
 219        if (hash != 0x358459b7 /* jhash of ipv4 packet */  &&
 220            hash != 0x2f4bc6bb /* jhash of ipv6 packet */)
 221                return 0;
 222
 223        real_pos = bpf_map_lookup_elem(&ch_rings, &key);
 224        if (!real_pos)
 225                return false;
 226        key = *real_pos;
 227        *real = bpf_map_lookup_elem(&reals, &key);
 228        if (!(*real))
 229                return false;
 230        return true;
 231}
 232
 233static __noinline int parse_icmpv6(void *data, void *data_end, __u64 off,
 234                                   struct packet_description *pckt)
 235{
 236        struct icmp6hdr *icmp_hdr;
 237        struct ipv6hdr *ip6h;
 238
 239        icmp_hdr = data + off;
 240        if (icmp_hdr + 1 > data_end)
 241                return TC_ACT_SHOT;
 242        if (icmp_hdr->icmp6_type != ICMPV6_PKT_TOOBIG)
 243                return TC_ACT_OK;
 244        off += sizeof(struct icmp6hdr);
 245        ip6h = data + off;
 246        if (ip6h + 1 > data_end)
 247                return TC_ACT_SHOT;
 248        pckt->proto = ip6h->nexthdr;
 249        pckt->flags |= F_ICMP;
 250        memcpy(pckt->srcv6, ip6h->daddr.s6_addr32, 16);
 251        memcpy(pckt->dstv6, ip6h->saddr.s6_addr32, 16);
 252        return TC_ACT_UNSPEC;
 253}
 254
 255static __noinline int parse_icmp(void *data, void *data_end, __u64 off,
 256                                 struct packet_description *pckt)
 257{
 258        struct icmphdr *icmp_hdr;
 259        struct iphdr *iph;
 260
 261        icmp_hdr = data + off;
 262        if (icmp_hdr + 1 > data_end)
 263                return TC_ACT_SHOT;
 264        if (icmp_hdr->type != ICMP_DEST_UNREACH ||
 265            icmp_hdr->code != ICMP_FRAG_NEEDED)
 266                return TC_ACT_OK;
 267        off += sizeof(struct icmphdr);
 268        iph = data + off;
 269        if (iph + 1 > data_end)
 270                return TC_ACT_SHOT;
 271        if (iph->ihl != 5)
 272                return TC_ACT_SHOT;
 273        pckt->proto = iph->protocol;
 274        pckt->flags |= F_ICMP;
 275        pckt->src = iph->daddr;
 276        pckt->dst = iph->saddr;
 277        return TC_ACT_UNSPEC;
 278}
 279
 280static __noinline bool parse_udp(void *data, __u64 off, void *data_end,
 281                                 struct packet_description *pckt)
 282{
 283        struct udphdr *udp;
 284        udp = data + off;
 285
 286        if (udp + 1 > data_end)
 287                return false;
 288
 289        if (!(pckt->flags & F_ICMP)) {
 290                pckt->port16[0] = udp->source;
 291                pckt->port16[1] = udp->dest;
 292        } else {
 293                pckt->port16[0] = udp->dest;
 294                pckt->port16[1] = udp->source;
 295        }
 296        return true;
 297}
 298
 299static __noinline bool parse_tcp(void *data, __u64 off, void *data_end,
 300                                 struct packet_description *pckt)
 301{
 302        struct tcphdr *tcp;
 303
 304        tcp = data + off;
 305        if (tcp + 1 > data_end)
 306                return false;
 307
 308        if (tcp->syn)
 309                pckt->flags |= F_SYN_SET;
 310
 311        if (!(pckt->flags & F_ICMP)) {
 312                pckt->port16[0] = tcp->source;
 313                pckt->port16[1] = tcp->dest;
 314        } else {
 315                pckt->port16[0] = tcp->dest;
 316                pckt->port16[1] = tcp->source;
 317        }
 318        return true;
 319}
 320
 321static __noinline int process_packet(void *data, __u64 off, void *data_end,
 322                                     bool is_ipv6, struct __sk_buff *skb)
 323{
 324        void *pkt_start = (void *)(long)skb->data;
 325        struct packet_description pckt = {};
 326        struct eth_hdr *eth = pkt_start;
 327        struct bpf_tunnel_key tkey = {};
 328        struct vip_stats *data_stats;
 329        struct real_definition *dst;
 330        struct vip_meta *vip_info;
 331        struct ctl_value *cval;
 332        __u32 v4_intf_pos = 1;
 333        __u32 v6_intf_pos = 2;
 334        struct ipv6hdr *ip6h;
 335        struct vip vip = {};
 336        struct iphdr *iph;
 337        int tun_flag = 0;
 338        __u16 pkt_bytes;
 339        __u64 iph_len;
 340        __u32 ifindex;
 341        __u8 protocol;
 342        __u32 vip_num;
 343        int action;
 344
 345        tkey.tunnel_ttl = 64;
 346        if (is_ipv6) {
 347                ip6h = data + off;
 348                if (ip6h + 1 > data_end)
 349                        return TC_ACT_SHOT;
 350
 351                iph_len = sizeof(struct ipv6hdr);
 352                protocol = ip6h->nexthdr;
 353                pckt.proto = protocol;
 354                pkt_bytes = bpf_ntohs(ip6h->payload_len);
 355                off += iph_len;
 356                if (protocol == IPPROTO_FRAGMENT) {
 357                        return TC_ACT_SHOT;
 358                } else if (protocol == IPPROTO_ICMPV6) {
 359                        action = parse_icmpv6(data, data_end, off, &pckt);
 360                        if (action >= 0)
 361                                return action;
 362                        off += IPV6_PLUS_ICMP_HDR;
 363                } else {
 364                        memcpy(pckt.srcv6, ip6h->saddr.s6_addr32, 16);
 365                        memcpy(pckt.dstv6, ip6h->daddr.s6_addr32, 16);
 366                }
 367        } else {
 368                iph = data + off;
 369                if (iph + 1 > data_end)
 370                        return TC_ACT_SHOT;
 371                if (iph->ihl != 5)
 372                        return TC_ACT_SHOT;
 373
 374                protocol = iph->protocol;
 375                pckt.proto = protocol;
 376                pkt_bytes = bpf_ntohs(iph->tot_len);
 377                off += IPV4_HDR_LEN_NO_OPT;
 378
 379                if (iph->frag_off & PCKT_FRAGMENTED)
 380                        return TC_ACT_SHOT;
 381                if (protocol == IPPROTO_ICMP) {
 382                        action = parse_icmp(data, data_end, off, &pckt);
 383                        if (action >= 0)
 384                                return action;
 385                        off += IPV4_PLUS_ICMP_HDR;
 386                } else {
 387                        pckt.src = iph->saddr;
 388                        pckt.dst = iph->daddr;
 389                }
 390        }
 391        protocol = pckt.proto;
 392
 393        if (protocol == IPPROTO_TCP) {
 394                if (!parse_tcp(data, off, data_end, &pckt))
 395                        return TC_ACT_SHOT;
 396        } else if (protocol == IPPROTO_UDP) {
 397                if (!parse_udp(data, off, data_end, &pckt))
 398                        return TC_ACT_SHOT;
 399        } else {
 400                return TC_ACT_SHOT;
 401        }
 402
 403        if (is_ipv6)
 404                memcpy(vip.daddr.v6, pckt.dstv6, 16);
 405        else
 406                vip.daddr.v4 = pckt.dst;
 407
 408        vip.dport = pckt.port16[1];
 409        vip.protocol = pckt.proto;
 410        vip_info = bpf_map_lookup_elem(&vip_map, &vip);
 411        if (!vip_info) {
 412                vip.dport = 0;
 413                vip_info = bpf_map_lookup_elem(&vip_map, &vip);
 414                if (!vip_info)
 415                        return TC_ACT_SHOT;
 416                pckt.port16[1] = 0;
 417        }
 418
 419        if (vip_info->flags & F_HASH_NO_SRC_PORT)
 420                pckt.port16[0] = 0;
 421
 422        if (!get_packet_dst(&dst, &pckt, vip_info, is_ipv6))
 423                return TC_ACT_SHOT;
 424
 425        if (dst->flags & F_IPV6) {
 426                cval = bpf_map_lookup_elem(&ctl_array, &v6_intf_pos);
 427                if (!cval)
 428                        return TC_ACT_SHOT;
 429                ifindex = cval->ifindex;
 430                memcpy(tkey.remote_ipv6, dst->dstv6, 16);
 431                tun_flag = BPF_F_TUNINFO_IPV6;
 432        } else {
 433                cval = bpf_map_lookup_elem(&ctl_array, &v4_intf_pos);
 434                if (!cval)
 435                        return TC_ACT_SHOT;
 436                ifindex = cval->ifindex;
 437                tkey.remote_ipv4 = dst->dst;
 438        }
 439        vip_num = vip_info->vip_num;
 440        data_stats = bpf_map_lookup_elem(&stats, &vip_num);
 441        if (!data_stats)
 442                return TC_ACT_SHOT;
 443        data_stats->pkts++;
 444        data_stats->bytes += pkt_bytes;
 445        bpf_skb_set_tunnel_key(skb, &tkey, sizeof(tkey), tun_flag);
 446        *(u32 *)eth->eth_dest = tkey.remote_ipv4;
 447        return bpf_redirect(ifindex, 0);
 448}
 449
 450SEC("l4lb-demo")
 451int balancer_ingress(struct __sk_buff *ctx)
 452{
 453        void *data_end = (void *)(long)ctx->data_end;
 454        void *data = (void *)(long)ctx->data;
 455        struct eth_hdr *eth = data;
 456        __u32 eth_proto;
 457        __u32 nh_off;
 458
 459        nh_off = sizeof(struct eth_hdr);
 460        if (data + nh_off > data_end)
 461                return TC_ACT_SHOT;
 462        eth_proto = eth->eth_proto;
 463        if (eth_proto == bpf_htons(ETH_P_IP))
 464                return process_packet(data, nh_off, data_end, false, ctx);
 465        else if (eth_proto == bpf_htons(ETH_P_IPV6))
 466                return process_packet(data, nh_off, data_end, true, ctx);
 467        else
 468                return TC_ACT_SHOT;
 469}
 470char _license[] SEC("license") = "GPL";
 471