linux/tools/testing/selftests/bpf/progs/xdpwall.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2021 Facebook */
   3#include <stdbool.h>
   4#include <stdint.h>
   5#include <linux/stddef.h>
   6#include <linux/if_ether.h>
   7#include <linux/in.h>
   8#include <linux/in6.h>
   9#include <linux/ip.h>
  10#include <linux/ipv6.h>
  11#include <linux/tcp.h>
  12#include <linux/udp.h>
  13#include <linux/bpf.h>
  14#include <linux/types.h>
  15#include <bpf/bpf_endian.h>
  16#include <bpf/bpf_helpers.h>
  17
  18enum pkt_parse_err {
  19        NO_ERR,
  20        BAD_IP6_HDR,
  21        BAD_IP4GUE_HDR,
  22        BAD_IP6GUE_HDR,
  23};
  24
  25enum pkt_flag {
  26        TUNNEL = 0x1,
  27        TCP_SYN = 0x2,
  28        QUIC_INITIAL_FLAG = 0x4,
  29        TCP_ACK = 0x8,
  30        TCP_RST = 0x10
  31};
  32
  33struct v4_lpm_key {
  34        __u32 prefixlen;
  35        __u32 src;
  36};
  37
  38struct v4_lpm_val {
  39        struct v4_lpm_key key;
  40        __u8 val;
  41};
  42
  43struct {
  44        __uint(type, BPF_MAP_TYPE_HASH);
  45        __uint(max_entries, 16);
  46        __type(key, struct in6_addr);
  47        __type(value, bool);
  48} v6_addr_map SEC(".maps");
  49
  50struct {
  51        __uint(type, BPF_MAP_TYPE_HASH);
  52        __uint(max_entries, 16);
  53        __type(key, __u32);
  54        __type(value, bool);
  55} v4_addr_map SEC(".maps");
  56
  57struct {
  58        __uint(type, BPF_MAP_TYPE_LPM_TRIE);
  59        __uint(max_entries, 16);
  60        __uint(key_size, sizeof(struct v4_lpm_key));
  61        __uint(value_size, sizeof(struct v4_lpm_val));
  62        __uint(map_flags, BPF_F_NO_PREALLOC);
  63} v4_lpm_val_map SEC(".maps");
  64
  65struct {
  66        __uint(type, BPF_MAP_TYPE_ARRAY);
  67        __uint(max_entries, 16);
  68        __type(key, int);
  69        __type(value, __u8);
  70} tcp_port_map SEC(".maps");
  71
  72struct {
  73        __uint(type, BPF_MAP_TYPE_ARRAY);
  74        __uint(max_entries, 16);
  75        __type(key, int);
  76        __type(value, __u16);
  77} udp_port_map SEC(".maps");
  78
  79enum ip_type { V4 = 1, V6 = 2 };
  80
  81struct fw_match_info {
  82        __u8 v4_src_ip_match;
  83        __u8 v6_src_ip_match;
  84        __u8 v4_src_prefix_match;
  85        __u8 v4_dst_prefix_match;
  86        __u8 tcp_dp_match;
  87        __u16 udp_sp_match;
  88        __u16 udp_dp_match;
  89        bool is_tcp;
  90        bool is_tcp_syn;
  91};
  92
  93struct pkt_info {
  94        enum ip_type type;
  95        union {
  96                struct iphdr *ipv4;
  97                struct ipv6hdr *ipv6;
  98        } ip;
  99        int sport;
 100        int dport;
 101        __u16 trans_hdr_offset;
 102        __u8 proto;
 103        __u8 flags;
 104};
 105
 106static __always_inline struct ethhdr *parse_ethhdr(void *data, void *data_end)
 107{
 108        struct ethhdr *eth = data;
 109
 110        if (eth + 1 > data_end)
 111                return NULL;
 112
 113        return eth;
 114}
 115
 116static __always_inline __u8 filter_ipv6_addr(const struct in6_addr *ipv6addr)
 117{
 118        __u8 *leaf;
 119
 120        leaf = bpf_map_lookup_elem(&v6_addr_map, ipv6addr);
 121
 122        return leaf ? *leaf : 0;
 123}
 124
 125static __always_inline __u8 filter_ipv4_addr(const __u32 ipaddr)
 126{
 127        __u8 *leaf;
 128
 129        leaf = bpf_map_lookup_elem(&v4_addr_map, &ipaddr);
 130
 131        return leaf ? *leaf : 0;
 132}
 133
 134static __always_inline __u8 filter_ipv4_lpm(const __u32 ipaddr)
 135{
 136        struct v4_lpm_key v4_key = {};
 137        struct v4_lpm_val *lpm_val;
 138
 139        v4_key.src = ipaddr;
 140        v4_key.prefixlen = 32;
 141
 142        lpm_val = bpf_map_lookup_elem(&v4_lpm_val_map, &v4_key);
 143
 144        return lpm_val ? lpm_val->val : 0;
 145}
 146
 147
 148static __always_inline void
 149filter_src_dst_ip(struct pkt_info* info, struct fw_match_info* match_info)
 150{
 151        if (info->type == V6) {
 152                match_info->v6_src_ip_match =
 153                        filter_ipv6_addr(&info->ip.ipv6->saddr);
 154        } else if (info->type == V4) {
 155                match_info->v4_src_ip_match =
 156                        filter_ipv4_addr(info->ip.ipv4->saddr);
 157                match_info->v4_src_prefix_match =
 158                        filter_ipv4_lpm(info->ip.ipv4->saddr);
 159                match_info->v4_dst_prefix_match =
 160                        filter_ipv4_lpm(info->ip.ipv4->daddr);
 161        }
 162}
 163
 164static __always_inline void *
 165get_transport_hdr(__u16 offset, void *data, void *data_end)
 166{
 167        if (offset > 255 || data + offset > data_end)
 168                return NULL;
 169
 170        return data + offset;
 171}
 172
 173static __always_inline bool tcphdr_only_contains_flag(struct tcphdr *tcp,
 174                                                      __u32 FLAG)
 175{
 176        return (tcp_flag_word(tcp) &
 177                (TCP_FLAG_ACK | TCP_FLAG_RST | TCP_FLAG_SYN | TCP_FLAG_FIN)) == FLAG;
 178}
 179
 180static __always_inline void set_tcp_flags(struct pkt_info *info,
 181                                          struct tcphdr *tcp) {
 182        if (tcphdr_only_contains_flag(tcp, TCP_FLAG_SYN))
 183                info->flags |= TCP_SYN;
 184        else if (tcphdr_only_contains_flag(tcp, TCP_FLAG_ACK))
 185                info->flags |= TCP_ACK;
 186        else if (tcphdr_only_contains_flag(tcp, TCP_FLAG_RST))
 187                info->flags |= TCP_RST;
 188}
 189
 190static __always_inline bool
 191parse_tcp(struct pkt_info *info, void *transport_hdr, void *data_end)
 192{
 193        struct tcphdr *tcp = transport_hdr;
 194
 195        if (tcp + 1 > data_end)
 196                return false;
 197
 198        info->sport = bpf_ntohs(tcp->source);
 199        info->dport = bpf_ntohs(tcp->dest);
 200        set_tcp_flags(info, tcp);
 201
 202        return true;
 203}
 204
 205static __always_inline bool
 206parse_udp(struct pkt_info *info, void *transport_hdr, void *data_end)
 207{
 208        struct udphdr *udp = transport_hdr;
 209
 210        if (udp + 1 > data_end)
 211                return false;
 212
 213        info->sport = bpf_ntohs(udp->source);
 214        info->dport = bpf_ntohs(udp->dest);
 215
 216        return true;
 217}
 218
 219static __always_inline __u8 filter_tcp_port(int port)
 220{
 221        __u8 *leaf = bpf_map_lookup_elem(&tcp_port_map, &port);
 222
 223        return leaf ? *leaf : 0;
 224}
 225
 226static __always_inline __u16 filter_udp_port(int port)
 227{
 228        __u16 *leaf = bpf_map_lookup_elem(&udp_port_map, &port);
 229
 230        return leaf ? *leaf : 0;
 231}
 232
 233static __always_inline bool
 234filter_transport_hdr(void *transport_hdr, void *data_end,
 235                     struct pkt_info *info, struct fw_match_info *match_info)
 236{
 237        if (info->proto == IPPROTO_TCP) {
 238                if (!parse_tcp(info, transport_hdr, data_end))
 239                        return false;
 240
 241                match_info->is_tcp = true;
 242                match_info->is_tcp_syn = (info->flags & TCP_SYN) > 0;
 243
 244                match_info->tcp_dp_match = filter_tcp_port(info->dport);
 245        } else if (info->proto == IPPROTO_UDP) {
 246                if (!parse_udp(info, transport_hdr, data_end))
 247                        return false;
 248
 249                match_info->udp_dp_match = filter_udp_port(info->dport);
 250                match_info->udp_sp_match = filter_udp_port(info->sport);
 251        }
 252
 253        return true;
 254}
 255
 256static __always_inline __u8
 257parse_gue_v6(struct pkt_info *info, struct ipv6hdr *ip6h, void *data_end)
 258{
 259        struct udphdr *udp = (struct udphdr *)(ip6h + 1);
 260        void *encap_data = udp + 1;
 261
 262        if (udp + 1 > data_end)
 263                return BAD_IP6_HDR;
 264
 265        if (udp->dest != bpf_htons(6666))
 266                return NO_ERR;
 267
 268        info->flags |= TUNNEL;
 269
 270        if (encap_data + 1 > data_end)
 271                return BAD_IP6GUE_HDR;
 272
 273        if (*(__u8 *)encap_data & 0x30) {
 274                struct ipv6hdr *inner_ip6h = encap_data;
 275
 276                if (inner_ip6h + 1 > data_end)
 277                        return BAD_IP6GUE_HDR;
 278
 279                info->type = V6;
 280                info->proto = inner_ip6h->nexthdr;
 281                info->ip.ipv6 = inner_ip6h;
 282                info->trans_hdr_offset += sizeof(struct ipv6hdr) + sizeof(struct udphdr);
 283        } else {
 284                struct iphdr *inner_ip4h = encap_data;
 285
 286                if (inner_ip4h + 1 > data_end)
 287                        return BAD_IP6GUE_HDR;
 288
 289                info->type = V4;
 290                info->proto = inner_ip4h->protocol;
 291                info->ip.ipv4 = inner_ip4h;
 292                info->trans_hdr_offset += sizeof(struct iphdr) + sizeof(struct udphdr);
 293        }
 294
 295        return NO_ERR;
 296}
 297
 298static __always_inline __u8 parse_ipv6_gue(struct pkt_info *info,
 299                                           void *data, void *data_end)
 300{
 301        struct ipv6hdr *ip6h = data + sizeof(struct ethhdr);
 302
 303        if (ip6h + 1 > data_end)
 304                return BAD_IP6_HDR;
 305
 306        info->proto = ip6h->nexthdr;
 307        info->ip.ipv6 = ip6h;
 308        info->type = V6;
 309        info->trans_hdr_offset = sizeof(struct ethhdr) + sizeof(struct ipv6hdr);
 310
 311        if (info->proto == IPPROTO_UDP)
 312                return parse_gue_v6(info, ip6h, data_end);
 313
 314        return NO_ERR;
 315}
 316
 317SEC("xdp")
 318int edgewall(struct xdp_md *ctx)
 319{
 320        void *data_end = (void *)(long)(ctx->data_end);
 321        void *data = (void *)(long)(ctx->data);
 322        struct fw_match_info match_info = {};
 323        struct pkt_info info = {};
 324        __u8 parse_err = NO_ERR;
 325        void *transport_hdr;
 326        struct ethhdr *eth;
 327        bool filter_res;
 328        __u32 proto;
 329
 330        eth = parse_ethhdr(data, data_end);
 331        if (!eth)
 332                return XDP_DROP;
 333
 334        proto = eth->h_proto;
 335        if (proto != bpf_htons(ETH_P_IPV6))
 336                return XDP_DROP;
 337
 338        if (parse_ipv6_gue(&info, data, data_end))
 339                return XDP_DROP;
 340
 341        if (info.proto == IPPROTO_ICMPV6)
 342                return XDP_PASS;
 343
 344        if (info.proto != IPPROTO_TCP && info.proto != IPPROTO_UDP)
 345                return XDP_DROP;
 346
 347        filter_src_dst_ip(&info, &match_info);
 348
 349        transport_hdr = get_transport_hdr(info.trans_hdr_offset, data,
 350                                          data_end);
 351        if (!transport_hdr)
 352                return XDP_DROP;
 353
 354        filter_res = filter_transport_hdr(transport_hdr, data_end,
 355                                          &info, &match_info);
 356        if (!filter_res)
 357                return XDP_DROP;
 358
 359        if (match_info.is_tcp && !match_info.is_tcp_syn)
 360                return XDP_PASS;
 361
 362        return XDP_DROP;
 363}
 364
 365char LICENSE[] SEC("license") = "GPL";
 366