linux/tools/testing/selftests/bpf/network_helpers.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2#include <errno.h>
   3#include <stdbool.h>
   4#include <stdio.h>
   5#include <string.h>
   6#include <unistd.h>
   7
   8#include <arpa/inet.h>
   9
  10#include <linux/err.h>
  11#include <linux/in.h>
  12#include <linux/in6.h>
  13
  14#include "bpf_util.h"
  15#include "network_helpers.h"
  16
  17#define clean_errno() (errno == 0 ? "None" : strerror(errno))
  18#define log_err(MSG, ...) ({                                            \
  19                        int __save = errno;                             \
  20                        fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
  21                                __FILE__, __LINE__, clean_errno(),      \
  22                                ##__VA_ARGS__);                         \
  23                        errno = __save;                                 \
  24})
  25
  26struct ipv4_packet pkt_v4 = {
  27        .eth.h_proto = __bpf_constant_htons(ETH_P_IP),
  28        .iph.ihl = 5,
  29        .iph.protocol = IPPROTO_TCP,
  30        .iph.tot_len = __bpf_constant_htons(MAGIC_BYTES),
  31        .tcp.urg_ptr = 123,
  32        .tcp.doff = 5,
  33};
  34
  35struct ipv6_packet pkt_v6 = {
  36        .eth.h_proto = __bpf_constant_htons(ETH_P_IPV6),
  37        .iph.nexthdr = IPPROTO_TCP,
  38        .iph.payload_len = __bpf_constant_htons(MAGIC_BYTES),
  39        .tcp.urg_ptr = 123,
  40        .tcp.doff = 5,
  41};
  42
  43static int settimeo(int fd, int timeout_ms)
  44{
  45        struct timeval timeout = { .tv_sec = 3 };
  46
  47        if (timeout_ms > 0) {
  48                timeout.tv_sec = timeout_ms / 1000;
  49                timeout.tv_usec = (timeout_ms % 1000) * 1000;
  50        }
  51
  52        if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
  53                       sizeof(timeout))) {
  54                log_err("Failed to set SO_RCVTIMEO");
  55                return -1;
  56        }
  57
  58        if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout,
  59                       sizeof(timeout))) {
  60                log_err("Failed to set SO_SNDTIMEO");
  61                return -1;
  62        }
  63
  64        return 0;
  65}
  66
  67#define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; })
  68
  69int start_server(int family, int type, const char *addr_str, __u16 port,
  70                 int timeout_ms)
  71{
  72        struct sockaddr_storage addr = {};
  73        socklen_t len;
  74        int fd;
  75
  76        if (make_sockaddr(family, addr_str, port, &addr, &len))
  77                return -1;
  78
  79        fd = socket(family, type, 0);
  80        if (fd < 0) {
  81                log_err("Failed to create server socket");
  82                return -1;
  83        }
  84
  85        if (settimeo(fd, timeout_ms))
  86                goto error_close;
  87
  88        if (bind(fd, (const struct sockaddr *)&addr, len) < 0) {
  89                log_err("Failed to bind socket");
  90                goto error_close;
  91        }
  92
  93        if (type == SOCK_STREAM) {
  94                if (listen(fd, 1) < 0) {
  95                        log_err("Failed to listed on socket");
  96                        goto error_close;
  97                }
  98        }
  99
 100        return fd;
 101
 102error_close:
 103        save_errno_close(fd);
 104        return -1;
 105}
 106
 107int fastopen_connect(int server_fd, const char *data, unsigned int data_len,
 108                     int timeout_ms)
 109{
 110        struct sockaddr_storage addr;
 111        socklen_t addrlen = sizeof(addr);
 112        struct sockaddr_in *addr_in;
 113        int fd, ret;
 114
 115        if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
 116                log_err("Failed to get server addr");
 117                return -1;
 118        }
 119
 120        addr_in = (struct sockaddr_in *)&addr;
 121        fd = socket(addr_in->sin_family, SOCK_STREAM, 0);
 122        if (fd < 0) {
 123                log_err("Failed to create client socket");
 124                return -1;
 125        }
 126
 127        if (settimeo(fd, timeout_ms))
 128                goto error_close;
 129
 130        ret = sendto(fd, data, data_len, MSG_FASTOPEN, (struct sockaddr *)&addr,
 131                     addrlen);
 132        if (ret != data_len) {
 133                log_err("sendto(data, %u) != %d\n", data_len, ret);
 134                goto error_close;
 135        }
 136
 137        return fd;
 138
 139error_close:
 140        save_errno_close(fd);
 141        return -1;
 142}
 143
 144static int connect_fd_to_addr(int fd,
 145                              const struct sockaddr_storage *addr,
 146                              socklen_t addrlen)
 147{
 148        if (connect(fd, (const struct sockaddr *)addr, addrlen)) {
 149                log_err("Failed to connect to server");
 150                return -1;
 151        }
 152
 153        return 0;
 154}
 155
 156int connect_to_fd(int server_fd, int timeout_ms)
 157{
 158        struct sockaddr_storage addr;
 159        struct sockaddr_in *addr_in;
 160        socklen_t addrlen, optlen;
 161        int fd, type;
 162
 163        optlen = sizeof(type);
 164        if (getsockopt(server_fd, SOL_SOCKET, SO_TYPE, &type, &optlen)) {
 165                log_err("getsockopt(SOL_TYPE)");
 166                return -1;
 167        }
 168
 169        addrlen = sizeof(addr);
 170        if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
 171                log_err("Failed to get server addr");
 172                return -1;
 173        }
 174
 175        addr_in = (struct sockaddr_in *)&addr;
 176        fd = socket(addr_in->sin_family, type, 0);
 177        if (fd < 0) {
 178                log_err("Failed to create client socket");
 179                return -1;
 180        }
 181
 182        if (settimeo(fd, timeout_ms))
 183                goto error_close;
 184
 185        if (connect_fd_to_addr(fd, &addr, addrlen))
 186                goto error_close;
 187
 188        return fd;
 189
 190error_close:
 191        save_errno_close(fd);
 192        return -1;
 193}
 194
 195int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms)
 196{
 197        struct sockaddr_storage addr;
 198        socklen_t len = sizeof(addr);
 199
 200        if (settimeo(client_fd, timeout_ms))
 201                return -1;
 202
 203        if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
 204                log_err("Failed to get server addr");
 205                return -1;
 206        }
 207
 208        if (connect_fd_to_addr(client_fd, &addr, len))
 209                return -1;
 210
 211        return 0;
 212}
 213
 214int make_sockaddr(int family, const char *addr_str, __u16 port,
 215                  struct sockaddr_storage *addr, socklen_t *len)
 216{
 217        if (family == AF_INET) {
 218                struct sockaddr_in *sin = (void *)addr;
 219
 220                sin->sin_family = AF_INET;
 221                sin->sin_port = htons(port);
 222                if (addr_str &&
 223                    inet_pton(AF_INET, addr_str, &sin->sin_addr) != 1) {
 224                        log_err("inet_pton(AF_INET, %s)", addr_str);
 225                        return -1;
 226                }
 227                if (len)
 228                        *len = sizeof(*sin);
 229                return 0;
 230        } else if (family == AF_INET6) {
 231                struct sockaddr_in6 *sin6 = (void *)addr;
 232
 233                sin6->sin6_family = AF_INET6;
 234                sin6->sin6_port = htons(port);
 235                if (addr_str &&
 236                    inet_pton(AF_INET6, addr_str, &sin6->sin6_addr) != 1) {
 237                        log_err("inet_pton(AF_INET6, %s)", addr_str);
 238                        return -1;
 239                }
 240                if (len)
 241                        *len = sizeof(*sin6);
 242                return 0;
 243        }
 244        return -1;
 245}
 246