linux/tools/testing/selftests/net/ipsec.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * ipsec.c - Check xfrm on veth inside a net-ns.
   4 * Copyright (c) 2018 Dmitry Safonov
   5 */
   6
   7#define _GNU_SOURCE
   8
   9#include <arpa/inet.h>
  10#include <asm/types.h>
  11#include <errno.h>
  12#include <fcntl.h>
  13#include <limits.h>
  14#include <linux/limits.h>
  15#include <linux/netlink.h>
  16#include <linux/random.h>
  17#include <linux/rtnetlink.h>
  18#include <linux/veth.h>
  19#include <linux/xfrm.h>
  20#include <netinet/in.h>
  21#include <net/if.h>
  22#include <sched.h>
  23#include <stdbool.h>
  24#include <stdint.h>
  25#include <stdio.h>
  26#include <stdlib.h>
  27#include <string.h>
  28#include <sys/mman.h>
  29#include <sys/socket.h>
  30#include <sys/stat.h>
  31#include <sys/syscall.h>
  32#include <sys/types.h>
  33#include <sys/wait.h>
  34#include <time.h>
  35#include <unistd.h>
  36
  37#include "../kselftest.h"
  38
  39#define printk(fmt, ...)                                                \
  40        ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
  41
  42#define pr_err(fmt, ...)        printk(fmt ": %m", ##__VA_ARGS__)
  43
  44#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
  45#define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
  46
  47#define IPV4_STR_SZ     16      /* xxx.xxx.xxx.xxx is longest + \0 */
  48#define MAX_PAYLOAD     2048
  49#define XFRM_ALGO_KEY_BUF_SIZE  512
  50#define MAX_PROCESSES   (1 << 14) /* /16 mask divided by /30 subnets */
  51#define INADDR_A        ((in_addr_t) 0x0a000000) /* 10.0.0.0 */
  52#define INADDR_B        ((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
  53
  54/* /30 mask for one veth connection */
  55#define PREFIX_LEN      30
  56#define child_ip(nr)    (4*nr + 1)
  57#define grchild_ip(nr)  (4*nr + 2)
  58
  59#define VETH_FMT        "ktst-%d"
  60#define VETH_LEN        12
  61
  62static int nsfd_parent  = -1;
  63static int nsfd_childa  = -1;
  64static int nsfd_childb  = -1;
  65static long page_size;
  66
  67/*
  68 * ksft_cnt is static in kselftest, so isn't shared with children.
  69 * We have to send a test result back to parent and count there.
  70 * results_fd is a pipe with test feedback from children.
  71 */
  72static int results_fd[2];
  73
  74const unsigned int ping_delay_nsec      = 50 * 1000 * 1000;
  75const unsigned int ping_timeout         = 300;
  76const unsigned int ping_count           = 100;
  77const unsigned int ping_success         = 80;
  78
  79static void randomize_buffer(void *buf, size_t buflen)
  80{
  81        int *p = (int *)buf;
  82        size_t words = buflen / sizeof(int);
  83        size_t leftover = buflen % sizeof(int);
  84
  85        if (!buflen)
  86                return;
  87
  88        while (words--)
  89                *p++ = rand();
  90
  91        if (leftover) {
  92                int tmp = rand();
  93
  94                memcpy(buf + buflen - leftover, &tmp, leftover);
  95        }
  96
  97        return;
  98}
  99
 100static int unshare_open(void)
 101{
 102        const char *netns_path = "/proc/self/ns/net";
 103        int fd;
 104
 105        if (unshare(CLONE_NEWNET) != 0) {
 106                pr_err("unshare()");
 107                return -1;
 108        }
 109
 110        fd = open(netns_path, O_RDONLY);
 111        if (fd <= 0) {
 112                pr_err("open(%s)", netns_path);
 113                return -1;
 114        }
 115
 116        return fd;
 117}
 118
 119static int switch_ns(int fd)
 120{
 121        if (setns(fd, CLONE_NEWNET)) {
 122                pr_err("setns()");
 123                return -1;
 124        }
 125        return 0;
 126}
 127
 128/*
 129 * Running the test inside a new parent net namespace to bother less
 130 * about cleanup on error-path.
 131 */
 132static int init_namespaces(void)
 133{
 134        nsfd_parent = unshare_open();
 135        if (nsfd_parent <= 0)
 136                return -1;
 137
 138        nsfd_childa = unshare_open();
 139        if (nsfd_childa <= 0)
 140                return -1;
 141
 142        if (switch_ns(nsfd_parent))
 143                return -1;
 144
 145        nsfd_childb = unshare_open();
 146        if (nsfd_childb <= 0)
 147                return -1;
 148
 149        if (switch_ns(nsfd_parent))
 150                return -1;
 151        return 0;
 152}
 153
 154static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
 155{
 156        if (*sock > 0) {
 157                seq_nr++;
 158                return 0;
 159        }
 160
 161        *sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
 162        if (*sock <= 0) {
 163                pr_err("socket(AF_NETLINK)");
 164                return -1;
 165        }
 166
 167        randomize_buffer(seq_nr, sizeof(*seq_nr));
 168
 169        return 0;
 170}
 171
 172static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
 173{
 174        return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
 175}
 176
 177static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
 178                unsigned short rta_type, const void *payload, size_t size)
 179{
 180        /* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
 181        struct rtattr *attr = rtattr_hdr(nh);
 182        size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
 183
 184        if (req_sz < nl_size) {
 185                printk("req buf is too small: %zu < %zu", req_sz, nl_size);
 186                return -1;
 187        }
 188        nh->nlmsg_len = nl_size;
 189
 190        attr->rta_len = RTA_LENGTH(size);
 191        attr->rta_type = rta_type;
 192        memcpy(RTA_DATA(attr), payload, size);
 193
 194        return 0;
 195}
 196
 197static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
 198                unsigned short rta_type, const void *payload, size_t size)
 199{
 200        struct rtattr *ret = rtattr_hdr(nh);
 201
 202        if (rtattr_pack(nh, req_sz, rta_type, payload, size))
 203                return 0;
 204
 205        return ret;
 206}
 207
 208static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
 209                unsigned short rta_type)
 210{
 211        return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
 212}
 213
 214static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
 215{
 216        char *nlmsg_end = (char *)nh + nh->nlmsg_len;
 217
 218        attr->rta_len = nlmsg_end - (char *)attr;
 219}
 220
 221static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
 222                const char *peer, int ns)
 223{
 224        struct ifinfomsg pi;
 225        struct rtattr *peer_attr;
 226
 227        memset(&pi, 0, sizeof(pi));
 228        pi.ifi_family   = AF_UNSPEC;
 229        pi.ifi_change   = 0xFFFFFFFF;
 230
 231        peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
 232        if (!peer_attr)
 233                return -1;
 234
 235        if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
 236                return -1;
 237
 238        if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
 239                return -1;
 240
 241        rtattr_end(nh, peer_attr);
 242
 243        return 0;
 244}
 245
 246static int netlink_check_answer(int sock)
 247{
 248        struct nlmsgerror {
 249                struct nlmsghdr hdr;
 250                int error;
 251                struct nlmsghdr orig_msg;
 252        } answer;
 253
 254        if (recv(sock, &answer, sizeof(answer), 0) < 0) {
 255                pr_err("recv()");
 256                return -1;
 257        } else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
 258                printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
 259                return -1;
 260        } else if (answer.error) {
 261                printk("NLMSG_ERROR: %d: %s",
 262                        answer.error, strerror(-answer.error));
 263                return answer.error;
 264        }
 265
 266        return 0;
 267}
 268
 269static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
 270                const char *peerb, int ns_b)
 271{
 272        uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
 273        struct {
 274                struct nlmsghdr         nh;
 275                struct ifinfomsg        info;
 276                char                    attrbuf[MAX_PAYLOAD];
 277        } req;
 278        const char veth_type[] = "veth";
 279        struct rtattr *link_info, *info_data;
 280
 281        memset(&req, 0, sizeof(req));
 282        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
 283        req.nh.nlmsg_type       = RTM_NEWLINK;
 284        req.nh.nlmsg_flags      = flags;
 285        req.nh.nlmsg_seq        = seq;
 286        req.info.ifi_family     = AF_UNSPEC;
 287        req.info.ifi_change     = 0xFFFFFFFF;
 288
 289        if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
 290                return -1;
 291
 292        if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
 293                return -1;
 294
 295        link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
 296        if (!link_info)
 297                return -1;
 298
 299        if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
 300                return -1;
 301
 302        info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
 303        if (!info_data)
 304                return -1;
 305
 306        if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
 307                return -1;
 308
 309        rtattr_end(&req.nh, info_data);
 310        rtattr_end(&req.nh, link_info);
 311
 312        if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
 313                pr_err("send()");
 314                return -1;
 315        }
 316        return netlink_check_answer(sock);
 317}
 318
 319static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
 320                struct in_addr addr, uint8_t prefix)
 321{
 322        uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
 323        struct {
 324                struct nlmsghdr         nh;
 325                struct ifaddrmsg        info;
 326                char                    attrbuf[MAX_PAYLOAD];
 327        } req;
 328
 329        memset(&req, 0, sizeof(req));
 330        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
 331        req.nh.nlmsg_type       = RTM_NEWADDR;
 332        req.nh.nlmsg_flags      = flags;
 333        req.nh.nlmsg_seq        = seq;
 334        req.info.ifa_family     = AF_INET;
 335        req.info.ifa_prefixlen  = prefix;
 336        req.info.ifa_index      = if_nametoindex(intf);
 337
 338#ifdef DEBUG
 339        {
 340                char addr_str[IPV4_STR_SZ] = {};
 341
 342                strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
 343
 344                printk("ip addr set %s", addr_str);
 345        }
 346#endif
 347
 348        if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
 349                return -1;
 350
 351        if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
 352                return -1;
 353
 354        if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
 355                pr_err("send()");
 356                return -1;
 357        }
 358        return netlink_check_answer(sock);
 359}
 360
 361static int link_set_up(int sock, uint32_t seq, const char *intf)
 362{
 363        struct {
 364                struct nlmsghdr         nh;
 365                struct ifinfomsg        info;
 366                char                    attrbuf[MAX_PAYLOAD];
 367        } req;
 368
 369        memset(&req, 0, sizeof(req));
 370        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
 371        req.nh.nlmsg_type       = RTM_NEWLINK;
 372        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
 373        req.nh.nlmsg_seq        = seq;
 374        req.info.ifi_family     = AF_UNSPEC;
 375        req.info.ifi_change     = 0xFFFFFFFF;
 376        req.info.ifi_index      = if_nametoindex(intf);
 377        req.info.ifi_flags      = IFF_UP;
 378        req.info.ifi_change     = IFF_UP;
 379
 380        if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
 381                pr_err("send()");
 382                return -1;
 383        }
 384        return netlink_check_answer(sock);
 385}
 386
 387static int ip4_route_set(int sock, uint32_t seq, const char *intf,
 388                struct in_addr src, struct in_addr dst)
 389{
 390        struct {
 391                struct nlmsghdr nh;
 392                struct rtmsg    rt;
 393                char            attrbuf[MAX_PAYLOAD];
 394        } req;
 395        unsigned int index = if_nametoindex(intf);
 396
 397        memset(&req, 0, sizeof(req));
 398        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.rt));
 399        req.nh.nlmsg_type       = RTM_NEWROUTE;
 400        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
 401        req.nh.nlmsg_seq        = seq;
 402        req.rt.rtm_family       = AF_INET;
 403        req.rt.rtm_dst_len      = 32;
 404        req.rt.rtm_table        = RT_TABLE_MAIN;
 405        req.rt.rtm_protocol     = RTPROT_BOOT;
 406        req.rt.rtm_scope        = RT_SCOPE_LINK;
 407        req.rt.rtm_type         = RTN_UNICAST;
 408
 409        if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
 410                return -1;
 411
 412        if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
 413                return -1;
 414
 415        if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
 416                return -1;
 417
 418        if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
 419                pr_err("send()");
 420                return -1;
 421        }
 422
 423        return netlink_check_answer(sock);
 424}
 425
 426static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
 427                struct in_addr tunsrc, struct in_addr tundst)
 428{
 429        if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
 430                        tunsrc, PREFIX_LEN)) {
 431                printk("Failed to set ipv4 addr");
 432                return -1;
 433        }
 434
 435        if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
 436                printk("Failed to set ipv4 route");
 437                return -1;
 438        }
 439
 440        return 0;
 441}
 442
 443static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
 444{
 445        struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
 446        struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
 447        struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
 448        int route_sock = -1, ret = -1;
 449        uint32_t route_seq;
 450
 451        if (switch_ns(nsfd))
 452                return -1;
 453
 454        if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
 455                printk("Failed to open netlink route socket in child");
 456                return -1;
 457        }
 458
 459        if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
 460                printk("Failed to set ipv4 addr");
 461                goto err;
 462        }
 463
 464        if (link_set_up(route_sock, route_seq++, veth)) {
 465                printk("Failed to bring up %s", veth);
 466                goto err;
 467        }
 468
 469        if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
 470                printk("Failed to add tunnel route on %s", veth);
 471                goto err;
 472        }
 473        ret = 0;
 474
 475err:
 476        close(route_sock);
 477        return ret;
 478}
 479
 480#define ALGO_LEN        64
 481enum desc_type {
 482        CREATE_TUNNEL   = 0,
 483        ALLOCATE_SPI,
 484        MONITOR_ACQUIRE,
 485        EXPIRE_STATE,
 486        EXPIRE_POLICY,
 487        SPDINFO_ATTRS,
 488};
 489const char *desc_name[] = {
 490        "create tunnel",
 491        "alloc spi",
 492        "monitor acquire",
 493        "expire state",
 494        "expire policy",
 495        "spdinfo attributes",
 496        ""
 497};
 498struct xfrm_desc {
 499        enum desc_type  type;
 500        uint8_t         proto;
 501        char            a_algo[ALGO_LEN];
 502        char            e_algo[ALGO_LEN];
 503        char            c_algo[ALGO_LEN];
 504        char            ae_algo[ALGO_LEN];
 505        unsigned int    icv_len;
 506        /* unsigned key_len; */
 507};
 508
 509enum msg_type {
 510        MSG_ACK         = 0,
 511        MSG_EXIT,
 512        MSG_PING,
 513        MSG_XFRM_PREPARE,
 514        MSG_XFRM_ADD,
 515        MSG_XFRM_DEL,
 516        MSG_XFRM_CLEANUP,
 517};
 518
 519struct test_desc {
 520        enum msg_type type;
 521        union {
 522                struct {
 523                        in_addr_t reply_ip;
 524                        unsigned int port;
 525                } ping;
 526                struct xfrm_desc xfrm_desc;
 527        } body;
 528};
 529
 530struct test_result {
 531        struct xfrm_desc desc;
 532        unsigned int res;
 533};
 534
 535static void write_test_result(unsigned int res, struct xfrm_desc *d)
 536{
 537        struct test_result tr = {};
 538        ssize_t ret;
 539
 540        tr.desc = *d;
 541        tr.res = res;
 542
 543        ret = write(results_fd[1], &tr, sizeof(tr));
 544        if (ret != sizeof(tr))
 545                pr_err("Failed to write the result in pipe %zd", ret);
 546}
 547
 548static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
 549{
 550        ssize_t bytes = write(fd, msg, sizeof(*msg));
 551
 552        /* Make sure that write/read is atomic to a pipe */
 553        BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
 554
 555        if (bytes < 0) {
 556                pr_err("write()");
 557                if (exit_of_fail)
 558                        exit(KSFT_FAIL);
 559        }
 560        if (bytes != sizeof(*msg)) {
 561                pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
 562                if (exit_of_fail)
 563                        exit(KSFT_FAIL);
 564        }
 565}
 566
 567static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
 568{
 569        ssize_t bytes = read(fd, msg, sizeof(*msg));
 570
 571        if (bytes < 0) {
 572                pr_err("read()");
 573                if (exit_of_fail)
 574                        exit(KSFT_FAIL);
 575        }
 576        if (bytes != sizeof(*msg)) {
 577                pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
 578                if (exit_of_fail)
 579                        exit(KSFT_FAIL);
 580        }
 581}
 582
 583static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
 584                unsigned int *server_port, int sock[2])
 585{
 586        struct sockaddr_in server;
 587        struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
 588        socklen_t s_len = sizeof(server);
 589
 590        sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
 591        if (sock[0] < 0) {
 592                pr_err("socket()");
 593                return -1;
 594        }
 595
 596        server.sin_family       = AF_INET;
 597        server.sin_port         = 0;
 598        memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
 599
 600        if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
 601                pr_err("bind()");
 602                goto err_close_server;
 603        }
 604
 605        if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
 606                pr_err("getsockname()");
 607                goto err_close_server;
 608        }
 609
 610        *server_port = ntohs(server.sin_port);
 611
 612        if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
 613                pr_err("setsockopt()");
 614                goto err_close_server;
 615        }
 616
 617        sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
 618        if (sock[1] < 0) {
 619                pr_err("socket()");
 620                goto err_close_server;
 621        }
 622
 623        return 0;
 624
 625err_close_server:
 626        close(sock[0]);
 627        return -1;
 628}
 629
 630static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
 631                char *buf, size_t buf_len)
 632{
 633        struct sockaddr_in server;
 634        const struct sockaddr *dest_addr = (struct sockaddr *)&server;
 635        char *sock_buf[buf_len];
 636        ssize_t r_bytes, s_bytes;
 637
 638        server.sin_family       = AF_INET;
 639        server.sin_port         = htons(port);
 640        server.sin_addr.s_addr  = dest_ip;
 641
 642        s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
 643        if (s_bytes < 0) {
 644                pr_err("sendto()");
 645                return -1;
 646        } else if (s_bytes != buf_len) {
 647                printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
 648                return -1;
 649        }
 650
 651        r_bytes = recv(sock[0], sock_buf, buf_len, 0);
 652        if (r_bytes < 0) {
 653                if (errno != EAGAIN)
 654                        pr_err("recv()");
 655                return -1;
 656        } else if (r_bytes == 0) { /* EOF */
 657                printk("EOF on reply to ping");
 658                return -1;
 659        } else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
 660                printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
 661                return -1;
 662        }
 663
 664        return 0;
 665}
 666
 667static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
 668                char *buf, size_t buf_len)
 669{
 670        struct sockaddr_in server;
 671        const struct sockaddr *dest_addr = (struct sockaddr *)&server;
 672        char *sock_buf[buf_len];
 673        ssize_t r_bytes, s_bytes;
 674
 675        server.sin_family       = AF_INET;
 676        server.sin_port         = htons(port);
 677        server.sin_addr.s_addr  = dest_ip;
 678
 679        r_bytes = recv(sock[0], sock_buf, buf_len, 0);
 680        if (r_bytes < 0) {
 681                if (errno != EAGAIN)
 682                        pr_err("recv()");
 683                return -1;
 684        }
 685        if (r_bytes == 0) { /* EOF */
 686                printk("EOF on reply to ping");
 687                return -1;
 688        }
 689        if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
 690                printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
 691                return -1;
 692        }
 693
 694        s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
 695        if (s_bytes < 0) {
 696                pr_err("sendto()");
 697                return -1;
 698        } else if (s_bytes != buf_len) {
 699                printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
 700                return -1;
 701        }
 702
 703        return 0;
 704}
 705
 706typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
 707                char *buf, size_t buf_len);
 708static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
 709                bool init_side, int d_port, in_addr_t to, ping_f func)
 710{
 711        struct test_desc msg;
 712        unsigned int s_port, i, ping_succeeded = 0;
 713        int ping_sock[2];
 714        char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
 715
 716        if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
 717                printk("Failed to init ping");
 718                return -1;
 719        }
 720
 721        memset(&msg, 0, sizeof(msg));
 722        msg.type                = MSG_PING;
 723        msg.body.ping.port      = s_port;
 724        memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
 725
 726        write_msg(cmd_fd, &msg, 0);
 727        if (init_side) {
 728                /* The other end sends ip to ping */
 729                read_msg(cmd_fd, &msg, 0);
 730                if (msg.type != MSG_PING)
 731                        return -1;
 732                to = msg.body.ping.reply_ip;
 733                d_port = msg.body.ping.port;
 734        }
 735
 736        for (i = 0; i < ping_count ; i++) {
 737                struct timespec sleep_time = {
 738                        .tv_sec = 0,
 739                        .tv_nsec = ping_delay_nsec,
 740                };
 741
 742                ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
 743                nanosleep(&sleep_time, 0);
 744        }
 745
 746        close(ping_sock[0]);
 747        close(ping_sock[1]);
 748
 749        strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
 750        strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
 751
 752        if (ping_succeeded < ping_success) {
 753                printk("ping (%s) %s->%s failed %u/%u times",
 754                        init_side ? "send" : "reply", from_str, to_str,
 755                        ping_count - ping_succeeded, ping_count);
 756                return -1;
 757        }
 758
 759#ifdef DEBUG
 760        printk("ping (%s) %s->%s succeeded %u/%u times",
 761                init_side ? "send" : "reply", from_str, to_str,
 762                ping_succeeded, ping_count);
 763#endif
 764
 765        return 0;
 766}
 767
 768static int xfrm_fill_key(char *name, char *buf,
 769                size_t buf_len, unsigned int *key_len)
 770{
 771        /* TODO: use set/map instead */
 772        if (strncmp(name, "digest_null", ALGO_LEN) == 0)
 773                *key_len = 0;
 774        else if (strncmp(name, "ecb(cipher_null)", ALGO_LEN) == 0)
 775                *key_len = 0;
 776        else if (strncmp(name, "cbc(des)", ALGO_LEN) == 0)
 777                *key_len = 64;
 778        else if (strncmp(name, "hmac(md5)", ALGO_LEN) == 0)
 779                *key_len = 128;
 780        else if (strncmp(name, "cmac(aes)", ALGO_LEN) == 0)
 781                *key_len = 128;
 782        else if (strncmp(name, "xcbc(aes)", ALGO_LEN) == 0)
 783                *key_len = 128;
 784        else if (strncmp(name, "cbc(cast5)", ALGO_LEN) == 0)
 785                *key_len = 128;
 786        else if (strncmp(name, "cbc(serpent)", ALGO_LEN) == 0)
 787                *key_len = 128;
 788        else if (strncmp(name, "hmac(sha1)", ALGO_LEN) == 0)
 789                *key_len = 160;
 790        else if (strncmp(name, "hmac(rmd160)", ALGO_LEN) == 0)
 791                *key_len = 160;
 792        else if (strncmp(name, "cbc(des3_ede)", ALGO_LEN) == 0)
 793                *key_len = 192;
 794        else if (strncmp(name, "hmac(sha256)", ALGO_LEN) == 0)
 795                *key_len = 256;
 796        else if (strncmp(name, "cbc(aes)", ALGO_LEN) == 0)
 797                *key_len = 256;
 798        else if (strncmp(name, "cbc(camellia)", ALGO_LEN) == 0)
 799                *key_len = 256;
 800        else if (strncmp(name, "cbc(twofish)", ALGO_LEN) == 0)
 801                *key_len = 256;
 802        else if (strncmp(name, "rfc3686(ctr(aes))", ALGO_LEN) == 0)
 803                *key_len = 288;
 804        else if (strncmp(name, "hmac(sha384)", ALGO_LEN) == 0)
 805                *key_len = 384;
 806        else if (strncmp(name, "cbc(blowfish)", ALGO_LEN) == 0)
 807                *key_len = 448;
 808        else if (strncmp(name, "hmac(sha512)", ALGO_LEN) == 0)
 809                *key_len = 512;
 810        else if (strncmp(name, "rfc4106(gcm(aes))-128", ALGO_LEN) == 0)
 811                *key_len = 160;
 812        else if (strncmp(name, "rfc4543(gcm(aes))-128", ALGO_LEN) == 0)
 813                *key_len = 160;
 814        else if (strncmp(name, "rfc4309(ccm(aes))-128", ALGO_LEN) == 0)
 815                *key_len = 152;
 816        else if (strncmp(name, "rfc4106(gcm(aes))-192", ALGO_LEN) == 0)
 817                *key_len = 224;
 818        else if (strncmp(name, "rfc4543(gcm(aes))-192", ALGO_LEN) == 0)
 819                *key_len = 224;
 820        else if (strncmp(name, "rfc4309(ccm(aes))-192", ALGO_LEN) == 0)
 821                *key_len = 216;
 822        else if (strncmp(name, "rfc4106(gcm(aes))-256", ALGO_LEN) == 0)
 823                *key_len = 288;
 824        else if (strncmp(name, "rfc4543(gcm(aes))-256", ALGO_LEN) == 0)
 825                *key_len = 288;
 826        else if (strncmp(name, "rfc4309(ccm(aes))-256", ALGO_LEN) == 0)
 827                *key_len = 280;
 828        else if (strncmp(name, "rfc7539(chacha20,poly1305)-128", ALGO_LEN) == 0)
 829                *key_len = 0;
 830
 831        if (*key_len > buf_len) {
 832                printk("Can't pack a key - too big for buffer");
 833                return -1;
 834        }
 835
 836        randomize_buffer(buf, *key_len);
 837
 838        return 0;
 839}
 840
 841static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
 842                struct xfrm_desc *desc)
 843{
 844        struct {
 845                union {
 846                        struct xfrm_algo        alg;
 847                        struct xfrm_algo_aead   aead;
 848                        struct xfrm_algo_auth   auth;
 849                } u;
 850                char buf[XFRM_ALGO_KEY_BUF_SIZE];
 851        } alg = {};
 852        size_t alen, elen, clen, aelen;
 853        unsigned short type;
 854
 855        alen = strlen(desc->a_algo);
 856        elen = strlen(desc->e_algo);
 857        clen = strlen(desc->c_algo);
 858        aelen = strlen(desc->ae_algo);
 859
 860        /* Verify desc */
 861        switch (desc->proto) {
 862        case IPPROTO_AH:
 863                if (!alen || elen || clen || aelen) {
 864                        printk("BUG: buggy ah desc");
 865                        return -1;
 866                }
 867                strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
 868                if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
 869                                sizeof(alg.buf), &alg.u.alg.alg_key_len))
 870                        return -1;
 871                type = XFRMA_ALG_AUTH;
 872                break;
 873        case IPPROTO_COMP:
 874                if (!clen || elen || alen || aelen) {
 875                        printk("BUG: buggy comp desc");
 876                        return -1;
 877                }
 878                strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
 879                if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
 880                                sizeof(alg.buf), &alg.u.alg.alg_key_len))
 881                        return -1;
 882                type = XFRMA_ALG_COMP;
 883                break;
 884        case IPPROTO_ESP:
 885                if (!((alen && elen) ^ aelen) || clen) {
 886                        printk("BUG: buggy esp desc");
 887                        return -1;
 888                }
 889                if (aelen) {
 890                        alg.u.aead.alg_icv_len = desc->icv_len;
 891                        strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
 892                        if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
 893                                                sizeof(alg.buf), &alg.u.aead.alg_key_len))
 894                                return -1;
 895                        type = XFRMA_ALG_AEAD;
 896                } else {
 897
 898                        strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
 899                        type = XFRMA_ALG_CRYPT;
 900                        if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
 901                                                sizeof(alg.buf), &alg.u.alg.alg_key_len))
 902                                return -1;
 903                        if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
 904                                return -1;
 905
 906                        strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
 907                        type = XFRMA_ALG_AUTH;
 908                        if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
 909                                                sizeof(alg.buf), &alg.u.alg.alg_key_len))
 910                                return -1;
 911                }
 912                break;
 913        default:
 914                printk("BUG: unknown proto in desc");
 915                return -1;
 916        }
 917
 918        if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
 919                return -1;
 920
 921        return 0;
 922}
 923
 924static inline uint32_t gen_spi(struct in_addr src)
 925{
 926        return htonl(inet_lnaof(src));
 927}
 928
 929static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
 930                struct in_addr src, struct in_addr dst,
 931                struct xfrm_desc *desc)
 932{
 933        struct {
 934                struct nlmsghdr         nh;
 935                struct xfrm_usersa_info info;
 936                char                    attrbuf[MAX_PAYLOAD];
 937        } req;
 938
 939        memset(&req, 0, sizeof(req));
 940        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
 941        req.nh.nlmsg_type       = XFRM_MSG_NEWSA;
 942        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
 943        req.nh.nlmsg_seq        = seq;
 944
 945        /* Fill selector. */
 946        memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
 947        memcpy(&req.info.sel.saddr, &src, sizeof(src));
 948        req.info.sel.family             = AF_INET;
 949        req.info.sel.prefixlen_d        = PREFIX_LEN;
 950        req.info.sel.prefixlen_s        = PREFIX_LEN;
 951
 952        /* Fill id */
 953        memcpy(&req.info.id.daddr, &dst, sizeof(dst));
 954        /* Note: zero-spi cannot be deleted */
 955        req.info.id.spi = spi;
 956        req.info.id.proto       = desc->proto;
 957
 958        memcpy(&req.info.saddr, &src, sizeof(src));
 959
 960        /* Fill lifteme_cfg */
 961        req.info.lft.soft_byte_limit    = XFRM_INF;
 962        req.info.lft.hard_byte_limit    = XFRM_INF;
 963        req.info.lft.soft_packet_limit  = XFRM_INF;
 964        req.info.lft.hard_packet_limit  = XFRM_INF;
 965
 966        req.info.family         = AF_INET;
 967        req.info.mode           = XFRM_MODE_TUNNEL;
 968
 969        if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
 970                return -1;
 971
 972        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
 973                pr_err("send()");
 974                return -1;
 975        }
 976
 977        return netlink_check_answer(xfrm_sock);
 978}
 979
 980static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
 981                struct in_addr src, struct in_addr dst,
 982                struct xfrm_desc *desc)
 983{
 984        if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
 985                return false;
 986
 987        if (memcmp(&info->sel.saddr, &src, sizeof(src)))
 988                return false;
 989
 990        if (info->sel.family != AF_INET                                 ||
 991                        info->sel.prefixlen_d != PREFIX_LEN             ||
 992                        info->sel.prefixlen_s != PREFIX_LEN)
 993                return false;
 994
 995        if (info->id.spi != spi || info->id.proto != desc->proto)
 996                return false;
 997
 998        if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
 999                return false;
1000
1001        if (memcmp(&info->saddr, &src, sizeof(src)))
1002                return false;
1003
1004        if (info->lft.soft_byte_limit != XFRM_INF                       ||
1005                        info->lft.hard_byte_limit != XFRM_INF           ||
1006                        info->lft.soft_packet_limit != XFRM_INF         ||
1007                        info->lft.hard_packet_limit != XFRM_INF)
1008                return false;
1009
1010        if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
1011                return false;
1012
1013        /* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
1014
1015        return true;
1016}
1017
1018static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1019                struct in_addr src, struct in_addr dst,
1020                struct xfrm_desc *desc)
1021{
1022        struct {
1023                struct nlmsghdr         nh;
1024                char                    attrbuf[MAX_PAYLOAD];
1025        } req;
1026        struct {
1027                struct nlmsghdr         nh;
1028                union {
1029                        struct xfrm_usersa_info info;
1030                        int error;
1031                };
1032                char                    attrbuf[MAX_PAYLOAD];
1033        } answer;
1034        struct xfrm_address_filter filter = {};
1035        bool found = false;
1036
1037
1038        memset(&req, 0, sizeof(req));
1039        req.nh.nlmsg_len        = NLMSG_LENGTH(0);
1040        req.nh.nlmsg_type       = XFRM_MSG_GETSA;
1041        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_DUMP;
1042        req.nh.nlmsg_seq        = seq;
1043
1044        /*
1045         * Add dump filter by source address as there may be other tunnels
1046         * in this netns (if tests run in parallel).
1047         */
1048        filter.family = AF_INET;
1049        filter.splen = 0x1f;    /* 0xffffffff mask see addr_match() */
1050        memcpy(&filter.saddr, &src, sizeof(src));
1051        if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1052                                &filter, sizeof(filter)))
1053                return -1;
1054
1055        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1056                pr_err("send()");
1057                return -1;
1058        }
1059
1060        while (1) {
1061                if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1062                        pr_err("recv()");
1063                        return -1;
1064                }
1065                if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1066                        printk("NLMSG_ERROR: %d: %s",
1067                                answer.error, strerror(-answer.error));
1068                        return -1;
1069                } else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1070                        if (found)
1071                                return 0;
1072                        printk("didn't find allocated xfrm state in dump");
1073                        return -1;
1074                } else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1075                        if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1076                                found = true;
1077                }
1078        }
1079}
1080
1081static int xfrm_set(int xfrm_sock, uint32_t *seq,
1082                struct in_addr src, struct in_addr dst,
1083                struct in_addr tunsrc, struct in_addr tundst,
1084                struct xfrm_desc *desc)
1085{
1086        int err;
1087
1088        err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1089        if (err) {
1090                printk("Failed to add xfrm state");
1091                return -1;
1092        }
1093
1094        err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1095        if (err) {
1096                printk("Failed to add xfrm state");
1097                return -1;
1098        }
1099
1100        /* Check dumps for XFRM_MSG_GETSA */
1101        err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1102        err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1103        if (err) {
1104                printk("Failed to check xfrm state");
1105                return -1;
1106        }
1107
1108        return 0;
1109}
1110
1111static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1112                struct in_addr src, struct in_addr dst, uint8_t dir,
1113                struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1114{
1115        struct {
1116                struct nlmsghdr                 nh;
1117                struct xfrm_userpolicy_info     info;
1118                char                            attrbuf[MAX_PAYLOAD];
1119        } req;
1120        struct xfrm_user_tmpl tmpl;
1121
1122        memset(&req, 0, sizeof(req));
1123        memset(&tmpl, 0, sizeof(tmpl));
1124        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
1125        req.nh.nlmsg_type       = XFRM_MSG_NEWPOLICY;
1126        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1127        req.nh.nlmsg_seq        = seq;
1128
1129        /* Fill selector. */
1130        memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1131        memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1132        req.info.sel.family             = AF_INET;
1133        req.info.sel.prefixlen_d        = PREFIX_LEN;
1134        req.info.sel.prefixlen_s        = PREFIX_LEN;
1135
1136        /* Fill lifteme_cfg */
1137        req.info.lft.soft_byte_limit    = XFRM_INF;
1138        req.info.lft.hard_byte_limit    = XFRM_INF;
1139        req.info.lft.soft_packet_limit  = XFRM_INF;
1140        req.info.lft.hard_packet_limit  = XFRM_INF;
1141
1142        req.info.dir = dir;
1143
1144        /* Fill tmpl */
1145        memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1146        /* Note: zero-spi cannot be deleted */
1147        tmpl.id.spi = spi;
1148        tmpl.id.proto   = proto;
1149        tmpl.family     = AF_INET;
1150        memcpy(&tmpl.saddr, &src, sizeof(src));
1151        tmpl.mode       = XFRM_MODE_TUNNEL;
1152        tmpl.aalgos = (~(uint32_t)0);
1153        tmpl.ealgos = (~(uint32_t)0);
1154        tmpl.calgos = (~(uint32_t)0);
1155
1156        if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1157                return -1;
1158
1159        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1160                pr_err("send()");
1161                return -1;
1162        }
1163
1164        return netlink_check_answer(xfrm_sock);
1165}
1166
1167static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1168                struct in_addr src, struct in_addr dst,
1169                struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1170{
1171        if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1172                                XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1173                printk("Failed to add xfrm policy");
1174                return -1;
1175        }
1176
1177        if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1178                                XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1179                printk("Failed to add xfrm policy");
1180                return -1;
1181        }
1182
1183        return 0;
1184}
1185
1186static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1187                struct in_addr src, struct in_addr dst, uint8_t dir,
1188                struct in_addr tunsrc, struct in_addr tundst)
1189{
1190        struct {
1191                struct nlmsghdr                 nh;
1192                struct xfrm_userpolicy_id       id;
1193                char                            attrbuf[MAX_PAYLOAD];
1194        } req;
1195
1196        memset(&req, 0, sizeof(req));
1197        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.id));
1198        req.nh.nlmsg_type       = XFRM_MSG_DELPOLICY;
1199        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1200        req.nh.nlmsg_seq        = seq;
1201
1202        /* Fill id */
1203        memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1204        memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1205        req.id.sel.family               = AF_INET;
1206        req.id.sel.prefixlen_d          = PREFIX_LEN;
1207        req.id.sel.prefixlen_s          = PREFIX_LEN;
1208        req.id.dir = dir;
1209
1210        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1211                pr_err("send()");
1212                return -1;
1213        }
1214
1215        return netlink_check_answer(xfrm_sock);
1216}
1217
1218static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1219                struct in_addr src, struct in_addr dst,
1220                struct in_addr tunsrc, struct in_addr tundst)
1221{
1222        if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1223                                XFRM_POLICY_OUT, tunsrc, tundst)) {
1224                printk("Failed to add xfrm policy");
1225                return -1;
1226        }
1227
1228        if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1229                                XFRM_POLICY_IN, tunsrc, tundst)) {
1230                printk("Failed to add xfrm policy");
1231                return -1;
1232        }
1233
1234        return 0;
1235}
1236
1237static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1238                struct in_addr src, struct in_addr dst, uint8_t proto)
1239{
1240        struct {
1241                struct nlmsghdr         nh;
1242                struct xfrm_usersa_id   id;
1243                char                    attrbuf[MAX_PAYLOAD];
1244        } req;
1245        xfrm_address_t saddr = {};
1246
1247        memset(&req, 0, sizeof(req));
1248        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.id));
1249        req.nh.nlmsg_type       = XFRM_MSG_DELSA;
1250        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1251        req.nh.nlmsg_seq        = seq;
1252
1253        memcpy(&req.id.daddr, &dst, sizeof(dst));
1254        req.id.family           = AF_INET;
1255        req.id.proto            = proto;
1256        /* Note: zero-spi cannot be deleted */
1257        req.id.spi = spi;
1258
1259        memcpy(&saddr, &src, sizeof(src));
1260        if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1261                return -1;
1262
1263        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1264                pr_err("send()");
1265                return -1;
1266        }
1267
1268        return netlink_check_answer(xfrm_sock);
1269}
1270
1271static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1272                struct in_addr src, struct in_addr dst,
1273                struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1274{
1275        if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1276                printk("Failed to remove xfrm state");
1277                return -1;
1278        }
1279
1280        if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1281                printk("Failed to remove xfrm state");
1282                return -1;
1283        }
1284
1285        return 0;
1286}
1287
1288static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1289                uint32_t spi, uint8_t proto)
1290{
1291        struct {
1292                struct nlmsghdr                 nh;
1293                struct xfrm_userspi_info        spi;
1294        } req;
1295        struct {
1296                struct nlmsghdr                 nh;
1297                union {
1298                        struct xfrm_usersa_info info;
1299                        int error;
1300                };
1301        } answer;
1302
1303        memset(&req, 0, sizeof(req));
1304        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.spi));
1305        req.nh.nlmsg_type       = XFRM_MSG_ALLOCSPI;
1306        req.nh.nlmsg_flags      = NLM_F_REQUEST;
1307        req.nh.nlmsg_seq        = (*seq)++;
1308
1309        req.spi.info.family     = AF_INET;
1310        req.spi.min             = spi;
1311        req.spi.max             = spi;
1312        req.spi.info.id.proto   = proto;
1313
1314        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1315                pr_err("send()");
1316                return KSFT_FAIL;
1317        }
1318
1319        if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1320                pr_err("recv()");
1321                return KSFT_FAIL;
1322        } else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1323                uint32_t new_spi = htonl(answer.info.id.spi);
1324
1325                if (new_spi != spi) {
1326                        printk("allocated spi is different from requested: %#x != %#x",
1327                                        new_spi, spi);
1328                        return KSFT_FAIL;
1329                }
1330                return KSFT_PASS;
1331        } else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1332                printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1333                return KSFT_FAIL;
1334        }
1335
1336        printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1337        return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1338}
1339
1340static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1341{
1342        struct sockaddr_nl snl = {};
1343        socklen_t addr_len;
1344        int ret = -1;
1345
1346        snl.nl_family = AF_NETLINK;
1347        snl.nl_groups = groups;
1348
1349        if (netlink_sock(sock, seq, proto)) {
1350                printk("Failed to open xfrm netlink socket");
1351                return -1;
1352        }
1353
1354        if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1355                pr_err("bind()");
1356                goto out_close;
1357        }
1358
1359        addr_len = sizeof(snl);
1360        if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1361                pr_err("getsockname()");
1362                goto out_close;
1363        }
1364        if (addr_len != sizeof(snl)) {
1365                printk("Wrong address length %d", addr_len);
1366                goto out_close;
1367        }
1368        if (snl.nl_family != AF_NETLINK) {
1369                printk("Wrong address family %d", snl.nl_family);
1370                goto out_close;
1371        }
1372        return 0;
1373
1374out_close:
1375        close(*sock);
1376        return ret;
1377}
1378
1379static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1380{
1381        struct {
1382                struct nlmsghdr nh;
1383                union {
1384                        struct xfrm_user_acquire acq;
1385                        int error;
1386                };
1387                char attrbuf[MAX_PAYLOAD];
1388        } req;
1389        struct xfrm_user_tmpl xfrm_tmpl = {};
1390        int xfrm_listen = -1, ret = KSFT_FAIL;
1391        uint32_t seq_listen;
1392
1393        if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1394                return KSFT_FAIL;
1395
1396        memset(&req, 0, sizeof(req));
1397        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.acq));
1398        req.nh.nlmsg_type       = XFRM_MSG_ACQUIRE;
1399        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1400        req.nh.nlmsg_seq        = (*seq)++;
1401
1402        req.acq.policy.sel.family       = AF_INET;
1403        req.acq.aalgos  = 0xfeed;
1404        req.acq.ealgos  = 0xbaad;
1405        req.acq.calgos  = 0xbabe;
1406
1407        xfrm_tmpl.family = AF_INET;
1408        xfrm_tmpl.id.proto = IPPROTO_ESP;
1409        if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1410                goto out_close;
1411
1412        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1413                pr_err("send()");
1414                goto out_close;
1415        }
1416
1417        if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1418                pr_err("recv()");
1419                goto out_close;
1420        } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1421                printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1422                goto out_close;
1423        }
1424
1425        if (req.error) {
1426                printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1427                ret = req.error;
1428                goto out_close;
1429        }
1430
1431        if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1432                pr_err("recv()");
1433                goto out_close;
1434        }
1435
1436        if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1437                        || req.acq.calgos != 0xbabe) {
1438                printk("xfrm_user_acquire has changed  %x %x %x",
1439                                req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1440                goto out_close;
1441        }
1442
1443        ret = KSFT_PASS;
1444out_close:
1445        close(xfrm_listen);
1446        return ret;
1447}
1448
1449static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1450                unsigned int nr, struct xfrm_desc *desc)
1451{
1452        struct {
1453                struct nlmsghdr nh;
1454                union {
1455                        struct xfrm_user_expire expire;
1456                        int error;
1457                };
1458        } req;
1459        struct in_addr src, dst;
1460        int xfrm_listen = -1, ret = KSFT_FAIL;
1461        uint32_t seq_listen;
1462
1463        src = inet_makeaddr(INADDR_B, child_ip(nr));
1464        dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1465
1466        if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1467                printk("Failed to add xfrm state");
1468                return KSFT_FAIL;
1469        }
1470
1471        if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1472                return KSFT_FAIL;
1473
1474        memset(&req, 0, sizeof(req));
1475        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.expire));
1476        req.nh.nlmsg_type       = XFRM_MSG_EXPIRE;
1477        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1478        req.nh.nlmsg_seq        = (*seq)++;
1479
1480        memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1481        req.expire.state.id.spi         = gen_spi(src);
1482        req.expire.state.id.proto       = desc->proto;
1483        req.expire.state.family         = AF_INET;
1484        req.expire.hard                 = 0xff;
1485
1486        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1487                pr_err("send()");
1488                goto out_close;
1489        }
1490
1491        if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1492                pr_err("recv()");
1493                goto out_close;
1494        } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1495                printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1496                goto out_close;
1497        }
1498
1499        if (req.error) {
1500                printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1501                ret = req.error;
1502                goto out_close;
1503        }
1504
1505        if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1506                pr_err("recv()");
1507                goto out_close;
1508        }
1509
1510        if (req.expire.hard != 0x1) {
1511                printk("expire.hard is not set: %x", req.expire.hard);
1512                goto out_close;
1513        }
1514
1515        ret = KSFT_PASS;
1516out_close:
1517        close(xfrm_listen);
1518        return ret;
1519}
1520
1521static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1522                unsigned int nr, struct xfrm_desc *desc)
1523{
1524        struct {
1525                struct nlmsghdr nh;
1526                union {
1527                        struct xfrm_user_polexpire expire;
1528                        int error;
1529                };
1530        } req;
1531        struct in_addr src, dst, tunsrc, tundst;
1532        int xfrm_listen = -1, ret = KSFT_FAIL;
1533        uint32_t seq_listen;
1534
1535        src = inet_makeaddr(INADDR_B, child_ip(nr));
1536        dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1537        tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1538        tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1539
1540        if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1541                                XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1542                printk("Failed to add xfrm policy");
1543                return KSFT_FAIL;
1544        }
1545
1546        if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1547                return KSFT_FAIL;
1548
1549        memset(&req, 0, sizeof(req));
1550        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.expire));
1551        req.nh.nlmsg_type       = XFRM_MSG_POLEXPIRE;
1552        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1553        req.nh.nlmsg_seq        = (*seq)++;
1554
1555        /* Fill selector. */
1556        memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1557        memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1558        req.expire.pol.sel.family       = AF_INET;
1559        req.expire.pol.sel.prefixlen_d  = PREFIX_LEN;
1560        req.expire.pol.sel.prefixlen_s  = PREFIX_LEN;
1561        req.expire.pol.dir              = XFRM_POLICY_OUT;
1562        req.expire.hard                 = 0xff;
1563
1564        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1565                pr_err("send()");
1566                goto out_close;
1567        }
1568
1569        if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1570                pr_err("recv()");
1571                goto out_close;
1572        } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1573                printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1574                goto out_close;
1575        }
1576
1577        if (req.error) {
1578                printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1579                ret = req.error;
1580                goto out_close;
1581        }
1582
1583        if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1584                pr_err("recv()");
1585                goto out_close;
1586        }
1587
1588        if (req.expire.hard != 0x1) {
1589                printk("expire.hard is not set: %x", req.expire.hard);
1590                goto out_close;
1591        }
1592
1593        ret = KSFT_PASS;
1594out_close:
1595        close(xfrm_listen);
1596        return ret;
1597}
1598
1599static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
1600                unsigned thresh4_l, unsigned thresh4_r,
1601                unsigned thresh6_l, unsigned thresh6_r,
1602                bool add_bad_attr)
1603
1604{
1605        struct {
1606                struct nlmsghdr         nh;
1607                union {
1608                        uint32_t        unused;
1609                        int             error;
1610                };
1611                char                    attrbuf[MAX_PAYLOAD];
1612        } req;
1613        struct xfrmu_spdhthresh thresh;
1614
1615        memset(&req, 0, sizeof(req));
1616        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.unused));
1617        req.nh.nlmsg_type       = XFRM_MSG_NEWSPDINFO;
1618        req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1619        req.nh.nlmsg_seq        = (*seq)++;
1620
1621        thresh.lbits = thresh4_l;
1622        thresh.rbits = thresh4_r;
1623        if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
1624                return -1;
1625
1626        thresh.lbits = thresh6_l;
1627        thresh.rbits = thresh6_r;
1628        if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
1629                return -1;
1630
1631        if (add_bad_attr) {
1632                BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
1633                if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
1634                        pr_err("adding attribute failed: no space");
1635                        return -1;
1636                }
1637        }
1638
1639        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1640                pr_err("send()");
1641                return -1;
1642        }
1643
1644        if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1645                pr_err("recv()");
1646                return -1;
1647        } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1648                printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1649                return -1;
1650        }
1651
1652        if (req.error) {
1653                printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1654                return -1;
1655        }
1656
1657        return 0;
1658}
1659
1660static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
1661{
1662        struct {
1663                struct nlmsghdr                 nh;
1664                union {
1665                        uint32_t        unused;
1666                        int             error;
1667                };
1668                char                    attrbuf[MAX_PAYLOAD];
1669        } req;
1670
1671        if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
1672                pr_err("Can't set SPD HTHRESH");
1673                return KSFT_FAIL;
1674        }
1675
1676        memset(&req, 0, sizeof(req));
1677
1678        req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.unused));
1679        req.nh.nlmsg_type       = XFRM_MSG_GETSPDINFO;
1680        req.nh.nlmsg_flags      = NLM_F_REQUEST;
1681        req.nh.nlmsg_seq        = (*seq)++;
1682        if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1683                pr_err("send()");
1684                return KSFT_FAIL;
1685        }
1686
1687        if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1688                pr_err("recv()");
1689                return KSFT_FAIL;
1690        } else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
1691                size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
1692                struct rtattr *attr = (void *)req.attrbuf;
1693                int got_thresh = 0;
1694
1695                for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
1696                        if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
1697                                struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1698
1699                                got_thresh++;
1700                                if (t->lbits != 32 || t->rbits != 31) {
1701                                        pr_err("thresh differ: %u, %u",
1702                                                        t->lbits, t->rbits);
1703                                        return KSFT_FAIL;
1704                                }
1705                        }
1706                        if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
1707                                struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1708
1709                                got_thresh++;
1710                                if (t->lbits != 120 || t->rbits != 16) {
1711                                        pr_err("thresh differ: %u, %u",
1712                                                        t->lbits, t->rbits);
1713                                        return KSFT_FAIL;
1714                                }
1715                        }
1716                }
1717                if (got_thresh != 2) {
1718                        pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
1719                        return KSFT_FAIL;
1720                }
1721        } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1722                printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1723                return KSFT_FAIL;
1724        } else {
1725                printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1726                return -1;
1727        }
1728
1729        /* Restore the default */
1730        if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
1731                pr_err("Can't restore SPD HTHRESH");
1732                return KSFT_FAIL;
1733        }
1734
1735        /*
1736         * At this moment xfrm uses nlmsg_parse_deprecated(), which
1737         * implies NL_VALIDATE_LIBERAL - ignoring attributes with
1738         * (type > maxtype). nla_parse_depricated_strict() would enforce
1739         * it. Or even stricter nla_parse().
1740         * Right now it's not expected to fail, but to be ignored.
1741         */
1742        if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
1743                return KSFT_PASS;
1744
1745        return KSFT_PASS;
1746}
1747
1748static int child_serv(int xfrm_sock, uint32_t *seq,
1749                unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1750{
1751        struct in_addr src, dst, tunsrc, tundst;
1752        struct test_desc msg;
1753        int ret = KSFT_FAIL;
1754
1755        src = inet_makeaddr(INADDR_B, child_ip(nr));
1756        dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1757        tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1758        tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1759
1760        /* UDP pinging without xfrm */
1761        if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1762                printk("ping failed before setting xfrm");
1763                return KSFT_FAIL;
1764        }
1765
1766        memset(&msg, 0, sizeof(msg));
1767        msg.type = MSG_XFRM_PREPARE;
1768        memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1769        write_msg(cmd_fd, &msg, 1);
1770
1771        if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1772                printk("failed to prepare xfrm");
1773                goto cleanup;
1774        }
1775
1776        memset(&msg, 0, sizeof(msg));
1777        msg.type = MSG_XFRM_ADD;
1778        memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1779        write_msg(cmd_fd, &msg, 1);
1780        if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1781                printk("failed to set xfrm");
1782                goto delete;
1783        }
1784
1785        /* UDP pinging with xfrm tunnel */
1786        if (do_ping(cmd_fd, buf, page_size, tunsrc,
1787                                true, 0, 0, udp_ping_send)) {
1788                printk("ping failed for xfrm");
1789                goto delete;
1790        }
1791
1792        ret = KSFT_PASS;
1793delete:
1794        /* xfrm delete */
1795        memset(&msg, 0, sizeof(msg));
1796        msg.type = MSG_XFRM_DEL;
1797        memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1798        write_msg(cmd_fd, &msg, 1);
1799
1800        if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1801                printk("failed ping to remove xfrm");
1802                ret = KSFT_FAIL;
1803        }
1804
1805cleanup:
1806        memset(&msg, 0, sizeof(msg));
1807        msg.type = MSG_XFRM_CLEANUP;
1808        memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1809        write_msg(cmd_fd, &msg, 1);
1810        if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1811                printk("failed ping to cleanup xfrm");
1812                ret = KSFT_FAIL;
1813        }
1814        return ret;
1815}
1816
1817static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1818{
1819        struct xfrm_desc desc;
1820        struct test_desc msg;
1821        int xfrm_sock = -1;
1822        uint32_t seq;
1823
1824        if (switch_ns(nsfd_childa))
1825                exit(KSFT_FAIL);
1826
1827        if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1828                printk("Failed to open xfrm netlink socket");
1829                exit(KSFT_FAIL);
1830        }
1831
1832        /* Check that seq sock is ready, just for sure. */
1833        memset(&msg, 0, sizeof(msg));
1834        msg.type = MSG_ACK;
1835        write_msg(cmd_fd, &msg, 1);
1836        read_msg(cmd_fd, &msg, 1);
1837        if (msg.type != MSG_ACK) {
1838                printk("Ack failed");
1839                exit(KSFT_FAIL);
1840        }
1841
1842        for (;;) {
1843                ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1844                int ret;
1845
1846                if (received == 0) /* EOF */
1847                        break;
1848
1849                if (received != sizeof(desc)) {
1850                        pr_err("read() returned %zd", received);
1851                        exit(KSFT_FAIL);
1852                }
1853
1854                switch (desc.type) {
1855                case CREATE_TUNNEL:
1856                        ret = child_serv(xfrm_sock, &seq, nr,
1857                                         cmd_fd, buf, &desc);
1858                        break;
1859                case ALLOCATE_SPI:
1860                        ret = xfrm_state_allocspi(xfrm_sock, &seq,
1861                                                  -1, desc.proto);
1862                        break;
1863                case MONITOR_ACQUIRE:
1864                        ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1865                        break;
1866                case EXPIRE_STATE:
1867                        ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1868                        break;
1869                case EXPIRE_POLICY:
1870                        ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1871                        break;
1872                case SPDINFO_ATTRS:
1873                        ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
1874                        break;
1875                default:
1876                        printk("Unknown desc type %d", desc.type);
1877                        exit(KSFT_FAIL);
1878                }
1879                write_test_result(ret, &desc);
1880        }
1881
1882        close(xfrm_sock);
1883
1884        msg.type = MSG_EXIT;
1885        write_msg(cmd_fd, &msg, 1);
1886        exit(KSFT_PASS);
1887}
1888
1889static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1890                struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1891{
1892        struct in_addr src, dst, tunsrc, tundst;
1893        bool tun_reply;
1894        struct xfrm_desc *desc = &msg->body.xfrm_desc;
1895
1896        src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1897        dst = inet_makeaddr(INADDR_B, child_ip(nr));
1898        tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1899        tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1900
1901        switch (msg->type) {
1902        case MSG_EXIT:
1903                exit(KSFT_PASS);
1904        case MSG_ACK:
1905                write_msg(cmd_fd, msg, 1);
1906                break;
1907        case MSG_PING:
1908                tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1909                /* UDP pinging without xfrm */
1910                if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1911                                false, msg->body.ping.port,
1912                                msg->body.ping.reply_ip, udp_ping_reply)) {
1913                        printk("ping failed before setting xfrm");
1914                }
1915                break;
1916        case MSG_XFRM_PREPARE:
1917                if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1918                                        desc->proto)) {
1919                        xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1920                        printk("failed to prepare xfrm");
1921                }
1922                break;
1923        case MSG_XFRM_ADD:
1924                if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1925                        xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1926                        printk("failed to set xfrm");
1927                }
1928                break;
1929        case MSG_XFRM_DEL:
1930                if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1931                                        desc->proto)) {
1932                        xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1933                        printk("failed to remove xfrm");
1934                }
1935                break;
1936        case MSG_XFRM_CLEANUP:
1937                if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1938                        printk("failed to cleanup xfrm");
1939                }
1940                break;
1941        default:
1942                printk("got unknown msg type %d", msg->type);
1943        }
1944}
1945
1946static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1947{
1948        struct test_desc msg;
1949        int xfrm_sock = -1;
1950        uint32_t seq;
1951
1952        if (switch_ns(nsfd_childb))
1953                exit(KSFT_FAIL);
1954
1955        if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1956                printk("Failed to open xfrm netlink socket");
1957                exit(KSFT_FAIL);
1958        }
1959
1960        do {
1961                read_msg(cmd_fd, &msg, 1);
1962                grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1963        } while (1);
1964
1965        close(xfrm_sock);
1966        exit(KSFT_FAIL);
1967}
1968
1969static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1970{
1971        int cmd_sock[2];
1972        void *data_map;
1973        pid_t child;
1974
1975        if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1976                return -1;
1977
1978        if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1979                return -1;
1980
1981        child = fork();
1982        if (child < 0) {
1983                pr_err("fork()");
1984                return -1;
1985        } else if (child) {
1986                /* in parent - selftest */
1987                return switch_ns(nsfd_parent);
1988        }
1989
1990        if (close(test_desc_fd[1])) {
1991                pr_err("close()");
1992                return -1;
1993        }
1994
1995        /* child */
1996        data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1997                        MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1998        if (data_map == MAP_FAILED) {
1999                pr_err("mmap()");
2000                return -1;
2001        }
2002
2003        randomize_buffer(data_map, page_size);
2004
2005        if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
2006                pr_err("socketpair()");
2007                return -1;
2008        }
2009
2010        child = fork();
2011        if (child < 0) {
2012                pr_err("fork()");
2013                return -1;
2014        } else if (child) {
2015                if (close(cmd_sock[0])) {
2016                        pr_err("close()");
2017                        return -1;
2018                }
2019                return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
2020        }
2021        if (close(cmd_sock[1])) {
2022                pr_err("close()");
2023                return -1;
2024        }
2025        return grand_child_f(nr, cmd_sock[0], data_map);
2026}
2027
2028static void exit_usage(char **argv)
2029{
2030        printk("Usage: %s [nr_process]", argv[0]);
2031        exit(KSFT_FAIL);
2032}
2033
2034static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
2035{
2036        ssize_t ret;
2037
2038        ret = write(test_desc_fd, desc, sizeof(*desc));
2039
2040        if (ret == sizeof(*desc))
2041                return 0;
2042
2043        pr_err("Writing test's desc failed %ld", ret);
2044
2045        return -1;
2046}
2047
2048static int write_desc(int proto, int test_desc_fd,
2049                char *a, char *e, char *c, char *ae)
2050{
2051        struct xfrm_desc desc = {};
2052
2053        desc.type = CREATE_TUNNEL;
2054        desc.proto = proto;
2055
2056        if (a)
2057                strncpy(desc.a_algo, a, ALGO_LEN - 1);
2058        if (e)
2059                strncpy(desc.e_algo, e, ALGO_LEN - 1);
2060        if (c)
2061                strncpy(desc.c_algo, c, ALGO_LEN - 1);
2062        if (ae)
2063                strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
2064
2065        return __write_desc(test_desc_fd, &desc);
2066}
2067
2068int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
2069char *ah_list[] = {
2070        "digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
2071        "hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
2072        "xcbc(aes)", "cmac(aes)"
2073};
2074char *comp_list[] = {
2075        "deflate",
2076#if 0
2077        /* No compression backend realization */
2078        "lzs", "lzjh"
2079#endif
2080};
2081char *e_list[] = {
2082        "ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
2083        "cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
2084        "cbc(twofish)", "rfc3686(ctr(aes))"
2085};
2086char *ae_list[] = {
2087#if 0
2088        /* not implemented */
2089        "rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
2090        "rfc7539esp(chacha20,poly1305)"
2091#endif
2092};
2093
2094const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
2095                                + (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
2096                                + ARRAY_SIZE(ae_list);
2097
2098static int write_proto_plan(int fd, int proto)
2099{
2100        unsigned int i;
2101
2102        switch (proto) {
2103        case IPPROTO_AH:
2104                for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2105                        if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
2106                                return -1;
2107                }
2108                break;
2109        case IPPROTO_COMP:
2110                for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
2111                        if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
2112                                return -1;
2113                }
2114                break;
2115        case IPPROTO_ESP:
2116                for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2117                        int j;
2118
2119                        for (j = 0; j < ARRAY_SIZE(e_list); j++) {
2120                                if (write_desc(proto, fd, ah_list[i],
2121                                                        e_list[j], 0, 0))
2122                                        return -1;
2123                        }
2124                }
2125                for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
2126                        if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
2127                                return -1;
2128                }
2129                break;
2130        default:
2131                printk("BUG: Specified unknown proto %d", proto);
2132                return -1;
2133        }
2134
2135        return 0;
2136}
2137
2138/*
2139 * Some structures in xfrm uapi header differ in size between
2140 * 64-bit and 32-bit ABI:
2141 *
2142 *             32-bit UABI               |            64-bit UABI
2143 *  -------------------------------------|-------------------------------------
2144 *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
2145 *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
2146 *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
2147 *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
2148 *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
2149 *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
2150 *
2151 * Check the affected by the UABI difference structures.
2152 * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
2153 * which needs to be correctly copied, but not translated.
2154 */
2155const unsigned int compat_plan = 5;
2156static int write_compat_struct_tests(int test_desc_fd)
2157{
2158        struct xfrm_desc desc = {};
2159
2160        desc.type = ALLOCATE_SPI;
2161        desc.proto = IPPROTO_AH;
2162        strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2163
2164        if (__write_desc(test_desc_fd, &desc))
2165                return -1;
2166
2167        desc.type = MONITOR_ACQUIRE;
2168        if (__write_desc(test_desc_fd, &desc))
2169                return -1;
2170
2171        desc.type = EXPIRE_STATE;
2172        if (__write_desc(test_desc_fd, &desc))
2173                return -1;
2174
2175        desc.type = EXPIRE_POLICY;
2176        if (__write_desc(test_desc_fd, &desc))
2177                return -1;
2178
2179        desc.type = SPDINFO_ATTRS;
2180        if (__write_desc(test_desc_fd, &desc))
2181                return -1;
2182
2183        return 0;
2184}
2185
2186static int write_test_plan(int test_desc_fd)
2187{
2188        unsigned int i;
2189        pid_t child;
2190
2191        child = fork();
2192        if (child < 0) {
2193                pr_err("fork()");
2194                return -1;
2195        }
2196        if (child) {
2197                if (close(test_desc_fd))
2198                        printk("close(): %m");
2199                return 0;
2200        }
2201
2202        if (write_compat_struct_tests(test_desc_fd))
2203                exit(KSFT_FAIL);
2204
2205        for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2206                if (write_proto_plan(test_desc_fd, proto_list[i]))
2207                        exit(KSFT_FAIL);
2208        }
2209
2210        exit(KSFT_PASS);
2211}
2212
2213static int children_cleanup(void)
2214{
2215        unsigned ret = KSFT_PASS;
2216
2217        while (1) {
2218                int status;
2219                pid_t p = wait(&status);
2220
2221                if ((p < 0) && errno == ECHILD)
2222                        break;
2223
2224                if (p < 0) {
2225                        pr_err("wait()");
2226                        return KSFT_FAIL;
2227                }
2228
2229                if (!WIFEXITED(status)) {
2230                        ret = KSFT_FAIL;
2231                        continue;
2232                }
2233
2234                if (WEXITSTATUS(status) == KSFT_FAIL)
2235                        ret = KSFT_FAIL;
2236        }
2237
2238        return ret;
2239}
2240
2241typedef void (*print_res)(const char *, ...);
2242
2243static int check_results(void)
2244{
2245        struct test_result tr = {};
2246        struct xfrm_desc *d = &tr.desc;
2247        int ret = KSFT_PASS;
2248
2249        while (1) {
2250                ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2251                print_res result;
2252
2253                if (received == 0) /* EOF */
2254                        break;
2255
2256                if (received != sizeof(tr)) {
2257                        pr_err("read() returned %zd", received);
2258                        return KSFT_FAIL;
2259                }
2260
2261                switch (tr.res) {
2262                case KSFT_PASS:
2263                        result = ksft_test_result_pass;
2264                        break;
2265                case KSFT_FAIL:
2266                default:
2267                        result = ksft_test_result_fail;
2268                        ret = KSFT_FAIL;
2269                }
2270
2271                result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2272                       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2273                       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2274        }
2275
2276        return ret;
2277}
2278
2279int main(int argc, char **argv)
2280{
2281        unsigned int nr_process = 1;
2282        int route_sock = -1, ret = KSFT_SKIP;
2283        int test_desc_fd[2];
2284        uint32_t route_seq;
2285        unsigned int i;
2286
2287        if (argc > 2)
2288                exit_usage(argv);
2289
2290        if (argc > 1) {
2291                char *endptr;
2292
2293                errno = 0;
2294                nr_process = strtol(argv[1], &endptr, 10);
2295                if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2296                                || (errno != 0 && nr_process == 0)
2297                                || (endptr == argv[1]) || (*endptr != '\0')) {
2298                        printk("Failed to parse [nr_process]");
2299                        exit_usage(argv);
2300                }
2301
2302                if (nr_process > MAX_PROCESSES || !nr_process) {
2303                        printk("nr_process should be between [1; %u]",
2304                                        MAX_PROCESSES);
2305                        exit_usage(argv);
2306                }
2307        }
2308
2309        srand(time(NULL));
2310        page_size = sysconf(_SC_PAGESIZE);
2311        if (page_size < 1)
2312                ksft_exit_skip("sysconf(): %m\n");
2313
2314        if (pipe2(test_desc_fd, O_DIRECT) < 0)
2315                ksft_exit_skip("pipe(): %m\n");
2316
2317        if (pipe2(results_fd, O_DIRECT) < 0)
2318                ksft_exit_skip("pipe(): %m\n");
2319
2320        if (init_namespaces())
2321                ksft_exit_skip("Failed to create namespaces\n");
2322
2323        if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2324                ksft_exit_skip("Failed to open netlink route socket\n");
2325
2326        for (i = 0; i < nr_process; i++) {
2327                char veth[VETH_LEN];
2328
2329                snprintf(veth, VETH_LEN, VETH_FMT, i);
2330
2331                if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2332                        close(route_sock);
2333                        ksft_exit_fail_msg("Failed to create veth device");
2334                }
2335
2336                if (start_child(i, veth, test_desc_fd)) {
2337                        close(route_sock);
2338                        ksft_exit_fail_msg("Child %u failed to start", i);
2339                }
2340        }
2341
2342        if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2343                ksft_exit_fail_msg("close(): %m");
2344
2345        ksft_set_plan(proto_plan + compat_plan);
2346
2347        if (write_test_plan(test_desc_fd[1]))
2348                ksft_exit_fail_msg("Failed to write test plan to pipe");
2349
2350        ret = check_results();
2351
2352        if (children_cleanup() == KSFT_FAIL)
2353                exit(KSFT_FAIL);
2354
2355        exit(ret);
2356}
2357