linux/tools/testing/selftests/bpf/prog_tests/sockopt_inherit.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2#include <test_progs.h>
   3#include "cgroup_helpers.h"
   4
   5#define SOL_CUSTOM                      0xdeadbeef
   6#define CUSTOM_INHERIT1                 0
   7#define CUSTOM_INHERIT2                 1
   8#define CUSTOM_LISTENER                 2
   9
  10static int connect_to_server(int server_fd)
  11{
  12        struct sockaddr_storage addr;
  13        socklen_t len = sizeof(addr);
  14        int fd;
  15
  16        fd = socket(AF_INET, SOCK_STREAM, 0);
  17        if (fd < 0) {
  18                log_err("Failed to create client socket");
  19                return -1;
  20        }
  21
  22        if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
  23                log_err("Failed to get server addr");
  24                goto out;
  25        }
  26
  27        if (connect(fd, (const struct sockaddr *)&addr, len) < 0) {
  28                log_err("Fail to connect to server");
  29                goto out;
  30        }
  31
  32        return fd;
  33
  34out:
  35        close(fd);
  36        return -1;
  37}
  38
  39static int verify_sockopt(int fd, int optname, const char *msg, char expected)
  40{
  41        socklen_t optlen = 1;
  42        char buf = 0;
  43        int err;
  44
  45        err = getsockopt(fd, SOL_CUSTOM, optname, &buf, &optlen);
  46        if (err) {
  47                log_err("%s: failed to call getsockopt", msg);
  48                return 1;
  49        }
  50
  51        printf("%s %d: got=0x%x ? expected=0x%x\n", msg, optname, buf, expected);
  52
  53        if (buf != expected) {
  54                log_err("%s: unexpected getsockopt value %d != %d", msg,
  55                        buf, expected);
  56                return 1;
  57        }
  58
  59        return 0;
  60}
  61
  62static pthread_mutex_t server_started_mtx = PTHREAD_MUTEX_INITIALIZER;
  63static pthread_cond_t server_started = PTHREAD_COND_INITIALIZER;
  64
  65static void *server_thread(void *arg)
  66{
  67        struct sockaddr_storage addr;
  68        socklen_t len = sizeof(addr);
  69        int fd = *(int *)arg;
  70        int client_fd;
  71        int err = 0;
  72
  73        err = listen(fd, 1);
  74
  75        pthread_mutex_lock(&server_started_mtx);
  76        pthread_cond_signal(&server_started);
  77        pthread_mutex_unlock(&server_started_mtx);
  78
  79        if (CHECK_FAIL(err < 0)) {
  80                perror("Failed to listed on socket");
  81                return NULL;
  82        }
  83
  84        err += verify_sockopt(fd, CUSTOM_INHERIT1, "listen", 1);
  85        err += verify_sockopt(fd, CUSTOM_INHERIT2, "listen", 1);
  86        err += verify_sockopt(fd, CUSTOM_LISTENER, "listen", 1);
  87
  88        client_fd = accept(fd, (struct sockaddr *)&addr, &len);
  89        if (CHECK_FAIL(client_fd < 0)) {
  90                perror("Failed to accept client");
  91                return NULL;
  92        }
  93
  94        err += verify_sockopt(client_fd, CUSTOM_INHERIT1, "accept", 1);
  95        err += verify_sockopt(client_fd, CUSTOM_INHERIT2, "accept", 1);
  96        err += verify_sockopt(client_fd, CUSTOM_LISTENER, "accept", 0);
  97
  98        close(client_fd);
  99
 100        return (void *)(long)err;
 101}
 102
 103static int start_server(void)
 104{
 105        struct sockaddr_in addr = {
 106                .sin_family = AF_INET,
 107                .sin_addr.s_addr = htonl(INADDR_LOOPBACK),
 108        };
 109        char buf;
 110        int err;
 111        int fd;
 112        int i;
 113
 114        fd = socket(AF_INET, SOCK_STREAM, 0);
 115        if (fd < 0) {
 116                log_err("Failed to create server socket");
 117                return -1;
 118        }
 119
 120        for (i = CUSTOM_INHERIT1; i <= CUSTOM_LISTENER; i++) {
 121                buf = 0x01;
 122                err = setsockopt(fd, SOL_CUSTOM, i, &buf, 1);
 123                if (err) {
 124                        log_err("Failed to call setsockopt(%d)", i);
 125                        close(fd);
 126                        return -1;
 127                }
 128        }
 129
 130        if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) {
 131                log_err("Failed to bind socket");
 132                close(fd);
 133                return -1;
 134        }
 135
 136        return fd;
 137}
 138
 139static int prog_attach(struct bpf_object *obj, int cgroup_fd, const char *title)
 140{
 141        enum bpf_attach_type attach_type;
 142        enum bpf_prog_type prog_type;
 143        struct bpf_program *prog;
 144        int err;
 145
 146        err = libbpf_prog_type_by_name(title, &prog_type, &attach_type);
 147        if (err) {
 148                log_err("Failed to deduct types for %s BPF program", title);
 149                return -1;
 150        }
 151
 152        prog = bpf_object__find_program_by_title(obj, title);
 153        if (!prog) {
 154                log_err("Failed to find %s BPF program", title);
 155                return -1;
 156        }
 157
 158        err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd,
 159                              attach_type, 0);
 160        if (err) {
 161                log_err("Failed to attach %s BPF program", title);
 162                return -1;
 163        }
 164
 165        return 0;
 166}
 167
 168static void run_test(int cgroup_fd)
 169{
 170        struct bpf_prog_load_attr attr = {
 171                .file = "./sockopt_inherit.o",
 172        };
 173        int server_fd = -1, client_fd;
 174        struct bpf_object *obj;
 175        void *server_err;
 176        pthread_t tid;
 177        int ignored;
 178        int err;
 179
 180        err = bpf_prog_load_xattr(&attr, &obj, &ignored);
 181        if (CHECK_FAIL(err))
 182                return;
 183
 184        err = prog_attach(obj, cgroup_fd, "cgroup/getsockopt");
 185        if (CHECK_FAIL(err))
 186                goto close_bpf_object;
 187
 188        err = prog_attach(obj, cgroup_fd, "cgroup/setsockopt");
 189        if (CHECK_FAIL(err))
 190                goto close_bpf_object;
 191
 192        server_fd = start_server();
 193        if (CHECK_FAIL(server_fd < 0))
 194                goto close_bpf_object;
 195
 196        pthread_mutex_lock(&server_started_mtx);
 197        if (CHECK_FAIL(pthread_create(&tid, NULL, server_thread,
 198                                      (void *)&server_fd))) {
 199                pthread_mutex_unlock(&server_started_mtx);
 200                goto close_server_fd;
 201        }
 202        pthread_cond_wait(&server_started, &server_started_mtx);
 203        pthread_mutex_unlock(&server_started_mtx);
 204
 205        client_fd = connect_to_server(server_fd);
 206        if (CHECK_FAIL(client_fd < 0))
 207                goto close_server_fd;
 208
 209        CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT1, "connect", 0));
 210        CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT2, "connect", 0));
 211        CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_LISTENER, "connect", 0));
 212
 213        pthread_join(tid, &server_err);
 214
 215        err = (int)(long)server_err;
 216        CHECK_FAIL(err);
 217
 218        close(client_fd);
 219
 220close_server_fd:
 221        close(server_fd);
 222close_bpf_object:
 223        bpf_object__close(obj);
 224}
 225
 226void test_sockopt_inherit(void)
 227{
 228        int cgroup_fd;
 229
 230        cgroup_fd = test__join_cgroup("/sockopt_inherit");
 231        if (CHECK_FAIL(cgroup_fd < 0))
 232                return;
 233
 234        run_test(cgroup_fd);
 235        close(cgroup_fd);
 236}
 237