linux/tools/testing/selftests/bpf/test_tcp_check_syncookie_user.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2// Copyright (c) 2018 Facebook
   3// Copyright (c) 2019 Cloudflare
   4
   5#include <limits.h>
   6#include <string.h>
   7#include <stdlib.h>
   8#include <unistd.h>
   9
  10#include <arpa/inet.h>
  11#include <netinet/in.h>
  12#include <sys/types.h>
  13#include <sys/socket.h>
  14
  15#include <bpf/bpf.h>
  16#include <bpf/libbpf.h>
  17
  18#include "bpf_rlimit.h"
  19#include "cgroup_helpers.h"
  20
  21static int start_server(const struct sockaddr *addr, socklen_t len)
  22{
  23        int fd;
  24
  25        fd = socket(addr->sa_family, SOCK_STREAM, 0);
  26        if (fd == -1) {
  27                log_err("Failed to create server socket");
  28                goto out;
  29        }
  30
  31        if (bind(fd, addr, len) == -1) {
  32                log_err("Failed to bind server socket");
  33                goto close_out;
  34        }
  35
  36        if (listen(fd, 128) == -1) {
  37                log_err("Failed to listen on server socket");
  38                goto close_out;
  39        }
  40
  41        goto out;
  42
  43close_out:
  44        close(fd);
  45        fd = -1;
  46out:
  47        return fd;
  48}
  49
  50static int connect_to_server(int server_fd)
  51{
  52        struct sockaddr_storage addr;
  53        socklen_t len = sizeof(addr);
  54        int fd = -1;
  55
  56        if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
  57                log_err("Failed to get server addr");
  58                goto out;
  59        }
  60
  61        fd = socket(addr.ss_family, SOCK_STREAM, 0);
  62        if (fd == -1) {
  63                log_err("Failed to create client socket");
  64                goto out;
  65        }
  66
  67        if (connect(fd, (const struct sockaddr *)&addr, len) == -1) {
  68                log_err("Fail to connect to server");
  69                goto close_out;
  70        }
  71
  72        goto out;
  73
  74close_out:
  75        close(fd);
  76        fd = -1;
  77out:
  78        return fd;
  79}
  80
  81static int get_map_fd_by_prog_id(int prog_id, bool *xdp)
  82{
  83        struct bpf_prog_info info = {};
  84        __u32 info_len = sizeof(info);
  85        __u32 map_ids[1];
  86        int prog_fd = -1;
  87        int map_fd = -1;
  88
  89        prog_fd = bpf_prog_get_fd_by_id(prog_id);
  90        if (prog_fd < 0) {
  91                log_err("Failed to get fd by prog id %d", prog_id);
  92                goto err;
  93        }
  94
  95        info.nr_map_ids = 1;
  96        info.map_ids = (__u64)(unsigned long)map_ids;
  97
  98        if (bpf_obj_get_info_by_fd(prog_fd, &info, &info_len)) {
  99                log_err("Failed to get info by prog fd %d", prog_fd);
 100                goto err;
 101        }
 102
 103        if (!info.nr_map_ids) {
 104                log_err("No maps found for prog fd %d", prog_fd);
 105                goto err;
 106        }
 107
 108        *xdp = info.type == BPF_PROG_TYPE_XDP;
 109
 110        map_fd = bpf_map_get_fd_by_id(map_ids[0]);
 111        if (map_fd < 0)
 112                log_err("Failed to get fd by map id %d", map_ids[0]);
 113err:
 114        if (prog_fd >= 0)
 115                close(prog_fd);
 116        return map_fd;
 117}
 118
 119static int run_test(int server_fd, int results_fd, bool xdp)
 120{
 121        int client = -1, srv_client = -1;
 122        int ret = 0;
 123        __u32 key = 0;
 124        __u32 key_gen = 1;
 125        __u32 key_mss = 2;
 126        __u32 value = 0;
 127        __u32 value_gen = 0;
 128        __u32 value_mss = 0;
 129
 130        if (bpf_map_update_elem(results_fd, &key, &value, 0) < 0) {
 131                log_err("Can't clear results");
 132                goto err;
 133        }
 134
 135        if (bpf_map_update_elem(results_fd, &key_gen, &value_gen, 0) < 0) {
 136                log_err("Can't clear results");
 137                goto err;
 138        }
 139
 140        if (bpf_map_update_elem(results_fd, &key_mss, &value_mss, 0) < 0) {
 141                log_err("Can't clear results");
 142                goto err;
 143        }
 144
 145        client = connect_to_server(server_fd);
 146        if (client == -1)
 147                goto err;
 148
 149        srv_client = accept(server_fd, NULL, 0);
 150        if (srv_client == -1) {
 151                log_err("Can't accept connection");
 152                goto err;
 153        }
 154
 155        if (bpf_map_lookup_elem(results_fd, &key, &value) < 0) {
 156                log_err("Can't lookup result");
 157                goto err;
 158        }
 159
 160        if (value == 0) {
 161                log_err("Didn't match syncookie: %u", value);
 162                goto err;
 163        }
 164
 165        if (bpf_map_lookup_elem(results_fd, &key_gen, &value_gen) < 0) {
 166                log_err("Can't lookup result");
 167                goto err;
 168        }
 169
 170        if (xdp && value_gen == 0) {
 171                // SYN packets do not get passed through generic XDP, skip the
 172                // rest of the test.
 173                printf("Skipping XDP cookie check\n");
 174                goto out;
 175        }
 176
 177        if (bpf_map_lookup_elem(results_fd, &key_mss, &value_mss) < 0) {
 178                log_err("Can't lookup result");
 179                goto err;
 180        }
 181
 182        if (value != value_gen) {
 183                log_err("BPF generated cookie does not match kernel one");
 184                goto err;
 185        }
 186
 187        if (value_mss < 536 || value_mss > USHRT_MAX) {
 188                log_err("Unexpected MSS retrieved");
 189                goto err;
 190        }
 191
 192        goto out;
 193
 194err:
 195        ret = 1;
 196out:
 197        close(client);
 198        close(srv_client);
 199        return ret;
 200}
 201
 202int main(int argc, char **argv)
 203{
 204        struct sockaddr_in addr4;
 205        struct sockaddr_in6 addr6;
 206        int server = -1;
 207        int server_v6 = -1;
 208        int results = -1;
 209        int err = 0;
 210        bool xdp;
 211
 212        if (argc < 2) {
 213                fprintf(stderr, "Usage: %s prog_id\n", argv[0]);
 214                exit(1);
 215        }
 216
 217        results = get_map_fd_by_prog_id(atoi(argv[1]), &xdp);
 218        if (results < 0) {
 219                log_err("Can't get map");
 220                goto err;
 221        }
 222
 223        memset(&addr4, 0, sizeof(addr4));
 224        addr4.sin_family = AF_INET;
 225        addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
 226        addr4.sin_port = 0;
 227
 228        memset(&addr6, 0, sizeof(addr6));
 229        addr6.sin6_family = AF_INET6;
 230        addr6.sin6_addr = in6addr_loopback;
 231        addr6.sin6_port = 0;
 232
 233        server = start_server((const struct sockaddr *)&addr4, sizeof(addr4));
 234        if (server == -1)
 235                goto err;
 236
 237        server_v6 = start_server((const struct sockaddr *)&addr6,
 238                                 sizeof(addr6));
 239        if (server_v6 == -1)
 240                goto err;
 241
 242        if (run_test(server, results, xdp))
 243                goto err;
 244
 245        if (run_test(server_v6, results, xdp))
 246                goto err;
 247
 248        printf("ok\n");
 249        goto out;
 250err:
 251        err = 1;
 252out:
 253        close(server);
 254        close(server_v6);
 255        close(results);
 256        return err;
 257}
 258