linux/tools/testing/selftests/net/psock_snd.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2
   3#define _GNU_SOURCE
   4
   5#include <arpa/inet.h>
   6#include <errno.h>
   7#include <error.h>
   8#include <fcntl.h>
   9#include <limits.h>
  10#include <linux/filter.h>
  11#include <linux/bpf.h>
  12#include <linux/if_packet.h>
  13#include <linux/if_vlan.h>
  14#include <linux/virtio_net.h>
  15#include <net/if.h>
  16#include <net/ethernet.h>
  17#include <netinet/ip.h>
  18#include <netinet/udp.h>
  19#include <poll.h>
  20#include <sched.h>
  21#include <stdbool.h>
  22#include <stdint.h>
  23#include <stdio.h>
  24#include <stdlib.h>
  25#include <string.h>
  26#include <sys/mman.h>
  27#include <sys/socket.h>
  28#include <sys/stat.h>
  29#include <sys/types.h>
  30#include <unistd.h>
  31
  32#include "psock_lib.h"
  33
  34static bool     cfg_use_bind;
  35static bool     cfg_use_csum_off;
  36static bool     cfg_use_csum_off_bad;
  37static bool     cfg_use_dgram;
  38static bool     cfg_use_gso;
  39static bool     cfg_use_qdisc_bypass;
  40static bool     cfg_use_vlan;
  41static bool     cfg_use_vnet;
  42
  43static char     *cfg_ifname = "lo";
  44static int      cfg_mtu = 1500;
  45static int      cfg_payload_len = DATA_LEN;
  46static int      cfg_truncate_len = INT_MAX;
  47static uint16_t cfg_port = 8000;
  48
  49/* test sending up to max mtu + 1 */
  50#define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
  51
  52static char tbuf[TEST_SZ], rbuf[TEST_SZ];
  53
  54static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
  55{
  56        unsigned long sum = 0;
  57        int i;
  58
  59        for (i = 0; i < num_u16; i++)
  60                sum += start[i];
  61
  62        return sum;
  63}
  64
  65static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
  66                              unsigned long sum)
  67{
  68        sum += add_csum_hword(start, num_u16);
  69
  70        while (sum >> 16)
  71                sum = (sum & 0xffff) + (sum >> 16);
  72
  73        return ~sum;
  74}
  75
  76static int build_vnet_header(void *header)
  77{
  78        struct virtio_net_hdr *vh = header;
  79
  80        vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
  81
  82        if (cfg_use_csum_off) {
  83                vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
  84                vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
  85                vh->csum_offset = __builtin_offsetof(struct udphdr, check);
  86
  87                /* position check field exactly one byte beyond end of packet */
  88                if (cfg_use_csum_off_bad)
  89                        vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
  90                                          vh->csum_offset - 1;
  91        }
  92
  93        if (cfg_use_gso) {
  94                vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
  95                vh->gso_size = cfg_mtu - sizeof(struct iphdr);
  96        }
  97
  98        return sizeof(*vh);
  99}
 100
 101static int build_eth_header(void *header)
 102{
 103        struct ethhdr *eth = header;
 104
 105        if (cfg_use_vlan) {
 106                uint16_t *tag = header + ETH_HLEN;
 107
 108                eth->h_proto = htons(ETH_P_8021Q);
 109                tag[1] = htons(ETH_P_IP);
 110                return ETH_HLEN + 4;
 111        }
 112
 113        eth->h_proto = htons(ETH_P_IP);
 114        return ETH_HLEN;
 115}
 116
 117static int build_ipv4_header(void *header, int payload_len)
 118{
 119        struct iphdr *iph = header;
 120
 121        iph->ihl = 5;
 122        iph->version = 4;
 123        iph->ttl = 8;
 124        iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
 125        iph->id = htons(1337);
 126        iph->protocol = IPPROTO_UDP;
 127        iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
 128        iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
 129        iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
 130
 131        return iph->ihl << 2;
 132}
 133
 134static int build_udp_header(void *header, int payload_len)
 135{
 136        const int alen = sizeof(uint32_t);
 137        struct udphdr *udph = header;
 138        int len = sizeof(*udph) + payload_len;
 139
 140        udph->source = htons(9);
 141        udph->dest = htons(cfg_port);
 142        udph->len = htons(len);
 143
 144        if (cfg_use_csum_off)
 145                udph->check = build_ip_csum(header - (2 * alen), alen,
 146                                            htons(IPPROTO_UDP) + udph->len);
 147        else
 148                udph->check = 0;
 149
 150        return sizeof(*udph);
 151}
 152
 153static int build_packet(int payload_len)
 154{
 155        int off = 0;
 156
 157        off += build_vnet_header(tbuf);
 158        off += build_eth_header(tbuf + off);
 159        off += build_ipv4_header(tbuf + off, payload_len);
 160        off += build_udp_header(tbuf + off, payload_len);
 161
 162        if (off + payload_len > sizeof(tbuf))
 163                error(1, 0, "payload length exceeds max");
 164
 165        memset(tbuf + off, DATA_CHAR, payload_len);
 166
 167        return off + payload_len;
 168}
 169
 170static void do_bind(int fd)
 171{
 172        struct sockaddr_ll laddr = {0};
 173
 174        laddr.sll_family = AF_PACKET;
 175        laddr.sll_protocol = htons(ETH_P_IP);
 176        laddr.sll_ifindex = if_nametoindex(cfg_ifname);
 177        if (!laddr.sll_ifindex)
 178                error(1, errno, "if_nametoindex");
 179
 180        if (bind(fd, (void *)&laddr, sizeof(laddr)))
 181                error(1, errno, "bind");
 182}
 183
 184static void do_send(int fd, char *buf, int len)
 185{
 186        int ret;
 187
 188        if (!cfg_use_vnet) {
 189                buf += sizeof(struct virtio_net_hdr);
 190                len -= sizeof(struct virtio_net_hdr);
 191        }
 192        if (cfg_use_dgram) {
 193                buf += ETH_HLEN;
 194                len -= ETH_HLEN;
 195        }
 196
 197        if (cfg_use_bind) {
 198                ret = write(fd, buf, len);
 199        } else {
 200                struct sockaddr_ll laddr = {0};
 201
 202                laddr.sll_protocol = htons(ETH_P_IP);
 203                laddr.sll_ifindex = if_nametoindex(cfg_ifname);
 204                if (!laddr.sll_ifindex)
 205                        error(1, errno, "if_nametoindex");
 206
 207                ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
 208        }
 209
 210        if (ret == -1)
 211                error(1, errno, "write");
 212        if (ret != len)
 213                error(1, 0, "write: %u %u", ret, len);
 214
 215        fprintf(stderr, "tx: %u\n", ret);
 216}
 217
 218static int do_tx(void)
 219{
 220        const int one = 1;
 221        int fd, len;
 222
 223        fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
 224        if (fd == -1)
 225                error(1, errno, "socket t");
 226
 227        if (cfg_use_bind)
 228                do_bind(fd);
 229
 230        if (cfg_use_qdisc_bypass &&
 231            setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
 232                error(1, errno, "setsockopt qdisc bypass");
 233
 234        if (cfg_use_vnet &&
 235            setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
 236                error(1, errno, "setsockopt vnet");
 237
 238        len = build_packet(cfg_payload_len);
 239
 240        if (cfg_truncate_len < len)
 241                len = cfg_truncate_len;
 242
 243        do_send(fd, tbuf, len);
 244
 245        if (close(fd))
 246                error(1, errno, "close t");
 247
 248        return len;
 249}
 250
 251static int setup_rx(void)
 252{
 253        struct timeval tv = { .tv_usec = 100 * 1000 };
 254        struct sockaddr_in raddr = {0};
 255        int fd;
 256
 257        fd = socket(PF_INET, SOCK_DGRAM, 0);
 258        if (fd == -1)
 259                error(1, errno, "socket r");
 260
 261        if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
 262                error(1, errno, "setsockopt rcv timeout");
 263
 264        raddr.sin_family = AF_INET;
 265        raddr.sin_port = htons(cfg_port);
 266        raddr.sin_addr.s_addr = htonl(INADDR_ANY);
 267
 268        if (bind(fd, (void *)&raddr, sizeof(raddr)))
 269                error(1, errno, "bind r");
 270
 271        return fd;
 272}
 273
 274static void do_rx(int fd, int expected_len, char *expected)
 275{
 276        int ret;
 277
 278        ret = recv(fd, rbuf, sizeof(rbuf), 0);
 279        if (ret == -1)
 280                error(1, errno, "recv");
 281        if (ret != expected_len)
 282                error(1, 0, "recv: %u != %u", ret, expected_len);
 283
 284        if (memcmp(rbuf, expected, ret))
 285                error(1, 0, "recv: data mismatch");
 286
 287        fprintf(stderr, "rx: %u\n", ret);
 288}
 289
 290static int setup_sniffer(void)
 291{
 292        struct timeval tv = { .tv_usec = 100 * 1000 };
 293        int fd;
 294
 295        fd = socket(PF_PACKET, SOCK_RAW, 0);
 296        if (fd == -1)
 297                error(1, errno, "socket p");
 298
 299        if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
 300                error(1, errno, "setsockopt rcv timeout");
 301
 302        pair_udp_setfilter(fd);
 303        do_bind(fd);
 304
 305        return fd;
 306}
 307
 308static void parse_opts(int argc, char **argv)
 309{
 310        int c;
 311
 312        while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
 313                switch (c) {
 314                case 'b':
 315                        cfg_use_bind = true;
 316                        break;
 317                case 'c':
 318                        cfg_use_csum_off = true;
 319                        break;
 320                case 'C':
 321                        cfg_use_csum_off_bad = true;
 322                        break;
 323                case 'd':
 324                        cfg_use_dgram = true;
 325                        break;
 326                case 'g':
 327                        cfg_use_gso = true;
 328                        break;
 329                case 'l':
 330                        cfg_payload_len = strtoul(optarg, NULL, 0);
 331                        break;
 332                case 'q':
 333                        cfg_use_qdisc_bypass = true;
 334                        break;
 335                case 't':
 336                        cfg_truncate_len = strtoul(optarg, NULL, 0);
 337                        break;
 338                case 'v':
 339                        cfg_use_vnet = true;
 340                        break;
 341                case 'V':
 342                        cfg_use_vlan = true;
 343                        break;
 344                default:
 345                        error(1, 0, "%s: parse error", argv[0]);
 346                }
 347        }
 348
 349        if (cfg_use_vlan && cfg_use_dgram)
 350                error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
 351
 352        if (cfg_use_csum_off && !cfg_use_vnet)
 353                error(1, 0, "option csum offload (-c) requires vnet (-v)");
 354
 355        if (cfg_use_csum_off_bad && !cfg_use_csum_off)
 356                error(1, 0, "option csum bad (-C) requires csum offload (-c)");
 357
 358        if (cfg_use_gso && !cfg_use_csum_off)
 359                error(1, 0, "option gso (-g) requires csum offload (-c)");
 360}
 361
 362static void run_test(void)
 363{
 364        int fdr, fds, total_len;
 365
 366        fdr = setup_rx();
 367        fds = setup_sniffer();
 368
 369        total_len = do_tx();
 370
 371        /* BPF filter accepts only this length, vlan changes MAC */
 372        if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
 373                do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
 374                      tbuf + sizeof(struct virtio_net_hdr));
 375
 376        do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
 377
 378        if (close(fds))
 379                error(1, errno, "close s");
 380        if (close(fdr))
 381                error(1, errno, "close r");
 382}
 383
 384int main(int argc, char **argv)
 385{
 386        parse_opts(argc, argv);
 387
 388        if (system("ip link set dev lo mtu 1500"))
 389                error(1, errno, "ip link set mtu");
 390        if (system("ip addr add dev lo 172.17.0.1/24"))
 391                error(1, errno, "ip addr add");
 392
 393        run_test();
 394
 395        fprintf(stderr, "OK\n\n");
 396        return 0;
 397}
 398