linux/tools/testing/selftests/bpf/prog_tests/sock_fields.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2019 Facebook */
   3
   4#include <netinet/in.h>
   5#include <arpa/inet.h>
   6#include <unistd.h>
   7#include <stdlib.h>
   8#include <string.h>
   9#include <errno.h>
  10
  11#include <bpf/bpf.h>
  12#include <bpf/libbpf.h>
  13#include <linux/compiler.h>
  14
  15#include "network_helpers.h"
  16#include "cgroup_helpers.h"
  17#include "test_progs.h"
  18#include "bpf_rlimit.h"
  19#include "test_sock_fields.skel.h"
  20
  21enum bpf_linum_array_idx {
  22        EGRESS_LINUM_IDX,
  23        INGRESS_LINUM_IDX,
  24        __NR_BPF_LINUM_ARRAY_IDX,
  25};
  26
  27struct bpf_spinlock_cnt {
  28        struct bpf_spin_lock lock;
  29        __u32 cnt;
  30};
  31
  32#define PARENT_CGROUP   "/test-bpf-sock-fields"
  33#define CHILD_CGROUP    "/test-bpf-sock-fields/child"
  34#define DATA "Hello BPF!"
  35#define DATA_LEN sizeof(DATA)
  36
  37static struct sockaddr_in6 srv_sa6, cli_sa6;
  38static int sk_pkt_out_cnt10_fd;
  39static struct test_sock_fields *skel;
  40static int sk_pkt_out_cnt_fd;
  41static __u64 parent_cg_id;
  42static __u64 child_cg_id;
  43static int linum_map_fd;
  44static __u32 duration;
  45
  46static __u32 egress_linum_idx = EGRESS_LINUM_IDX;
  47static __u32 ingress_linum_idx = INGRESS_LINUM_IDX;
  48
  49static void print_sk(const struct bpf_sock *sk, const char *prefix)
  50{
  51        char src_ip4[24], dst_ip4[24];
  52        char src_ip6[64], dst_ip6[64];
  53
  54        inet_ntop(AF_INET, &sk->src_ip4, src_ip4, sizeof(src_ip4));
  55        inet_ntop(AF_INET6, &sk->src_ip6, src_ip6, sizeof(src_ip6));
  56        inet_ntop(AF_INET, &sk->dst_ip4, dst_ip4, sizeof(dst_ip4));
  57        inet_ntop(AF_INET6, &sk->dst_ip6, dst_ip6, sizeof(dst_ip6));
  58
  59        printf("%s: state:%u bound_dev_if:%u family:%u type:%u protocol:%u mark:%u priority:%u "
  60               "src_ip4:%x(%s) src_ip6:%x:%x:%x:%x(%s) src_port:%u "
  61               "dst_ip4:%x(%s) dst_ip6:%x:%x:%x:%x(%s) dst_port:%u\n",
  62               prefix,
  63               sk->state, sk->bound_dev_if, sk->family, sk->type, sk->protocol,
  64               sk->mark, sk->priority,
  65               sk->src_ip4, src_ip4,
  66               sk->src_ip6[0], sk->src_ip6[1], sk->src_ip6[2], sk->src_ip6[3],
  67               src_ip6, sk->src_port,
  68               sk->dst_ip4, dst_ip4,
  69               sk->dst_ip6[0], sk->dst_ip6[1], sk->dst_ip6[2], sk->dst_ip6[3],
  70               dst_ip6, ntohs(sk->dst_port));
  71}
  72
  73static void print_tp(const struct bpf_tcp_sock *tp, const char *prefix)
  74{
  75        printf("%s: snd_cwnd:%u srtt_us:%u rtt_min:%u snd_ssthresh:%u rcv_nxt:%u "
  76               "snd_nxt:%u snd:una:%u mss_cache:%u ecn_flags:%u "
  77               "rate_delivered:%u rate_interval_us:%u packets_out:%u "
  78               "retrans_out:%u total_retrans:%u segs_in:%u data_segs_in:%u "
  79               "segs_out:%u data_segs_out:%u lost_out:%u sacked_out:%u "
  80               "bytes_received:%llu bytes_acked:%llu\n",
  81               prefix,
  82               tp->snd_cwnd, tp->srtt_us, tp->rtt_min, tp->snd_ssthresh,
  83               tp->rcv_nxt, tp->snd_nxt, tp->snd_una, tp->mss_cache,
  84               tp->ecn_flags, tp->rate_delivered, tp->rate_interval_us,
  85               tp->packets_out, tp->retrans_out, tp->total_retrans,
  86               tp->segs_in, tp->data_segs_in, tp->segs_out,
  87               tp->data_segs_out, tp->lost_out, tp->sacked_out,
  88               tp->bytes_received, tp->bytes_acked);
  89}
  90
  91static void check_result(void)
  92{
  93        struct bpf_tcp_sock srv_tp, cli_tp, listen_tp;
  94        struct bpf_sock srv_sk, cli_sk, listen_sk;
  95        __u32 ingress_linum, egress_linum;
  96        int err;
  97
  98        err = bpf_map_lookup_elem(linum_map_fd, &egress_linum_idx,
  99                                  &egress_linum);
 100        CHECK(err < 0, "bpf_map_lookup_elem(linum_map_fd)",
 101              "err:%d errno:%d\n", err, errno);
 102
 103        err = bpf_map_lookup_elem(linum_map_fd, &ingress_linum_idx,
 104                                  &ingress_linum);
 105        CHECK(err < 0, "bpf_map_lookup_elem(linum_map_fd)",
 106              "err:%d errno:%d\n", err, errno);
 107
 108        memcpy(&srv_sk, &skel->bss->srv_sk, sizeof(srv_sk));
 109        memcpy(&srv_tp, &skel->bss->srv_tp, sizeof(srv_tp));
 110        memcpy(&cli_sk, &skel->bss->cli_sk, sizeof(cli_sk));
 111        memcpy(&cli_tp, &skel->bss->cli_tp, sizeof(cli_tp));
 112        memcpy(&listen_sk, &skel->bss->listen_sk, sizeof(listen_sk));
 113        memcpy(&listen_tp, &skel->bss->listen_tp, sizeof(listen_tp));
 114
 115        print_sk(&listen_sk, "listen_sk");
 116        print_sk(&srv_sk, "srv_sk");
 117        print_sk(&cli_sk, "cli_sk");
 118        print_tp(&listen_tp, "listen_tp");
 119        print_tp(&srv_tp, "srv_tp");
 120        print_tp(&cli_tp, "cli_tp");
 121
 122        CHECK(listen_sk.state != 10 ||
 123              listen_sk.family != AF_INET6 ||
 124              listen_sk.protocol != IPPROTO_TCP ||
 125              memcmp(listen_sk.src_ip6, &in6addr_loopback,
 126                     sizeof(listen_sk.src_ip6)) ||
 127              listen_sk.dst_ip6[0] || listen_sk.dst_ip6[1] ||
 128              listen_sk.dst_ip6[2] || listen_sk.dst_ip6[3] ||
 129              listen_sk.src_port != ntohs(srv_sa6.sin6_port) ||
 130              listen_sk.dst_port,
 131              "listen_sk",
 132              "Unexpected. Check listen_sk output. ingress_linum:%u\n",
 133              ingress_linum);
 134
 135        CHECK(srv_sk.state == 10 ||
 136              !srv_sk.state ||
 137              srv_sk.family != AF_INET6 ||
 138              srv_sk.protocol != IPPROTO_TCP ||
 139              memcmp(srv_sk.src_ip6, &in6addr_loopback,
 140                     sizeof(srv_sk.src_ip6)) ||
 141              memcmp(srv_sk.dst_ip6, &in6addr_loopback,
 142                     sizeof(srv_sk.dst_ip6)) ||
 143              srv_sk.src_port != ntohs(srv_sa6.sin6_port) ||
 144              srv_sk.dst_port != cli_sa6.sin6_port,
 145              "srv_sk", "Unexpected. Check srv_sk output. egress_linum:%u\n",
 146              egress_linum);
 147
 148        CHECK(!skel->bss->lsndtime, "srv_tp", "Unexpected lsndtime:0\n");
 149
 150        CHECK(cli_sk.state == 10 ||
 151              !cli_sk.state ||
 152              cli_sk.family != AF_INET6 ||
 153              cli_sk.protocol != IPPROTO_TCP ||
 154              memcmp(cli_sk.src_ip6, &in6addr_loopback,
 155                     sizeof(cli_sk.src_ip6)) ||
 156              memcmp(cli_sk.dst_ip6, &in6addr_loopback,
 157                     sizeof(cli_sk.dst_ip6)) ||
 158              cli_sk.src_port != ntohs(cli_sa6.sin6_port) ||
 159              cli_sk.dst_port != srv_sa6.sin6_port,
 160              "cli_sk", "Unexpected. Check cli_sk output. egress_linum:%u\n",
 161              egress_linum);
 162
 163        CHECK(listen_tp.data_segs_out ||
 164              listen_tp.data_segs_in ||
 165              listen_tp.total_retrans ||
 166              listen_tp.bytes_acked,
 167              "listen_tp",
 168              "Unexpected. Check listen_tp output. ingress_linum:%u\n",
 169              ingress_linum);
 170
 171        CHECK(srv_tp.data_segs_out != 2 ||
 172              srv_tp.data_segs_in ||
 173              srv_tp.snd_cwnd != 10 ||
 174              srv_tp.total_retrans ||
 175              srv_tp.bytes_acked < 2 * DATA_LEN,
 176              "srv_tp", "Unexpected. Check srv_tp output. egress_linum:%u\n",
 177              egress_linum);
 178
 179        CHECK(cli_tp.data_segs_out ||
 180              cli_tp.data_segs_in != 2 ||
 181              cli_tp.snd_cwnd != 10 ||
 182              cli_tp.total_retrans ||
 183              cli_tp.bytes_received < 2 * DATA_LEN,
 184              "cli_tp", "Unexpected. Check cli_tp output. egress_linum:%u\n",
 185              egress_linum);
 186
 187        CHECK(skel->bss->parent_cg_id != parent_cg_id,
 188              "parent_cg_id", "%zu != %zu\n",
 189              (size_t)skel->bss->parent_cg_id, (size_t)parent_cg_id);
 190
 191        CHECK(skel->bss->child_cg_id != child_cg_id,
 192              "child_cg_id", "%zu != %zu\n",
 193               (size_t)skel->bss->child_cg_id, (size_t)child_cg_id);
 194}
 195
 196static void check_sk_pkt_out_cnt(int accept_fd, int cli_fd)
 197{
 198        struct bpf_spinlock_cnt pkt_out_cnt = {}, pkt_out_cnt10 = {};
 199        int err;
 200
 201        pkt_out_cnt.cnt = ~0;
 202        pkt_out_cnt10.cnt = ~0;
 203        err = bpf_map_lookup_elem(sk_pkt_out_cnt_fd, &accept_fd, &pkt_out_cnt);
 204        if (!err)
 205                err = bpf_map_lookup_elem(sk_pkt_out_cnt10_fd, &accept_fd,
 206                                          &pkt_out_cnt10);
 207
 208        /* The bpf prog only counts for fullsock and
 209         * passive connection did not become fullsock until 3WHS
 210         * had been finished, so the bpf prog only counted two data
 211         * packet out.
 212         */
 213        CHECK(err || pkt_out_cnt.cnt < 0xeB9F + 2 ||
 214              pkt_out_cnt10.cnt < 0xeB9F + 20,
 215              "bpf_map_lookup_elem(sk_pkt_out_cnt, &accept_fd)",
 216              "err:%d errno:%d pkt_out_cnt:%u pkt_out_cnt10:%u\n",
 217              err, errno, pkt_out_cnt.cnt, pkt_out_cnt10.cnt);
 218
 219        pkt_out_cnt.cnt = ~0;
 220        pkt_out_cnt10.cnt = ~0;
 221        err = bpf_map_lookup_elem(sk_pkt_out_cnt_fd, &cli_fd, &pkt_out_cnt);
 222        if (!err)
 223                err = bpf_map_lookup_elem(sk_pkt_out_cnt10_fd, &cli_fd,
 224                                          &pkt_out_cnt10);
 225        /* Active connection is fullsock from the beginning.
 226         * 1 SYN and 1 ACK during 3WHS
 227         * 2 Acks on data packet.
 228         *
 229         * The bpf_prog initialized it to 0xeB9F.
 230         */
 231        CHECK(err || pkt_out_cnt.cnt < 0xeB9F + 4 ||
 232              pkt_out_cnt10.cnt < 0xeB9F + 40,
 233              "bpf_map_lookup_elem(sk_pkt_out_cnt, &cli_fd)",
 234              "err:%d errno:%d pkt_out_cnt:%u pkt_out_cnt10:%u\n",
 235              err, errno, pkt_out_cnt.cnt, pkt_out_cnt10.cnt);
 236}
 237
 238static int init_sk_storage(int sk_fd, __u32 pkt_out_cnt)
 239{
 240        struct bpf_spinlock_cnt scnt = {};
 241        int err;
 242
 243        scnt.cnt = pkt_out_cnt;
 244        err = bpf_map_update_elem(sk_pkt_out_cnt_fd, &sk_fd, &scnt,
 245                                  BPF_NOEXIST);
 246        if (CHECK(err, "bpf_map_update_elem(sk_pkt_out_cnt_fd)",
 247                  "err:%d errno:%d\n", err, errno))
 248                return err;
 249
 250        err = bpf_map_update_elem(sk_pkt_out_cnt10_fd, &sk_fd, &scnt,
 251                                  BPF_NOEXIST);
 252        if (CHECK(err, "bpf_map_update_elem(sk_pkt_out_cnt10_fd)",
 253                  "err:%d errno:%d\n", err, errno))
 254                return err;
 255
 256        return 0;
 257}
 258
 259static void test(void)
 260{
 261        int listen_fd = -1, cli_fd = -1, accept_fd = -1, err, i;
 262        socklen_t addrlen = sizeof(struct sockaddr_in6);
 263        char buf[DATA_LEN];
 264
 265        /* Prepare listen_fd */
 266        listen_fd = start_server(AF_INET6, SOCK_STREAM, "::1", 0, 0);
 267        /* start_server() has logged the error details */
 268        if (CHECK_FAIL(listen_fd == -1))
 269                goto done;
 270
 271        err = getsockname(listen_fd, (struct sockaddr *)&srv_sa6, &addrlen);
 272        if (CHECK(err, "getsockname(listen_fd)", "err:%d errno:%d\n", err,
 273                  errno))
 274                goto done;
 275        memcpy(&skel->bss->srv_sa6, &srv_sa6, sizeof(srv_sa6));
 276
 277        cli_fd = connect_to_fd(listen_fd, 0);
 278        if (CHECK_FAIL(cli_fd == -1))
 279                goto done;
 280
 281        err = getsockname(cli_fd, (struct sockaddr *)&cli_sa6, &addrlen);
 282        if (CHECK(err, "getsockname(cli_fd)", "err:%d errno:%d\n",
 283                  err, errno))
 284                goto done;
 285
 286        accept_fd = accept(listen_fd, NULL, NULL);
 287        if (CHECK(accept_fd == -1, "accept(listen_fd)",
 288                  "accept_fd:%d errno:%d\n",
 289                  accept_fd, errno))
 290                goto done;
 291
 292        if (init_sk_storage(accept_fd, 0xeB9F))
 293                goto done;
 294
 295        for (i = 0; i < 2; i++) {
 296                /* Send some data from accept_fd to cli_fd.
 297                 * MSG_EOR to stop kernel from coalescing two pkts.
 298                 */
 299                err = send(accept_fd, DATA, DATA_LEN, MSG_EOR);
 300                if (CHECK(err != DATA_LEN, "send(accept_fd)",
 301                          "err:%d errno:%d\n", err, errno))
 302                        goto done;
 303
 304                err = recv(cli_fd, buf, DATA_LEN, 0);
 305                if (CHECK(err != DATA_LEN, "recv(cli_fd)", "err:%d errno:%d\n",
 306                          err, errno))
 307                        goto done;
 308        }
 309
 310        shutdown(cli_fd, SHUT_WR);
 311        err = recv(accept_fd, buf, 1, 0);
 312        if (CHECK(err, "recv(accept_fd) for fin", "err:%d errno:%d\n",
 313                  err, errno))
 314                goto done;
 315        shutdown(accept_fd, SHUT_WR);
 316        err = recv(cli_fd, buf, 1, 0);
 317        if (CHECK(err, "recv(cli_fd) for fin", "err:%d errno:%d\n",
 318                  err, errno))
 319                goto done;
 320        check_sk_pkt_out_cnt(accept_fd, cli_fd);
 321        check_result();
 322
 323done:
 324        if (accept_fd != -1)
 325                close(accept_fd);
 326        if (cli_fd != -1)
 327                close(cli_fd);
 328        if (listen_fd != -1)
 329                close(listen_fd);
 330}
 331
 332void test_sock_fields(void)
 333{
 334        struct bpf_link *egress_link = NULL, *ingress_link = NULL;
 335        int parent_cg_fd = -1, child_cg_fd = -1;
 336
 337        /* Create a cgroup, get fd, and join it */
 338        parent_cg_fd = test__join_cgroup(PARENT_CGROUP);
 339        if (CHECK_FAIL(parent_cg_fd < 0))
 340                return;
 341        parent_cg_id = get_cgroup_id(PARENT_CGROUP);
 342        if (CHECK_FAIL(!parent_cg_id))
 343                goto done;
 344
 345        child_cg_fd = test__join_cgroup(CHILD_CGROUP);
 346        if (CHECK_FAIL(child_cg_fd < 0))
 347                goto done;
 348        child_cg_id = get_cgroup_id(CHILD_CGROUP);
 349        if (CHECK_FAIL(!child_cg_id))
 350                goto done;
 351
 352        skel = test_sock_fields__open_and_load();
 353        if (CHECK(!skel, "test_sock_fields__open_and_load", "failed\n"))
 354                goto done;
 355
 356        egress_link = bpf_program__attach_cgroup(skel->progs.egress_read_sock_fields,
 357                                                 child_cg_fd);
 358        if (!ASSERT_OK_PTR(egress_link, "attach_cgroup(egress)"))
 359                goto done;
 360
 361        ingress_link = bpf_program__attach_cgroup(skel->progs.ingress_read_sock_fields,
 362                                                  child_cg_fd);
 363        if (!ASSERT_OK_PTR(ingress_link, "attach_cgroup(ingress)"))
 364                goto done;
 365
 366        linum_map_fd = bpf_map__fd(skel->maps.linum_map);
 367        sk_pkt_out_cnt_fd = bpf_map__fd(skel->maps.sk_pkt_out_cnt);
 368        sk_pkt_out_cnt10_fd = bpf_map__fd(skel->maps.sk_pkt_out_cnt10);
 369
 370        test();
 371
 372done:
 373        bpf_link__destroy(egress_link);
 374        bpf_link__destroy(ingress_link);
 375        test_sock_fields__destroy(skel);
 376        if (child_cg_fd >= 0)
 377                close(child_cg_fd);
 378        if (parent_cg_fd >= 0)
 379                close(parent_cg_fd);
 380}
 381