linux/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2019 Facebook */
   3
   4#include <linux/err.h>
   5#include <test_progs.h>
   6#include "bpf_dctcp.skel.h"
   7#include "bpf_cubic.skel.h"
   8
   9#define min(a, b) ((a) < (b) ? (a) : (b))
  10
  11static const unsigned int total_bytes = 10 * 1024 * 1024;
  12static const struct timeval timeo_sec = { .tv_sec = 10 };
  13static const size_t timeo_optlen = sizeof(timeo_sec);
  14static int expected_stg = 0xeB9F;
  15static int stop, duration;
  16
  17static int settimeo(int fd)
  18{
  19        int err;
  20
  21        err = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeo_sec,
  22                         timeo_optlen);
  23        if (CHECK(err == -1, "setsockopt(fd, SO_RCVTIMEO)", "errno:%d\n",
  24                  errno))
  25                return -1;
  26
  27        err = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeo_sec,
  28                         timeo_optlen);
  29        if (CHECK(err == -1, "setsockopt(fd, SO_SNDTIMEO)", "errno:%d\n",
  30                  errno))
  31                return -1;
  32
  33        return 0;
  34}
  35
  36static int settcpca(int fd, const char *tcp_ca)
  37{
  38        int err;
  39
  40        err = setsockopt(fd, IPPROTO_TCP, TCP_CONGESTION, tcp_ca, strlen(tcp_ca));
  41        if (CHECK(err == -1, "setsockopt(fd, TCP_CONGESTION)", "errno:%d\n",
  42                  errno))
  43                return -1;
  44
  45        return 0;
  46}
  47
  48static void *server(void *arg)
  49{
  50        int lfd = (int)(long)arg, err = 0, fd;
  51        ssize_t nr_sent = 0, bytes = 0;
  52        char batch[1500];
  53
  54        fd = accept(lfd, NULL, NULL);
  55        while (fd == -1) {
  56                if (errno == EINTR)
  57                        continue;
  58                err = -errno;
  59                goto done;
  60        }
  61
  62        if (settimeo(fd)) {
  63                err = -errno;
  64                goto done;
  65        }
  66
  67        while (bytes < total_bytes && !READ_ONCE(stop)) {
  68                nr_sent = send(fd, &batch,
  69                               min(total_bytes - bytes, sizeof(batch)), 0);
  70                if (nr_sent == -1 && errno == EINTR)
  71                        continue;
  72                if (nr_sent == -1) {
  73                        err = -errno;
  74                        break;
  75                }
  76                bytes += nr_sent;
  77        }
  78
  79        CHECK(bytes != total_bytes, "send", "%zd != %u nr_sent:%zd errno:%d\n",
  80              bytes, total_bytes, nr_sent, errno);
  81
  82done:
  83        if (fd != -1)
  84                close(fd);
  85        if (err) {
  86                WRITE_ONCE(stop, 1);
  87                return ERR_PTR(err);
  88        }
  89        return NULL;
  90}
  91
  92static void do_test(const char *tcp_ca, const struct bpf_map *sk_stg_map)
  93{
  94        struct sockaddr_in6 sa6 = {};
  95        ssize_t nr_recv = 0, bytes = 0;
  96        int lfd = -1, fd = -1;
  97        pthread_t srv_thread;
  98        socklen_t addrlen = sizeof(sa6);
  99        void *thread_ret;
 100        char batch[1500];
 101        int err;
 102
 103        WRITE_ONCE(stop, 0);
 104
 105        lfd = socket(AF_INET6, SOCK_STREAM, 0);
 106        if (CHECK(lfd == -1, "socket", "errno:%d\n", errno))
 107                return;
 108        fd = socket(AF_INET6, SOCK_STREAM, 0);
 109        if (CHECK(fd == -1, "socket", "errno:%d\n", errno)) {
 110                close(lfd);
 111                return;
 112        }
 113
 114        if (settcpca(lfd, tcp_ca) || settcpca(fd, tcp_ca) ||
 115            settimeo(lfd) || settimeo(fd))
 116                goto done;
 117
 118        /* bind, listen and start server thread to accept */
 119        sa6.sin6_family = AF_INET6;
 120        sa6.sin6_addr = in6addr_loopback;
 121        err = bind(lfd, (struct sockaddr *)&sa6, addrlen);
 122        if (CHECK(err == -1, "bind", "errno:%d\n", errno))
 123                goto done;
 124        err = getsockname(lfd, (struct sockaddr *)&sa6, &addrlen);
 125        if (CHECK(err == -1, "getsockname", "errno:%d\n", errno))
 126                goto done;
 127        err = listen(lfd, 1);
 128        if (CHECK(err == -1, "listen", "errno:%d\n", errno))
 129                goto done;
 130
 131        if (sk_stg_map) {
 132                err = bpf_map_update_elem(bpf_map__fd(sk_stg_map), &fd,
 133                                          &expected_stg, BPF_NOEXIST);
 134                if (CHECK(err, "bpf_map_update_elem(sk_stg_map)",
 135                          "err:%d errno:%d\n", err, errno))
 136                        goto done;
 137        }
 138
 139        /* connect to server */
 140        err = connect(fd, (struct sockaddr *)&sa6, addrlen);
 141        if (CHECK(err == -1, "connect", "errno:%d\n", errno))
 142                goto done;
 143
 144        if (sk_stg_map) {
 145                int tmp_stg;
 146
 147                err = bpf_map_lookup_elem(bpf_map__fd(sk_stg_map), &fd,
 148                                          &tmp_stg);
 149                if (CHECK(!err || errno != ENOENT,
 150                          "bpf_map_lookup_elem(sk_stg_map)",
 151                          "err:%d errno:%d\n", err, errno))
 152                        goto done;
 153        }
 154
 155        err = pthread_create(&srv_thread, NULL, server, (void *)(long)lfd);
 156        if (CHECK(err != 0, "pthread_create", "err:%d errno:%d\n", err, errno))
 157                goto done;
 158
 159        /* recv total_bytes */
 160        while (bytes < total_bytes && !READ_ONCE(stop)) {
 161                nr_recv = recv(fd, &batch,
 162                               min(total_bytes - bytes, sizeof(batch)), 0);
 163                if (nr_recv == -1 && errno == EINTR)
 164                        continue;
 165                if (nr_recv == -1)
 166                        break;
 167                bytes += nr_recv;
 168        }
 169
 170        CHECK(bytes != total_bytes, "recv", "%zd != %u nr_recv:%zd errno:%d\n",
 171              bytes, total_bytes, nr_recv, errno);
 172
 173        WRITE_ONCE(stop, 1);
 174        pthread_join(srv_thread, &thread_ret);
 175        CHECK(IS_ERR(thread_ret), "pthread_join", "thread_ret:%ld",
 176              PTR_ERR(thread_ret));
 177done:
 178        close(lfd);
 179        close(fd);
 180}
 181
 182static void test_cubic(void)
 183{
 184        struct bpf_cubic *cubic_skel;
 185        struct bpf_link *link;
 186
 187        cubic_skel = bpf_cubic__open_and_load();
 188        if (CHECK(!cubic_skel, "bpf_cubic__open_and_load", "failed\n"))
 189                return;
 190
 191        link = bpf_map__attach_struct_ops(cubic_skel->maps.cubic);
 192        if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n",
 193                  PTR_ERR(link))) {
 194                bpf_cubic__destroy(cubic_skel);
 195                return;
 196        }
 197
 198        do_test("bpf_cubic", NULL);
 199
 200        bpf_link__destroy(link);
 201        bpf_cubic__destroy(cubic_skel);
 202}
 203
 204static void test_dctcp(void)
 205{
 206        struct bpf_dctcp *dctcp_skel;
 207        struct bpf_link *link;
 208
 209        dctcp_skel = bpf_dctcp__open_and_load();
 210        if (CHECK(!dctcp_skel, "bpf_dctcp__open_and_load", "failed\n"))
 211                return;
 212
 213        link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp);
 214        if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n",
 215                  PTR_ERR(link))) {
 216                bpf_dctcp__destroy(dctcp_skel);
 217                return;
 218        }
 219
 220        do_test("bpf_dctcp", dctcp_skel->maps.sk_stg_map);
 221        CHECK(dctcp_skel->bss->stg_result != expected_stg,
 222              "Unexpected stg_result", "stg_result (%x) != expected_stg (%x)\n",
 223              dctcp_skel->bss->stg_result, expected_stg);
 224
 225        bpf_link__destroy(link);
 226        bpf_dctcp__destroy(dctcp_skel);
 227}
 228
 229void test_bpf_tcp_ca(void)
 230{
 231        if (test__start_subtest("dctcp"))
 232                test_dctcp();
 233        if (test__start_subtest("cubic"))
 234                test_cubic();
 235}
 236