linux/tools/testing/selftests/net/tcp_fastopen_backup_key.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2
   3/*
   4 * Test key rotation for TFO.
   5 * New keys are 'rotated' in two steps:
   6 * 1) Add new key as the 'backup' key 'behind' the primary key
   7 * 2) Make new key the primary by swapping the backup and primary keys
   8 *
   9 * The rotation is done in stages using multiple sockets bound
  10 * to the same port via SO_REUSEPORT. This simulates key rotation
  11 * behind say a load balancer. We verify that across the rotation
  12 * there are no cases in which a cookie is not accepted by verifying
  13 * that TcpExtTCPFastOpenPassiveFail remains 0.
  14 */
  15#define _GNU_SOURCE
  16#include <arpa/inet.h>
  17#include <errno.h>
  18#include <error.h>
  19#include <stdbool.h>
  20#include <stdio.h>
  21#include <stdlib.h>
  22#include <string.h>
  23#include <sys/epoll.h>
  24#include <unistd.h>
  25#include <netinet/tcp.h>
  26#include <fcntl.h>
  27#include <time.h>
  28
  29#ifndef TCP_FASTOPEN_KEY
  30#define TCP_FASTOPEN_KEY 33
  31#endif
  32
  33#define N_LISTEN 10
  34#define PROC_FASTOPEN_KEY "/proc/sys/net/ipv4/tcp_fastopen_key"
  35#define KEY_LENGTH 16
  36
  37#ifndef ARRAY_SIZE
  38#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
  39#endif
  40
  41static bool do_ipv6;
  42static bool do_sockopt;
  43static bool do_rotate;
  44static int key_len = KEY_LENGTH;
  45static int rcv_fds[N_LISTEN];
  46static int proc_fd;
  47static const char *IP4_ADDR = "127.0.0.1";
  48static const char *IP6_ADDR = "::1";
  49static const int PORT = 8891;
  50
  51static void get_keys(int fd, uint32_t *keys)
  52{
  53        char buf[128];
  54        socklen_t len = KEY_LENGTH * 2;
  55
  56        if (do_sockopt) {
  57                if (getsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, &len))
  58                        error(1, errno, "Unable to get key");
  59                return;
  60        }
  61        lseek(proc_fd, 0, SEEK_SET);
  62        if (read(proc_fd, buf, sizeof(buf)) <= 0)
  63                error(1, errno, "Unable to read %s", PROC_FASTOPEN_KEY);
  64        if (sscanf(buf, "%x-%x-%x-%x,%x-%x-%x-%x", keys, keys + 1, keys + 2,
  65            keys + 3, keys + 4, keys + 5, keys + 6, keys + 7) != 8)
  66                error(1, 0, "Unable to parse %s", PROC_FASTOPEN_KEY);
  67}
  68
  69static void set_keys(int fd, uint32_t *keys)
  70{
  71        char buf[128];
  72
  73        if (do_sockopt) {
  74                if (setsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys,
  75                    key_len))
  76                        error(1, errno, "Unable to set key");
  77                return;
  78        }
  79        if (do_rotate)
  80                snprintf(buf, 128, "%08x-%08x-%08x-%08x,%08x-%08x-%08x-%08x",
  81                         keys[0], keys[1], keys[2], keys[3], keys[4], keys[5],
  82                         keys[6], keys[7]);
  83        else
  84                snprintf(buf, 128, "%08x-%08x-%08x-%08x",
  85                         keys[0], keys[1], keys[2], keys[3]);
  86        lseek(proc_fd, 0, SEEK_SET);
  87        if (write(proc_fd, buf, sizeof(buf)) <= 0)
  88                error(1, errno, "Unable to write %s", PROC_FASTOPEN_KEY);
  89}
  90
  91static void build_rcv_fd(int family, int proto, int *rcv_fds)
  92{
  93        struct sockaddr_in  addr4 = {0};
  94        struct sockaddr_in6 addr6 = {0};
  95        struct sockaddr *addr;
  96        int opt = 1, i, sz;
  97        int qlen = 100;
  98        uint32_t keys[8];
  99
 100        switch (family) {
 101        case AF_INET:
 102                addr4.sin_family = family;
 103                addr4.sin_addr.s_addr = htonl(INADDR_ANY);
 104                addr4.sin_port = htons(PORT);
 105                sz = sizeof(addr4);
 106                addr = (struct sockaddr *)&addr4;
 107                break;
 108        case AF_INET6:
 109                addr6.sin6_family = AF_INET6;
 110                addr6.sin6_addr = in6addr_any;
 111                addr6.sin6_port = htons(PORT);
 112                sz = sizeof(addr6);
 113                addr = (struct sockaddr *)&addr6;
 114                break;
 115        default:
 116                error(1, 0, "Unsupported family %d", family);
 117                /* clang does not recognize error() above as terminating
 118                 * the program, so it complains that saddr, sz are
 119                 * not initialized when this code path is taken. Silence it.
 120                 */
 121                return;
 122        }
 123        for (i = 0; i < ARRAY_SIZE(keys); i++)
 124                keys[i] = rand();
 125        for (i = 0; i < N_LISTEN; i++) {
 126                rcv_fds[i] = socket(family, proto, 0);
 127                if (rcv_fds[i] < 0)
 128                        error(1, errno, "failed to create receive socket");
 129                if (setsockopt(rcv_fds[i], SOL_SOCKET, SO_REUSEPORT, &opt,
 130                               sizeof(opt)))
 131                        error(1, errno, "failed to set SO_REUSEPORT");
 132                if (bind(rcv_fds[i], addr, sz))
 133                        error(1, errno, "failed to bind receive socket");
 134                if (setsockopt(rcv_fds[i], SOL_TCP, TCP_FASTOPEN, &qlen,
 135                               sizeof(qlen)))
 136                        error(1, errno, "failed to set TCP_FASTOPEN");
 137                set_keys(rcv_fds[i], keys);
 138                if (proto == SOCK_STREAM && listen(rcv_fds[i], 10))
 139                        error(1, errno, "failed to listen on receive port");
 140        }
 141}
 142
 143static int connect_and_send(int family, int proto)
 144{
 145        struct sockaddr_in  saddr4 = {0};
 146        struct sockaddr_in  daddr4 = {0};
 147        struct sockaddr_in6 saddr6 = {0};
 148        struct sockaddr_in6 daddr6 = {0};
 149        struct sockaddr *saddr, *daddr;
 150        int fd, sz, ret;
 151        char data[1];
 152
 153        switch (family) {
 154        case AF_INET:
 155                saddr4.sin_family = AF_INET;
 156                saddr4.sin_addr.s_addr = htonl(INADDR_ANY);
 157                saddr4.sin_port = 0;
 158
 159                daddr4.sin_family = AF_INET;
 160                if (!inet_pton(family, IP4_ADDR, &daddr4.sin_addr.s_addr))
 161                        error(1, errno, "inet_pton failed: %s", IP4_ADDR);
 162                daddr4.sin_port = htons(PORT);
 163
 164                sz = sizeof(saddr4);
 165                saddr = (struct sockaddr *)&saddr4;
 166                daddr = (struct sockaddr *)&daddr4;
 167                break;
 168        case AF_INET6:
 169                saddr6.sin6_family = AF_INET6;
 170                saddr6.sin6_addr = in6addr_any;
 171
 172                daddr6.sin6_family = AF_INET6;
 173                if (!inet_pton(family, IP6_ADDR, &daddr6.sin6_addr))
 174                        error(1, errno, "inet_pton failed: %s", IP6_ADDR);
 175                daddr6.sin6_port = htons(PORT);
 176
 177                sz = sizeof(saddr6);
 178                saddr = (struct sockaddr *)&saddr6;
 179                daddr = (struct sockaddr *)&daddr6;
 180                break;
 181        default:
 182                error(1, 0, "Unsupported family %d", family);
 183                /* clang does not recognize error() above as terminating
 184                 * the program, so it complains that saddr, daddr, sz are
 185                 * not initialized when this code path is taken. Silence it.
 186                 */
 187                return -1;
 188        }
 189        fd = socket(family, proto, 0);
 190        if (fd < 0)
 191                error(1, errno, "failed to create send socket");
 192        if (bind(fd, saddr, sz))
 193                error(1, errno, "failed to bind send socket");
 194        data[0] = 'a';
 195        ret = sendto(fd, data, 1, MSG_FASTOPEN, daddr, sz);
 196        if (ret != 1)
 197                error(1, errno, "failed to sendto");
 198
 199        return fd;
 200}
 201
 202static bool is_listen_fd(int fd)
 203{
 204        int i;
 205
 206        for (i = 0; i < N_LISTEN; i++) {
 207                if (rcv_fds[i] == fd)
 208                        return true;
 209        }
 210        return false;
 211}
 212
 213static void rotate_key(int fd)
 214{
 215        static int iter;
 216        static uint32_t new_key[4];
 217        uint32_t keys[8];
 218        uint32_t tmp_key[4];
 219        int i;
 220
 221        if (iter < N_LISTEN) {
 222                /* first set new key as backups */
 223                if (iter == 0) {
 224                        for (i = 0; i < ARRAY_SIZE(new_key); i++)
 225                                new_key[i] = rand();
 226                }
 227                get_keys(fd, keys);
 228                memcpy(keys + 4, new_key, KEY_LENGTH);
 229                set_keys(fd, keys);
 230        } else {
 231                /* swap the keys */
 232                get_keys(fd, keys);
 233                memcpy(tmp_key, keys + 4, KEY_LENGTH);
 234                memcpy(keys + 4, keys, KEY_LENGTH);
 235                memcpy(keys, tmp_key, KEY_LENGTH);
 236                set_keys(fd, keys);
 237        }
 238        if (++iter >= (N_LISTEN * 2))
 239                iter = 0;
 240}
 241
 242static void run_one_test(int family)
 243{
 244        struct epoll_event ev;
 245        int i, send_fd;
 246        int n_loops = 10000;
 247        int rotate_key_fd = 0;
 248        int key_rotate_interval = 50;
 249        int fd, epfd;
 250        char buf[1];
 251
 252        build_rcv_fd(family, SOCK_STREAM, rcv_fds);
 253        epfd = epoll_create(1);
 254        if (epfd < 0)
 255                error(1, errno, "failed to create epoll");
 256        ev.events = EPOLLIN;
 257        for (i = 0; i < N_LISTEN; i++) {
 258                ev.data.fd = rcv_fds[i];
 259                if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fds[i], &ev))
 260                        error(1, errno, "failed to register sock epoll");
 261        }
 262        while (n_loops--) {
 263                send_fd = connect_and_send(family, SOCK_STREAM);
 264                if (do_rotate && ((n_loops % key_rotate_interval) == 0)) {
 265                        rotate_key(rcv_fds[rotate_key_fd]);
 266                        if (++rotate_key_fd >= N_LISTEN)
 267                                rotate_key_fd = 0;
 268                }
 269                while (1) {
 270                        i = epoll_wait(epfd, &ev, 1, -1);
 271                        if (i < 0)
 272                                error(1, errno, "epoll_wait failed");
 273                        if (is_listen_fd(ev.data.fd)) {
 274                                fd = accept(ev.data.fd, NULL, NULL);
 275                                if (fd < 0)
 276                                        error(1, errno, "failed to accept");
 277                                ev.data.fd = fd;
 278                                if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev))
 279                                        error(1, errno, "failed epoll add");
 280                                continue;
 281                        }
 282                        i = recv(ev.data.fd, buf, sizeof(buf), 0);
 283                        if (i != 1)
 284                                error(1, errno, "failed recv data");
 285                        if (epoll_ctl(epfd, EPOLL_CTL_DEL, ev.data.fd, NULL))
 286                                error(1, errno, "failed epoll del");
 287                        close(ev.data.fd);
 288                        break;
 289                }
 290                close(send_fd);
 291        }
 292        for (i = 0; i < N_LISTEN; i++)
 293                close(rcv_fds[i]);
 294}
 295
 296static void parse_opts(int argc, char **argv)
 297{
 298        int c;
 299
 300        while ((c = getopt(argc, argv, "46sr")) != -1) {
 301                switch (c) {
 302                case '4':
 303                        do_ipv6 = false;
 304                        break;
 305                case '6':
 306                        do_ipv6 = true;
 307                        break;
 308                case 's':
 309                        do_sockopt = true;
 310                        break;
 311                case 'r':
 312                        do_rotate = true;
 313                        key_len = KEY_LENGTH * 2;
 314                        break;
 315                default:
 316                        error(1, 0, "%s: parse error", argv[0]);
 317                }
 318        }
 319}
 320
 321int main(int argc, char **argv)
 322{
 323        parse_opts(argc, argv);
 324        proc_fd = open(PROC_FASTOPEN_KEY, O_RDWR);
 325        if (proc_fd < 0)
 326                error(1, errno, "Unable to open %s", PROC_FASTOPEN_KEY);
 327        srand(time(NULL));
 328        if (do_ipv6)
 329                run_one_test(AF_INET6);
 330        else
 331                run_one_test(AF_INET);
 332        close(proc_fd);
 333        fprintf(stderr, "PASS\n");
 334        return 0;
 335}
 336