linux/tools/testing/selftests/bpf/prog_tests/select_reuseport.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2018 Facebook */
   3
   4#include <stdlib.h>
   5#include <unistd.h>
   6#include <stdbool.h>
   7#include <string.h>
   8#include <errno.h>
   9#include <assert.h>
  10#include <fcntl.h>
  11#include <linux/bpf.h>
  12#include <linux/err.h>
  13#include <linux/types.h>
  14#include <linux/if_ether.h>
  15#include <sys/types.h>
  16#include <sys/epoll.h>
  17#include <sys/socket.h>
  18#include <netinet/in.h>
  19#include <bpf/bpf.h>
  20#include <bpf/libbpf.h>
  21#include "bpf_util.h"
  22
  23#include "test_progs.h"
  24#include "test_select_reuseport_common.h"
  25
  26#define MAX_TEST_NAME 80
  27#define MIN_TCPHDR_LEN 20
  28#define UDPHDR_LEN 8
  29
  30#define TCP_SYNCOOKIE_SYSCTL "/proc/sys/net/ipv4/tcp_syncookies"
  31#define TCP_FO_SYSCTL "/proc/sys/net/ipv4/tcp_fastopen"
  32#define REUSEPORT_ARRAY_SIZE 32
  33
  34static int result_map, tmp_index_ovr_map, linum_map, data_check_map;
  35static __u32 expected_results[NR_RESULTS];
  36static int sk_fds[REUSEPORT_ARRAY_SIZE];
  37static int reuseport_array = -1, outer_map = -1;
  38static enum bpf_map_type inner_map_type;
  39static int select_by_skb_data_prog;
  40static int saved_tcp_syncookie = -1;
  41static struct bpf_object *obj;
  42static int saved_tcp_fo = -1;
  43static __u32 index_zero;
  44static int epfd;
  45
  46static union sa46 {
  47        struct sockaddr_in6 v6;
  48        struct sockaddr_in v4;
  49        sa_family_t family;
  50} srv_sa;
  51
  52#define RET_IF(condition, tag, format...) ({                            \
  53        if (CHECK_FAIL(condition)) {                                    \
  54                printf(tag " " format);                                 \
  55                return;                                                 \
  56        }                                                               \
  57})
  58
  59#define RET_ERR(condition, tag, format...) ({                           \
  60        if (CHECK_FAIL(condition)) {                                    \
  61                printf(tag " " format);                                 \
  62                return -1;                                              \
  63        }                                                               \
  64})
  65
  66static int create_maps(enum bpf_map_type inner_type)
  67{
  68        LIBBPF_OPTS(bpf_map_create_opts, opts);
  69
  70        inner_map_type = inner_type;
  71
  72        /* Creating reuseport_array */
  73        reuseport_array = bpf_map_create(inner_type, "reuseport_array",
  74                                         sizeof(__u32), sizeof(__u32), REUSEPORT_ARRAY_SIZE, NULL);
  75        RET_ERR(reuseport_array < 0, "creating reuseport_array",
  76                "reuseport_array:%d errno:%d\n", reuseport_array, errno);
  77
  78        /* Creating outer_map */
  79        opts.inner_map_fd = reuseport_array;
  80        outer_map = bpf_map_create(BPF_MAP_TYPE_ARRAY_OF_MAPS, "outer_map",
  81                                   sizeof(__u32), sizeof(__u32), 1, &opts);
  82        RET_ERR(outer_map < 0, "creating outer_map",
  83                "outer_map:%d errno:%d\n", outer_map, errno);
  84
  85        return 0;
  86}
  87
  88static int prepare_bpf_obj(void)
  89{
  90        struct bpf_program *prog;
  91        struct bpf_map *map;
  92        int err;
  93
  94        obj = bpf_object__open("test_select_reuseport_kern.o");
  95        err = libbpf_get_error(obj);
  96        RET_ERR(err, "open test_select_reuseport_kern.o",
  97                "obj:%p PTR_ERR(obj):%d\n", obj, err);
  98
  99        map = bpf_object__find_map_by_name(obj, "outer_map");
 100        RET_ERR(!map, "find outer_map", "!map\n");
 101        err = bpf_map__reuse_fd(map, outer_map);
 102        RET_ERR(err, "reuse outer_map", "err:%d\n", err);
 103
 104        err = bpf_object__load(obj);
 105        RET_ERR(err, "load bpf_object", "err:%d\n", err);
 106
 107        prog = bpf_object__next_program(obj, NULL);
 108        RET_ERR(!prog, "get first bpf_program", "!prog\n");
 109        select_by_skb_data_prog = bpf_program__fd(prog);
 110        RET_ERR(select_by_skb_data_prog < 0, "get prog fd",
 111                "select_by_skb_data_prog:%d\n", select_by_skb_data_prog);
 112
 113        map = bpf_object__find_map_by_name(obj, "result_map");
 114        RET_ERR(!map, "find result_map", "!map\n");
 115        result_map = bpf_map__fd(map);
 116        RET_ERR(result_map < 0, "get result_map fd",
 117                "result_map:%d\n", result_map);
 118
 119        map = bpf_object__find_map_by_name(obj, "tmp_index_ovr_map");
 120        RET_ERR(!map, "find tmp_index_ovr_map\n", "!map");
 121        tmp_index_ovr_map = bpf_map__fd(map);
 122        RET_ERR(tmp_index_ovr_map < 0, "get tmp_index_ovr_map fd",
 123                "tmp_index_ovr_map:%d\n", tmp_index_ovr_map);
 124
 125        map = bpf_object__find_map_by_name(obj, "linum_map");
 126        RET_ERR(!map, "find linum_map", "!map\n");
 127        linum_map = bpf_map__fd(map);
 128        RET_ERR(linum_map < 0, "get linum_map fd",
 129                "linum_map:%d\n", linum_map);
 130
 131        map = bpf_object__find_map_by_name(obj, "data_check_map");
 132        RET_ERR(!map, "find data_check_map", "!map\n");
 133        data_check_map = bpf_map__fd(map);
 134        RET_ERR(data_check_map < 0, "get data_check_map fd",
 135                "data_check_map:%d\n", data_check_map);
 136
 137        return 0;
 138}
 139
 140static void sa46_init_loopback(union sa46 *sa, sa_family_t family)
 141{
 142        memset(sa, 0, sizeof(*sa));
 143        sa->family = family;
 144        if (sa->family == AF_INET6)
 145                sa->v6.sin6_addr = in6addr_loopback;
 146        else
 147                sa->v4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
 148}
 149
 150static void sa46_init_inany(union sa46 *sa, sa_family_t family)
 151{
 152        memset(sa, 0, sizeof(*sa));
 153        sa->family = family;
 154        if (sa->family == AF_INET6)
 155                sa->v6.sin6_addr = in6addr_any;
 156        else
 157                sa->v4.sin_addr.s_addr = INADDR_ANY;
 158}
 159
 160static int read_int_sysctl(const char *sysctl)
 161{
 162        char buf[16];
 163        int fd, ret;
 164
 165        fd = open(sysctl, 0);
 166        RET_ERR(fd == -1, "open(sysctl)",
 167                "sysctl:%s fd:%d errno:%d\n", sysctl, fd, errno);
 168
 169        ret = read(fd, buf, sizeof(buf));
 170        RET_ERR(ret <= 0, "read(sysctl)",
 171                "sysctl:%s ret:%d errno:%d\n", sysctl, ret, errno);
 172
 173        close(fd);
 174        return atoi(buf);
 175}
 176
 177static int write_int_sysctl(const char *sysctl, int v)
 178{
 179        int fd, ret, size;
 180        char buf[16];
 181
 182        fd = open(sysctl, O_RDWR);
 183        RET_ERR(fd == -1, "open(sysctl)",
 184                "sysctl:%s fd:%d errno:%d\n", sysctl, fd, errno);
 185
 186        size = snprintf(buf, sizeof(buf), "%d", v);
 187        ret = write(fd, buf, size);
 188        RET_ERR(ret != size, "write(sysctl)",
 189                "sysctl:%s ret:%d size:%d errno:%d\n",
 190                sysctl, ret, size, errno);
 191
 192        close(fd);
 193        return 0;
 194}
 195
 196static void restore_sysctls(void)
 197{
 198        if (saved_tcp_fo != -1)
 199                write_int_sysctl(TCP_FO_SYSCTL, saved_tcp_fo);
 200        if (saved_tcp_syncookie != -1)
 201                write_int_sysctl(TCP_SYNCOOKIE_SYSCTL, saved_tcp_syncookie);
 202}
 203
 204static int enable_fastopen(void)
 205{
 206        int fo;
 207
 208        fo = read_int_sysctl(TCP_FO_SYSCTL);
 209        if (fo < 0)
 210                return -1;
 211
 212        return write_int_sysctl(TCP_FO_SYSCTL, fo | 7);
 213}
 214
 215static int enable_syncookie(void)
 216{
 217        return write_int_sysctl(TCP_SYNCOOKIE_SYSCTL, 2);
 218}
 219
 220static int disable_syncookie(void)
 221{
 222        return write_int_sysctl(TCP_SYNCOOKIE_SYSCTL, 0);
 223}
 224
 225static long get_linum(void)
 226{
 227        __u32 linum;
 228        int err;
 229
 230        err = bpf_map_lookup_elem(linum_map, &index_zero, &linum);
 231        RET_ERR(err < 0, "lookup_elem(linum_map)", "err:%d errno:%d\n",
 232                err, errno);
 233
 234        return linum;
 235}
 236
 237static void check_data(int type, sa_family_t family, const struct cmd *cmd,
 238                       int cli_fd)
 239{
 240        struct data_check expected = {}, result;
 241        union sa46 cli_sa;
 242        socklen_t addrlen;
 243        int err;
 244
 245        addrlen = sizeof(cli_sa);
 246        err = getsockname(cli_fd, (struct sockaddr *)&cli_sa,
 247                          &addrlen);
 248        RET_IF(err < 0, "getsockname(cli_fd)", "err:%d errno:%d\n",
 249               err, errno);
 250
 251        err = bpf_map_lookup_elem(data_check_map, &index_zero, &result);
 252        RET_IF(err < 0, "lookup_elem(data_check_map)", "err:%d errno:%d\n",
 253               err, errno);
 254
 255        if (type == SOCK_STREAM) {
 256                expected.len = MIN_TCPHDR_LEN;
 257                expected.ip_protocol = IPPROTO_TCP;
 258        } else {
 259                expected.len = UDPHDR_LEN;
 260                expected.ip_protocol = IPPROTO_UDP;
 261        }
 262
 263        if (family == AF_INET6) {
 264                expected.eth_protocol = htons(ETH_P_IPV6);
 265                expected.bind_inany = !srv_sa.v6.sin6_addr.s6_addr32[3] &&
 266                        !srv_sa.v6.sin6_addr.s6_addr32[2] &&
 267                        !srv_sa.v6.sin6_addr.s6_addr32[1] &&
 268                        !srv_sa.v6.sin6_addr.s6_addr32[0];
 269
 270                memcpy(&expected.skb_addrs[0], cli_sa.v6.sin6_addr.s6_addr32,
 271                       sizeof(cli_sa.v6.sin6_addr));
 272                memcpy(&expected.skb_addrs[4], &in6addr_loopback,
 273                       sizeof(in6addr_loopback));
 274                expected.skb_ports[0] = cli_sa.v6.sin6_port;
 275                expected.skb_ports[1] = srv_sa.v6.sin6_port;
 276        } else {
 277                expected.eth_protocol = htons(ETH_P_IP);
 278                expected.bind_inany = !srv_sa.v4.sin_addr.s_addr;
 279
 280                expected.skb_addrs[0] = cli_sa.v4.sin_addr.s_addr;
 281                expected.skb_addrs[1] = htonl(INADDR_LOOPBACK);
 282                expected.skb_ports[0] = cli_sa.v4.sin_port;
 283                expected.skb_ports[1] = srv_sa.v4.sin_port;
 284        }
 285
 286        if (memcmp(&result, &expected, offsetof(struct data_check,
 287                                                equal_check_end))) {
 288                printf("unexpected data_check\n");
 289                printf("  result: (0x%x, %u, %u)\n",
 290                       result.eth_protocol, result.ip_protocol,
 291                       result.bind_inany);
 292                printf("expected: (0x%x, %u, %u)\n",
 293                       expected.eth_protocol, expected.ip_protocol,
 294                       expected.bind_inany);
 295                RET_IF(1, "data_check result != expected",
 296                       "bpf_prog_linum:%ld\n", get_linum());
 297        }
 298
 299        RET_IF(!result.hash, "data_check result.hash empty",
 300               "result.hash:%u", result.hash);
 301
 302        expected.len += cmd ? sizeof(*cmd) : 0;
 303        if (type == SOCK_STREAM)
 304                RET_IF(expected.len > result.len, "expected.len > result.len",
 305                       "expected.len:%u result.len:%u bpf_prog_linum:%ld\n",
 306                       expected.len, result.len, get_linum());
 307        else
 308                RET_IF(expected.len != result.len, "expected.len != result.len",
 309                       "expected.len:%u result.len:%u bpf_prog_linum:%ld\n",
 310                       expected.len, result.len, get_linum());
 311}
 312
 313static const char *result_to_str(enum result res)
 314{
 315        switch (res) {
 316        case DROP_ERR_INNER_MAP:
 317                return "DROP_ERR_INNER_MAP";
 318        case DROP_ERR_SKB_DATA:
 319                return "DROP_ERR_SKB_DATA";
 320        case DROP_ERR_SK_SELECT_REUSEPORT:
 321                return "DROP_ERR_SK_SELECT_REUSEPORT";
 322        case DROP_MISC:
 323                return "DROP_MISC";
 324        case PASS:
 325                return "PASS";
 326        case PASS_ERR_SK_SELECT_REUSEPORT:
 327                return "PASS_ERR_SK_SELECT_REUSEPORT";
 328        default:
 329                return "UNKNOWN";
 330        }
 331}
 332
 333static void check_results(void)
 334{
 335        __u32 results[NR_RESULTS];
 336        __u32 i, broken = 0;
 337        int err;
 338
 339        for (i = 0; i < NR_RESULTS; i++) {
 340                err = bpf_map_lookup_elem(result_map, &i, &results[i]);
 341                RET_IF(err < 0, "lookup_elem(result_map)",
 342                       "i:%u err:%d errno:%d\n", i, err, errno);
 343        }
 344
 345        for (i = 0; i < NR_RESULTS; i++) {
 346                if (results[i] != expected_results[i]) {
 347                        broken = i;
 348                        break;
 349                }
 350        }
 351
 352        if (i == NR_RESULTS)
 353                return;
 354
 355        printf("unexpected result\n");
 356        printf(" result: [");
 357        printf("%u", results[0]);
 358        for (i = 1; i < NR_RESULTS; i++)
 359                printf(", %u", results[i]);
 360        printf("]\n");
 361
 362        printf("expected: [");
 363        printf("%u", expected_results[0]);
 364        for (i = 1; i < NR_RESULTS; i++)
 365                printf(", %u", expected_results[i]);
 366        printf("]\n");
 367
 368        printf("mismatch on %s (bpf_prog_linum:%ld)\n", result_to_str(broken),
 369               get_linum());
 370
 371        CHECK_FAIL(true);
 372}
 373
 374static int send_data(int type, sa_family_t family, void *data, size_t len,
 375                     enum result expected)
 376{
 377        union sa46 cli_sa;
 378        int fd, err;
 379
 380        fd = socket(family, type, 0);
 381        RET_ERR(fd == -1, "socket()", "fd:%d errno:%d\n", fd, errno);
 382
 383        sa46_init_loopback(&cli_sa, family);
 384        err = bind(fd, (struct sockaddr *)&cli_sa, sizeof(cli_sa));
 385        RET_ERR(fd == -1, "bind(cli_sa)", "err:%d errno:%d\n", err, errno);
 386
 387        err = sendto(fd, data, len, MSG_FASTOPEN, (struct sockaddr *)&srv_sa,
 388                     sizeof(srv_sa));
 389        RET_ERR(err != len && expected >= PASS,
 390                "sendto()", "family:%u err:%d errno:%d expected:%d\n",
 391                family, err, errno, expected);
 392
 393        return fd;
 394}
 395
 396static void do_test(int type, sa_family_t family, struct cmd *cmd,
 397                    enum result expected)
 398{
 399        int nev, srv_fd, cli_fd;
 400        struct epoll_event ev;
 401        struct cmd rcv_cmd;
 402        ssize_t nread;
 403
 404        cli_fd = send_data(type, family, cmd, cmd ? sizeof(*cmd) : 0,
 405                           expected);
 406        if (cli_fd < 0)
 407                return;
 408        nev = epoll_wait(epfd, &ev, 1, expected >= PASS ? 5 : 0);
 409        RET_IF((nev <= 0 && expected >= PASS) ||
 410               (nev > 0 && expected < PASS),
 411               "nev <> expected",
 412               "nev:%d expected:%d type:%d family:%d data:(%d, %d)\n",
 413               nev, expected, type, family,
 414               cmd ? cmd->reuseport_index : -1,
 415               cmd ? cmd->pass_on_failure : -1);
 416        check_results();
 417        check_data(type, family, cmd, cli_fd);
 418
 419        if (expected < PASS)
 420                return;
 421
 422        RET_IF(expected != PASS_ERR_SK_SELECT_REUSEPORT &&
 423               cmd->reuseport_index != ev.data.u32,
 424               "check cmd->reuseport_index",
 425               "cmd:(%u, %u) ev.data.u32:%u\n",
 426               cmd->pass_on_failure, cmd->reuseport_index, ev.data.u32);
 427
 428        srv_fd = sk_fds[ev.data.u32];
 429        if (type == SOCK_STREAM) {
 430                int new_fd = accept(srv_fd, NULL, 0);
 431
 432                RET_IF(new_fd == -1, "accept(srv_fd)",
 433                       "ev.data.u32:%u new_fd:%d errno:%d\n",
 434                       ev.data.u32, new_fd, errno);
 435
 436                nread = recv(new_fd, &rcv_cmd, sizeof(rcv_cmd), MSG_DONTWAIT);
 437                RET_IF(nread != sizeof(rcv_cmd),
 438                       "recv(new_fd)",
 439                       "ev.data.u32:%u nread:%zd sizeof(rcv_cmd):%zu errno:%d\n",
 440                       ev.data.u32, nread, sizeof(rcv_cmd), errno);
 441
 442                close(new_fd);
 443        } else {
 444                nread = recv(srv_fd, &rcv_cmd, sizeof(rcv_cmd), MSG_DONTWAIT);
 445                RET_IF(nread != sizeof(rcv_cmd),
 446                       "recv(sk_fds)",
 447                       "ev.data.u32:%u nread:%zd sizeof(rcv_cmd):%zu errno:%d\n",
 448                       ev.data.u32, nread, sizeof(rcv_cmd), errno);
 449        }
 450
 451        close(cli_fd);
 452}
 453
 454static void test_err_inner_map(int type, sa_family_t family)
 455{
 456        struct cmd cmd = {
 457                .reuseport_index = 0,
 458                .pass_on_failure = 0,
 459        };
 460
 461        expected_results[DROP_ERR_INNER_MAP]++;
 462        do_test(type, family, &cmd, DROP_ERR_INNER_MAP);
 463}
 464
 465static void test_err_skb_data(int type, sa_family_t family)
 466{
 467        expected_results[DROP_ERR_SKB_DATA]++;
 468        do_test(type, family, NULL, DROP_ERR_SKB_DATA);
 469}
 470
 471static void test_err_sk_select_port(int type, sa_family_t family)
 472{
 473        struct cmd cmd = {
 474                .reuseport_index = REUSEPORT_ARRAY_SIZE,
 475                .pass_on_failure = 0,
 476        };
 477
 478        expected_results[DROP_ERR_SK_SELECT_REUSEPORT]++;
 479        do_test(type, family, &cmd, DROP_ERR_SK_SELECT_REUSEPORT);
 480}
 481
 482static void test_pass(int type, sa_family_t family)
 483{
 484        struct cmd cmd;
 485        int i;
 486
 487        cmd.pass_on_failure = 0;
 488        for (i = 0; i < REUSEPORT_ARRAY_SIZE; i++) {
 489                expected_results[PASS]++;
 490                cmd.reuseport_index = i;
 491                do_test(type, family, &cmd, PASS);
 492        }
 493}
 494
 495static void test_syncookie(int type, sa_family_t family)
 496{
 497        int err, tmp_index = 1;
 498        struct cmd cmd = {
 499                .reuseport_index = 0,
 500                .pass_on_failure = 0,
 501        };
 502
 503        /*
 504         * +1 for TCP-SYN and
 505         * +1 for the TCP-ACK (ack the syncookie)
 506         */
 507        expected_results[PASS] += 2;
 508        enable_syncookie();
 509        /*
 510         * Simulate TCP-SYN and TCP-ACK are handled by two different sk:
 511         * TCP-SYN: select sk_fds[tmp_index = 1] tmp_index is from the
 512         *          tmp_index_ovr_map
 513         * TCP-ACK: select sk_fds[reuseport_index = 0] reuseport_index
 514         *          is from the cmd.reuseport_index
 515         */
 516        err = bpf_map_update_elem(tmp_index_ovr_map, &index_zero,
 517                                  &tmp_index, BPF_ANY);
 518        RET_IF(err < 0, "update_elem(tmp_index_ovr_map, 0, 1)",
 519               "err:%d errno:%d\n", err, errno);
 520        do_test(type, family, &cmd, PASS);
 521        err = bpf_map_lookup_elem(tmp_index_ovr_map, &index_zero,
 522                                  &tmp_index);
 523        RET_IF(err < 0 || tmp_index >= 0,
 524               "lookup_elem(tmp_index_ovr_map)",
 525               "err:%d errno:%d tmp_index:%d\n",
 526               err, errno, tmp_index);
 527        disable_syncookie();
 528}
 529
 530static void test_pass_on_err(int type, sa_family_t family)
 531{
 532        struct cmd cmd = {
 533                .reuseport_index = REUSEPORT_ARRAY_SIZE,
 534                .pass_on_failure = 1,
 535        };
 536
 537        expected_results[PASS_ERR_SK_SELECT_REUSEPORT] += 1;
 538        do_test(type, family, &cmd, PASS_ERR_SK_SELECT_REUSEPORT);
 539}
 540
 541static void test_detach_bpf(int type, sa_family_t family)
 542{
 543#ifdef SO_DETACH_REUSEPORT_BPF
 544        __u32 nr_run_before = 0, nr_run_after = 0, tmp, i;
 545        struct epoll_event ev;
 546        int cli_fd, err, nev;
 547        struct cmd cmd = {};
 548        int optvalue = 0;
 549
 550        err = setsockopt(sk_fds[0], SOL_SOCKET, SO_DETACH_REUSEPORT_BPF,
 551                         &optvalue, sizeof(optvalue));
 552        RET_IF(err == -1, "setsockopt(SO_DETACH_REUSEPORT_BPF)",
 553               "err:%d errno:%d\n", err, errno);
 554
 555        err = setsockopt(sk_fds[1], SOL_SOCKET, SO_DETACH_REUSEPORT_BPF,
 556                         &optvalue, sizeof(optvalue));
 557        RET_IF(err == 0 || errno != ENOENT,
 558               "setsockopt(SO_DETACH_REUSEPORT_BPF)",
 559               "err:%d errno:%d\n", err, errno);
 560
 561        for (i = 0; i < NR_RESULTS; i++) {
 562                err = bpf_map_lookup_elem(result_map, &i, &tmp);
 563                RET_IF(err < 0, "lookup_elem(result_map)",
 564                       "i:%u err:%d errno:%d\n", i, err, errno);
 565                nr_run_before += tmp;
 566        }
 567
 568        cli_fd = send_data(type, family, &cmd, sizeof(cmd), PASS);
 569        if (cli_fd < 0)
 570                return;
 571        nev = epoll_wait(epfd, &ev, 1, 5);
 572        RET_IF(nev <= 0, "nev <= 0",
 573               "nev:%d expected:1 type:%d family:%d data:(0, 0)\n",
 574               nev,  type, family);
 575
 576        for (i = 0; i < NR_RESULTS; i++) {
 577                err = bpf_map_lookup_elem(result_map, &i, &tmp);
 578                RET_IF(err < 0, "lookup_elem(result_map)",
 579                       "i:%u err:%d errno:%d\n", i, err, errno);
 580                nr_run_after += tmp;
 581        }
 582
 583        RET_IF(nr_run_before != nr_run_after,
 584               "nr_run_before != nr_run_after",
 585               "nr_run_before:%u nr_run_after:%u\n",
 586               nr_run_before, nr_run_after);
 587
 588        close(cli_fd);
 589#else
 590        test__skip();
 591#endif
 592}
 593
 594static void prepare_sk_fds(int type, sa_family_t family, bool inany)
 595{
 596        const int first = REUSEPORT_ARRAY_SIZE - 1;
 597        int i, err, optval = 1;
 598        struct epoll_event ev;
 599        socklen_t addrlen;
 600
 601        if (inany)
 602                sa46_init_inany(&srv_sa, family);
 603        else
 604                sa46_init_loopback(&srv_sa, family);
 605        addrlen = sizeof(srv_sa);
 606
 607        /*
 608         * The sk_fds[] is filled from the back such that the order
 609         * is exactly opposite to the (struct sock_reuseport *)reuse->socks[].
 610         */
 611        for (i = first; i >= 0; i--) {
 612                sk_fds[i] = socket(family, type, 0);
 613                RET_IF(sk_fds[i] == -1, "socket()", "sk_fds[%d]:%d errno:%d\n",
 614                       i, sk_fds[i], errno);
 615                err = setsockopt(sk_fds[i], SOL_SOCKET, SO_REUSEPORT,
 616                                 &optval, sizeof(optval));
 617                RET_IF(err == -1, "setsockopt(SO_REUSEPORT)",
 618                       "sk_fds[%d] err:%d errno:%d\n",
 619                       i, err, errno);
 620
 621                if (i == first) {
 622                        err = setsockopt(sk_fds[i], SOL_SOCKET,
 623                                         SO_ATTACH_REUSEPORT_EBPF,
 624                                         &select_by_skb_data_prog,
 625                                         sizeof(select_by_skb_data_prog));
 626                        RET_IF(err < 0, "setsockopt(SO_ATTACH_REUEPORT_EBPF)",
 627                               "err:%d errno:%d\n", err, errno);
 628                }
 629
 630                err = bind(sk_fds[i], (struct sockaddr *)&srv_sa, addrlen);
 631                RET_IF(err < 0, "bind()", "sk_fds[%d] err:%d errno:%d\n",
 632                       i, err, errno);
 633
 634                if (type == SOCK_STREAM) {
 635                        err = listen(sk_fds[i], 10);
 636                        RET_IF(err < 0, "listen()",
 637                               "sk_fds[%d] err:%d errno:%d\n",
 638                               i, err, errno);
 639                }
 640
 641                err = bpf_map_update_elem(reuseport_array, &i, &sk_fds[i],
 642                                          BPF_NOEXIST);
 643                RET_IF(err < 0, "update_elem(reuseport_array)",
 644                       "sk_fds[%d] err:%d errno:%d\n", i, err, errno);
 645
 646                if (i == first) {
 647                        socklen_t addrlen = sizeof(srv_sa);
 648
 649                        err = getsockname(sk_fds[i], (struct sockaddr *)&srv_sa,
 650                                          &addrlen);
 651                        RET_IF(err == -1, "getsockname()",
 652                               "sk_fds[%d] err:%d errno:%d\n", i, err, errno);
 653                }
 654        }
 655
 656        epfd = epoll_create(1);
 657        RET_IF(epfd == -1, "epoll_create(1)",
 658               "epfd:%d errno:%d\n", epfd, errno);
 659
 660        ev.events = EPOLLIN;
 661        for (i = 0; i < REUSEPORT_ARRAY_SIZE; i++) {
 662                ev.data.u32 = i;
 663                err = epoll_ctl(epfd, EPOLL_CTL_ADD, sk_fds[i], &ev);
 664                RET_IF(err, "epoll_ctl(EPOLL_CTL_ADD)", "sk_fds[%d]\n", i);
 665        }
 666}
 667
 668static void setup_per_test(int type, sa_family_t family, bool inany,
 669                           bool no_inner_map)
 670{
 671        int ovr = -1, err;
 672
 673        prepare_sk_fds(type, family, inany);
 674        err = bpf_map_update_elem(tmp_index_ovr_map, &index_zero, &ovr,
 675                                  BPF_ANY);
 676        RET_IF(err < 0, "update_elem(tmp_index_ovr_map, 0, -1)",
 677               "err:%d errno:%d\n", err, errno);
 678
 679        /* Install reuseport_array to outer_map? */
 680        if (no_inner_map)
 681                return;
 682
 683        err = bpf_map_update_elem(outer_map, &index_zero, &reuseport_array,
 684                                  BPF_ANY);
 685        RET_IF(err < 0, "update_elem(outer_map, 0, reuseport_array)",
 686               "err:%d errno:%d\n", err, errno);
 687}
 688
 689static void cleanup_per_test(bool no_inner_map)
 690{
 691        int i, err, zero = 0;
 692
 693        memset(expected_results, 0, sizeof(expected_results));
 694
 695        for (i = 0; i < NR_RESULTS; i++) {
 696                err = bpf_map_update_elem(result_map, &i, &zero, BPF_ANY);
 697                RET_IF(err, "reset elem in result_map",
 698                       "i:%u err:%d errno:%d\n", i, err, errno);
 699        }
 700
 701        err = bpf_map_update_elem(linum_map, &zero, &zero, BPF_ANY);
 702        RET_IF(err, "reset line number in linum_map", "err:%d errno:%d\n",
 703               err, errno);
 704
 705        for (i = 0; i < REUSEPORT_ARRAY_SIZE; i++)
 706                close(sk_fds[i]);
 707        close(epfd);
 708
 709        /* Delete reuseport_array from outer_map? */
 710        if (no_inner_map)
 711                return;
 712
 713        err = bpf_map_delete_elem(outer_map, &index_zero);
 714        RET_IF(err < 0, "delete_elem(outer_map)",
 715               "err:%d errno:%d\n", err, errno);
 716}
 717
 718static void cleanup(void)
 719{
 720        if (outer_map >= 0) {
 721                close(outer_map);
 722                outer_map = -1;
 723        }
 724
 725        if (reuseport_array >= 0) {
 726                close(reuseport_array);
 727                reuseport_array = -1;
 728        }
 729
 730        if (obj) {
 731                bpf_object__close(obj);
 732                obj = NULL;
 733        }
 734
 735        memset(expected_results, 0, sizeof(expected_results));
 736}
 737
 738static const char *maptype_str(enum bpf_map_type type)
 739{
 740        switch (type) {
 741        case BPF_MAP_TYPE_REUSEPORT_SOCKARRAY:
 742                return "reuseport_sockarray";
 743        case BPF_MAP_TYPE_SOCKMAP:
 744                return "sockmap";
 745        case BPF_MAP_TYPE_SOCKHASH:
 746                return "sockhash";
 747        default:
 748                return "unknown";
 749        }
 750}
 751
 752static const char *family_str(sa_family_t family)
 753{
 754        switch (family) {
 755        case AF_INET:
 756                return "IPv4";
 757        case AF_INET6:
 758                return "IPv6";
 759        default:
 760                return "unknown";
 761        }
 762}
 763
 764static const char *sotype_str(int sotype)
 765{
 766        switch (sotype) {
 767        case SOCK_STREAM:
 768                return "TCP";
 769        case SOCK_DGRAM:
 770                return "UDP";
 771        default:
 772                return "unknown";
 773        }
 774}
 775
 776#define TEST_INIT(fn_, ...) { .fn = fn_, .name = #fn_, __VA_ARGS__ }
 777
 778static void test_config(int sotype, sa_family_t family, bool inany)
 779{
 780        const struct test {
 781                void (*fn)(int sotype, sa_family_t family);
 782                const char *name;
 783                bool no_inner_map;
 784                int need_sotype;
 785        } tests[] = {
 786                TEST_INIT(test_err_inner_map,
 787                          .no_inner_map = true),
 788                TEST_INIT(test_err_skb_data),
 789                TEST_INIT(test_err_sk_select_port),
 790                TEST_INIT(test_pass),
 791                TEST_INIT(test_syncookie,
 792                          .need_sotype = SOCK_STREAM),
 793                TEST_INIT(test_pass_on_err),
 794                TEST_INIT(test_detach_bpf),
 795        };
 796        char s[MAX_TEST_NAME];
 797        const struct test *t;
 798
 799        for (t = tests; t < tests + ARRAY_SIZE(tests); t++) {
 800                if (t->need_sotype && t->need_sotype != sotype)
 801                        continue; /* test not compatible with socket type */
 802
 803                snprintf(s, sizeof(s), "%s %s/%s %s %s",
 804                         maptype_str(inner_map_type),
 805                         family_str(family), sotype_str(sotype),
 806                         inany ? "INANY" : "LOOPBACK", t->name);
 807
 808                if (!test__start_subtest(s))
 809                        continue;
 810
 811                setup_per_test(sotype, family, inany, t->no_inner_map);
 812                t->fn(sotype, family);
 813                cleanup_per_test(t->no_inner_map);
 814        }
 815}
 816
 817#define BIND_INANY true
 818
 819static void test_all(void)
 820{
 821        const struct config {
 822                int sotype;
 823                sa_family_t family;
 824                bool inany;
 825        } configs[] = {
 826                { SOCK_STREAM, AF_INET },
 827                { SOCK_STREAM, AF_INET, BIND_INANY },
 828                { SOCK_STREAM, AF_INET6 },
 829                { SOCK_STREAM, AF_INET6, BIND_INANY },
 830                { SOCK_DGRAM, AF_INET },
 831                { SOCK_DGRAM, AF_INET6 },
 832        };
 833        const struct config *c;
 834
 835        for (c = configs; c < configs + ARRAY_SIZE(configs); c++)
 836                test_config(c->sotype, c->family, c->inany);
 837}
 838
 839void test_map_type(enum bpf_map_type mt)
 840{
 841        if (create_maps(mt))
 842                goto out;
 843        if (prepare_bpf_obj())
 844                goto out;
 845
 846        test_all();
 847out:
 848        cleanup();
 849}
 850
 851void serial_test_select_reuseport(void)
 852{
 853        saved_tcp_fo = read_int_sysctl(TCP_FO_SYSCTL);
 854        if (saved_tcp_fo < 0)
 855                goto out;
 856        saved_tcp_syncookie = read_int_sysctl(TCP_SYNCOOKIE_SYSCTL);
 857        if (saved_tcp_syncookie < 0)
 858                goto out;
 859
 860        if (enable_fastopen())
 861                goto out;
 862        if (disable_syncookie())
 863                goto out;
 864
 865        test_map_type(BPF_MAP_TYPE_REUSEPORT_SOCKARRAY);
 866        test_map_type(BPF_MAP_TYPE_SOCKMAP);
 867        test_map_type(BPF_MAP_TYPE_SOCKHASH);
 868out:
 869        restore_sysctls();
 870}
 871