linux/tools/testing/selftests/bpf/test_socket_cookie.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2// Copyright (c) 2018 Facebook
   3
   4#include <string.h>
   5#include <unistd.h>
   6
   7#include <arpa/inet.h>
   8#include <netinet/in.h>
   9#include <sys/types.h>
  10#include <sys/socket.h>
  11
  12#include <bpf/bpf.h>
  13#include <bpf/libbpf.h>
  14
  15#include "bpf_rlimit.h"
  16#include "cgroup_helpers.h"
  17
  18#define CG_PATH                 "/foo"
  19#define SOCKET_COOKIE_PROG      "./socket_cookie_prog.o"
  20
  21static int start_server(void)
  22{
  23        struct sockaddr_in6 addr;
  24        int fd;
  25
  26        fd = socket(AF_INET6, SOCK_STREAM, 0);
  27        if (fd == -1) {
  28                log_err("Failed to create server socket");
  29                goto out;
  30        }
  31
  32        memset(&addr, 0, sizeof(addr));
  33        addr.sin6_family = AF_INET6;
  34        addr.sin6_addr = in6addr_loopback;
  35        addr.sin6_port = 0;
  36
  37        if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) == -1) {
  38                log_err("Failed to bind server socket");
  39                goto close_out;
  40        }
  41
  42        if (listen(fd, 128) == -1) {
  43                log_err("Failed to listen on server socket");
  44                goto close_out;
  45        }
  46
  47        goto out;
  48
  49close_out:
  50        close(fd);
  51        fd = -1;
  52out:
  53        return fd;
  54}
  55
  56static int connect_to_server(int server_fd)
  57{
  58        struct sockaddr_storage addr;
  59        socklen_t len = sizeof(addr);
  60        int fd;
  61
  62        fd = socket(AF_INET6, SOCK_STREAM, 0);
  63        if (fd == -1) {
  64                log_err("Failed to create client socket");
  65                goto out;
  66        }
  67
  68        if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
  69                log_err("Failed to get server addr");
  70                goto close_out;
  71        }
  72
  73        if (connect(fd, (const struct sockaddr *)&addr, len) == -1) {
  74                log_err("Fail to connect to server");
  75                goto close_out;
  76        }
  77
  78        goto out;
  79
  80close_out:
  81        close(fd);
  82        fd = -1;
  83out:
  84        return fd;
  85}
  86
  87static int validate_map(struct bpf_map *map, int client_fd)
  88{
  89        __u32 cookie_expected_value;
  90        struct sockaddr_in6 addr;
  91        socklen_t len = sizeof(addr);
  92        __u32 cookie_value;
  93        __u64 cookie_key;
  94        int err = 0;
  95        int map_fd;
  96
  97        if (!map) {
  98                log_err("Map not found in BPF object");
  99                goto err;
 100        }
 101
 102        map_fd = bpf_map__fd(map);
 103
 104        err = bpf_map_get_next_key(map_fd, NULL, &cookie_key);
 105        if (err) {
 106                log_err("Can't get cookie key from map");
 107                goto out;
 108        }
 109
 110        err = bpf_map_lookup_elem(map_fd, &cookie_key, &cookie_value);
 111        if (err) {
 112                log_err("Can't get cookie value from map");
 113                goto out;
 114        }
 115
 116        err = getsockname(client_fd, (struct sockaddr *)&addr, &len);
 117        if (err) {
 118                log_err("Can't get client local addr");
 119                goto out;
 120        }
 121
 122        cookie_expected_value = (ntohs(addr.sin6_port) << 8) | 0xFF;
 123        if (cookie_value != cookie_expected_value) {
 124                log_err("Unexpected value in map: %x != %x", cookie_value,
 125                        cookie_expected_value);
 126                goto err;
 127        }
 128
 129        goto out;
 130err:
 131        err = -1;
 132out:
 133        return err;
 134}
 135
 136static int run_test(int cgfd)
 137{
 138        enum bpf_attach_type attach_type;
 139        struct bpf_prog_load_attr attr;
 140        struct bpf_program *prog;
 141        struct bpf_object *pobj;
 142        const char *prog_name;
 143        int server_fd = -1;
 144        int client_fd = -1;
 145        int prog_fd = -1;
 146        int err = 0;
 147
 148        memset(&attr, 0, sizeof(attr));
 149        attr.file = SOCKET_COOKIE_PROG;
 150        attr.prog_type = BPF_PROG_TYPE_UNSPEC;
 151
 152        err = bpf_prog_load_xattr(&attr, &pobj, &prog_fd);
 153        if (err) {
 154                log_err("Failed to load %s", attr.file);
 155                goto out;
 156        }
 157
 158        bpf_object__for_each_program(prog, pobj) {
 159                prog_name = bpf_program__title(prog, /*needs_copy*/ false);
 160
 161                if (strcmp(prog_name, "cgroup/connect6") == 0) {
 162                        attach_type = BPF_CGROUP_INET6_CONNECT;
 163                } else if (strcmp(prog_name, "sockops") == 0) {
 164                        attach_type = BPF_CGROUP_SOCK_OPS;
 165                } else {
 166                        log_err("Unexpected prog: %s", prog_name);
 167                        goto err;
 168                }
 169
 170                err = bpf_prog_attach(bpf_program__fd(prog), cgfd, attach_type,
 171                                      BPF_F_ALLOW_OVERRIDE);
 172                if (err) {
 173                        log_err("Failed to attach prog %s", prog_name);
 174                        goto out;
 175                }
 176        }
 177
 178        server_fd = start_server();
 179        if (server_fd == -1)
 180                goto err;
 181
 182        client_fd = connect_to_server(server_fd);
 183        if (client_fd == -1)
 184                goto err;
 185
 186        if (validate_map(bpf_map__next(NULL, pobj), client_fd))
 187                goto err;
 188
 189        goto out;
 190err:
 191        err = -1;
 192out:
 193        close(client_fd);
 194        close(server_fd);
 195        bpf_object__close(pobj);
 196        printf("%s\n", err ? "FAILED" : "PASSED");
 197        return err;
 198}
 199
 200int main(int argc, char **argv)
 201{
 202        int cgfd = -1;
 203        int err = 0;
 204
 205        if (setup_cgroup_environment())
 206                goto err;
 207
 208        cgfd = create_and_get_cgroup(CG_PATH);
 209        if (!cgfd)
 210                goto err;
 211
 212        if (join_cgroup(CG_PATH))
 213                goto err;
 214
 215        if (run_test(cgfd))
 216                goto err;
 217
 218        goto out;
 219err:
 220        err = -1;
 221out:
 222        close(cgfd);
 223        cleanup_cgroup_environment();
 224        return err;
 225}
 226