linux/tools/testing/selftests/net/mptcp/mptcp_connect.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2
   3#define _GNU_SOURCE
   4
   5#include <errno.h>
   6#include <limits.h>
   7#include <fcntl.h>
   8#include <string.h>
   9#include <stdarg.h>
  10#include <stdbool.h>
  11#include <stdint.h>
  12#include <stdio.h>
  13#include <stdlib.h>
  14#include <strings.h>
  15#include <signal.h>
  16#include <unistd.h>
  17
  18#include <sys/poll.h>
  19#include <sys/sendfile.h>
  20#include <sys/stat.h>
  21#include <sys/socket.h>
  22#include <sys/types.h>
  23#include <sys/mman.h>
  24
  25#include <netdb.h>
  26#include <netinet/in.h>
  27
  28#include <linux/tcp.h>
  29#include <linux/time_types.h>
  30
  31extern int optind;
  32
  33#ifndef IPPROTO_MPTCP
  34#define IPPROTO_MPTCP 262
  35#endif
  36#ifndef TCP_ULP
  37#define TCP_ULP 31
  38#endif
  39
  40static int  poll_timeout = 10 * 1000;
  41static bool listen_mode;
  42static bool quit;
  43
  44enum cfg_mode {
  45        CFG_MODE_POLL,
  46        CFG_MODE_MMAP,
  47        CFG_MODE_SENDFILE,
  48};
  49
  50enum cfg_peek {
  51        CFG_NONE_PEEK,
  52        CFG_WITH_PEEK,
  53        CFG_AFTER_PEEK,
  54};
  55
  56static enum cfg_mode cfg_mode = CFG_MODE_POLL;
  57static enum cfg_peek cfg_peek = CFG_NONE_PEEK;
  58static const char *cfg_host;
  59static const char *cfg_port     = "12000";
  60static int cfg_sock_proto       = IPPROTO_MPTCP;
  61static bool tcpulp_audit;
  62static int pf = AF_INET;
  63static int cfg_sndbuf;
  64static int cfg_rcvbuf;
  65static bool cfg_join;
  66static bool cfg_remove;
  67static unsigned int cfg_do_w;
  68static int cfg_wait;
  69static uint32_t cfg_mark;
  70
  71struct cfg_cmsg_types {
  72        unsigned int cmsg_enabled:1;
  73        unsigned int timestampns:1;
  74};
  75
  76static struct cfg_cmsg_types cfg_cmsg_types;
  77
  78static void die_usage(void)
  79{
  80        fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]"
  81                "[-l] [-w sec] connect_address\n");
  82        fprintf(stderr, "\t-6 use ipv6\n");
  83        fprintf(stderr, "\t-t num -- set poll timeout to num\n");
  84        fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
  85        fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n");
  86        fprintf(stderr, "\t-p num -- use port num\n");
  87        fprintf(stderr, "\t-s [MPTCP|TCP] -- use mptcp(default) or tcp sockets\n");
  88        fprintf(stderr, "\t-m [poll|mmap|sendfile] -- use poll(default)/mmap+write/sendfile\n");
  89        fprintf(stderr, "\t-M mark -- set socket packet mark\n");
  90        fprintf(stderr, "\t-u -- check mptcp ulp\n");
  91        fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
  92        fprintf(stderr, "\t-c cmsg -- test cmsg type <cmsg>\n");
  93        fprintf(stderr,
  94                "\t-P [saveWithPeek|saveAfterPeek] -- save data with/after MSG_PEEK form tcp socket\n");
  95        exit(1);
  96}
  97
  98static void xerror(const char *fmt, ...)
  99{
 100        va_list ap;
 101
 102        va_start(ap, fmt);
 103        vfprintf(stderr, fmt, ap);
 104        va_end(ap);
 105        exit(1);
 106}
 107
 108static void handle_signal(int nr)
 109{
 110        quit = true;
 111}
 112
 113static const char *getxinfo_strerr(int err)
 114{
 115        if (err == EAI_SYSTEM)
 116                return strerror(errno);
 117
 118        return gai_strerror(err);
 119}
 120
 121static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen,
 122                         char *host, socklen_t hostlen,
 123                         char *serv, socklen_t servlen)
 124{
 125        int flags = NI_NUMERICHOST | NI_NUMERICSERV;
 126        int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen,
 127                              flags);
 128
 129        if (err) {
 130                const char *errstr = getxinfo_strerr(err);
 131
 132                fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr);
 133                exit(1);
 134        }
 135}
 136
 137static void xgetaddrinfo(const char *node, const char *service,
 138                         const struct addrinfo *hints,
 139                         struct addrinfo **res)
 140{
 141        int err = getaddrinfo(node, service, hints, res);
 142
 143        if (err) {
 144                const char *errstr = getxinfo_strerr(err);
 145
 146                fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
 147                        node ? node : "", service ? service : "", errstr);
 148                exit(1);
 149        }
 150}
 151
 152static void set_rcvbuf(int fd, unsigned int size)
 153{
 154        int err;
 155
 156        err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size));
 157        if (err) {
 158                perror("set SO_RCVBUF");
 159                exit(1);
 160        }
 161}
 162
 163static void set_sndbuf(int fd, unsigned int size)
 164{
 165        int err;
 166
 167        err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size));
 168        if (err) {
 169                perror("set SO_SNDBUF");
 170                exit(1);
 171        }
 172}
 173
 174static void set_mark(int fd, uint32_t mark)
 175{
 176        int err;
 177
 178        err = setsockopt(fd, SOL_SOCKET, SO_MARK, &mark, sizeof(mark));
 179        if (err) {
 180                perror("set SO_MARK");
 181                exit(1);
 182        }
 183}
 184
 185static int sock_listen_mptcp(const char * const listenaddr,
 186                             const char * const port)
 187{
 188        int sock;
 189        struct addrinfo hints = {
 190                .ai_protocol = IPPROTO_TCP,
 191                .ai_socktype = SOCK_STREAM,
 192                .ai_flags = AI_PASSIVE | AI_NUMERICHOST
 193        };
 194
 195        hints.ai_family = pf;
 196
 197        struct addrinfo *a, *addr;
 198        int one = 1;
 199
 200        xgetaddrinfo(listenaddr, port, &hints, &addr);
 201        hints.ai_family = pf;
 202
 203        for (a = addr; a; a = a->ai_next) {
 204                sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto);
 205                if (sock < 0)
 206                        continue;
 207
 208                if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
 209                                     sizeof(one)))
 210                        perror("setsockopt");
 211
 212                if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
 213                        break; /* success */
 214
 215                perror("bind");
 216                close(sock);
 217                sock = -1;
 218        }
 219
 220        freeaddrinfo(addr);
 221
 222        if (sock < 0) {
 223                fprintf(stderr, "Could not create listen socket\n");
 224                return sock;
 225        }
 226
 227        if (listen(sock, 20)) {
 228                perror("listen");
 229                close(sock);
 230                return -1;
 231        }
 232
 233        return sock;
 234}
 235
 236static bool sock_test_tcpulp(const char * const remoteaddr,
 237                             const char * const port)
 238{
 239        struct addrinfo hints = {
 240                .ai_protocol = IPPROTO_TCP,
 241                .ai_socktype = SOCK_STREAM,
 242        };
 243        struct addrinfo *a, *addr;
 244        int sock = -1, ret = 0;
 245        bool test_pass = false;
 246
 247        hints.ai_family = AF_INET;
 248
 249        xgetaddrinfo(remoteaddr, port, &hints, &addr);
 250        for (a = addr; a; a = a->ai_next) {
 251                sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP);
 252                if (sock < 0) {
 253                        perror("socket");
 254                        continue;
 255                }
 256                ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp",
 257                                 sizeof("mptcp"));
 258                if (ret == -1 && errno == EOPNOTSUPP)
 259                        test_pass = true;
 260                close(sock);
 261
 262                if (test_pass)
 263                        break;
 264                if (!ret)
 265                        fprintf(stderr,
 266                                "setsockopt(TCP_ULP) returned 0\n");
 267                else
 268                        perror("setsockopt(TCP_ULP)");
 269        }
 270        return test_pass;
 271}
 272
 273static int sock_connect_mptcp(const char * const remoteaddr,
 274                              const char * const port, int proto)
 275{
 276        struct addrinfo hints = {
 277                .ai_protocol = IPPROTO_TCP,
 278                .ai_socktype = SOCK_STREAM,
 279        };
 280        struct addrinfo *a, *addr;
 281        int sock = -1;
 282
 283        hints.ai_family = pf;
 284
 285        xgetaddrinfo(remoteaddr, port, &hints, &addr);
 286        for (a = addr; a; a = a->ai_next) {
 287                sock = socket(a->ai_family, a->ai_socktype, proto);
 288                if (sock < 0) {
 289                        perror("socket");
 290                        continue;
 291                }
 292
 293                if (cfg_mark)
 294                        set_mark(sock, cfg_mark);
 295
 296                if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
 297                        break; /* success */
 298
 299                perror("connect()");
 300                close(sock);
 301                sock = -1;
 302        }
 303
 304        freeaddrinfo(addr);
 305        return sock;
 306}
 307
 308static size_t do_rnd_write(const int fd, char *buf, const size_t len)
 309{
 310        static bool first = true;
 311        unsigned int do_w;
 312        ssize_t bw;
 313
 314        do_w = rand() & 0xffff;
 315        if (do_w == 0 || do_w > len)
 316                do_w = len;
 317
 318        if (cfg_join && first && do_w > 100)
 319                do_w = 100;
 320
 321        if (cfg_remove && do_w > cfg_do_w)
 322                do_w = cfg_do_w;
 323
 324        bw = write(fd, buf, do_w);
 325        if (bw < 0)
 326                perror("write");
 327
 328        /* let the join handshake complete, before going on */
 329        if (cfg_join && first) {
 330                usleep(200000);
 331                first = false;
 332        }
 333
 334        if (cfg_remove)
 335                usleep(200000);
 336
 337        return bw;
 338}
 339
 340static size_t do_write(const int fd, char *buf, const size_t len)
 341{
 342        size_t offset = 0;
 343
 344        while (offset < len) {
 345                size_t written;
 346                ssize_t bw;
 347
 348                bw = write(fd, buf + offset, len - offset);
 349                if (bw < 0) {
 350                        perror("write");
 351                        return 0;
 352                }
 353
 354                written = (size_t)bw;
 355                offset += written;
 356        }
 357
 358        return offset;
 359}
 360
 361static void process_cmsg(struct msghdr *msgh)
 362{
 363        struct __kernel_timespec ts;
 364        bool ts_found = false;
 365        struct cmsghdr *cmsg;
 366
 367        for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
 368                if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SO_TIMESTAMPNS_NEW) {
 369                        memcpy(&ts, CMSG_DATA(cmsg), sizeof(ts));
 370                        ts_found = true;
 371                        continue;
 372                }
 373        }
 374
 375        if (cfg_cmsg_types.timestampns) {
 376                if (!ts_found)
 377                        xerror("TIMESTAMPNS not present\n");
 378        }
 379}
 380
 381static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len)
 382{
 383        char msg_buf[8192];
 384        struct iovec iov = {
 385                .iov_base = buf,
 386                .iov_len = len,
 387        };
 388        struct msghdr msg = {
 389                .msg_iov = &iov,
 390                .msg_iovlen = 1,
 391                .msg_control = msg_buf,
 392                .msg_controllen = sizeof(msg_buf),
 393        };
 394        int flags = 0;
 395        int ret = recvmsg(fd, &msg, flags);
 396
 397        if (ret <= 0)
 398                return ret;
 399
 400        if (msg.msg_controllen && !cfg_cmsg_types.cmsg_enabled)
 401                xerror("got %lu bytes of cmsg data, expected 0\n",
 402                       (unsigned long)msg.msg_controllen);
 403
 404        if (msg.msg_controllen == 0 && cfg_cmsg_types.cmsg_enabled)
 405                xerror("%s\n", "got no cmsg data");
 406
 407        if (msg.msg_controllen)
 408                process_cmsg(&msg);
 409
 410        return ret;
 411}
 412
 413static ssize_t do_rnd_read(const int fd, char *buf, const size_t len)
 414{
 415        int ret = 0;
 416        char tmp[16384];
 417        size_t cap = rand();
 418
 419        cap &= 0xffff;
 420
 421        if (cap == 0)
 422                cap = 1;
 423        else if (cap > len)
 424                cap = len;
 425
 426        if (cfg_peek == CFG_WITH_PEEK) {
 427                ret = recv(fd, buf, cap, MSG_PEEK);
 428                ret = (ret < 0) ? ret : read(fd, tmp, ret);
 429        } else if (cfg_peek == CFG_AFTER_PEEK) {
 430                ret = recv(fd, buf, cap, MSG_PEEK);
 431                ret = (ret < 0) ? ret : read(fd, buf, cap);
 432        } else if (cfg_cmsg_types.cmsg_enabled) {
 433                ret = do_recvmsg_cmsg(fd, buf, cap);
 434        } else {
 435                ret = read(fd, buf, cap);
 436        }
 437
 438        return ret;
 439}
 440
 441static void set_nonblock(int fd)
 442{
 443        int flags = fcntl(fd, F_GETFL);
 444
 445        if (flags == -1)
 446                return;
 447
 448        fcntl(fd, F_SETFL, flags | O_NONBLOCK);
 449}
 450
 451static int copyfd_io_poll(int infd, int peerfd, int outfd)
 452{
 453        struct pollfd fds = {
 454                .fd = peerfd,
 455                .events = POLLIN | POLLOUT,
 456        };
 457        unsigned int woff = 0, wlen = 0;
 458        char wbuf[8192];
 459
 460        set_nonblock(peerfd);
 461
 462        for (;;) {
 463                char rbuf[8192];
 464                ssize_t len;
 465
 466                if (fds.events == 0)
 467                        break;
 468
 469                switch (poll(&fds, 1, poll_timeout)) {
 470                case -1:
 471                        if (errno == EINTR)
 472                                continue;
 473                        perror("poll");
 474                        return 1;
 475                case 0:
 476                        fprintf(stderr, "%s: poll timed out (events: "
 477                                "POLLIN %u, POLLOUT %u)\n", __func__,
 478                                fds.events & POLLIN, fds.events & POLLOUT);
 479                        return 2;
 480                }
 481
 482                if (fds.revents & POLLIN) {
 483                        len = do_rnd_read(peerfd, rbuf, sizeof(rbuf));
 484                        if (len == 0) {
 485                                /* no more data to receive:
 486                                 * peer has closed its write side
 487                                 */
 488                                fds.events &= ~POLLIN;
 489
 490                                if ((fds.events & POLLOUT) == 0)
 491                                        /* and nothing more to send */
 492                                        break;
 493
 494                        /* Else, still have data to transmit */
 495                        } else if (len < 0) {
 496                                perror("read");
 497                                return 3;
 498                        }
 499
 500                        do_write(outfd, rbuf, len);
 501                }
 502
 503                if (fds.revents & POLLOUT) {
 504                        if (wlen == 0) {
 505                                woff = 0;
 506                                wlen = read(infd, wbuf, sizeof(wbuf));
 507                        }
 508
 509                        if (wlen > 0) {
 510                                ssize_t bw;
 511
 512                                bw = do_rnd_write(peerfd, wbuf + woff, wlen);
 513                                if (bw < 0)
 514                                        return 111;
 515
 516                                woff += bw;
 517                                wlen -= bw;
 518                        } else if (wlen == 0) {
 519                                /* We have no more data to send. */
 520                                fds.events &= ~POLLOUT;
 521
 522                                if ((fds.events & POLLIN) == 0)
 523                                        /* ... and peer also closed already */
 524                                        break;
 525
 526                                /* ... but we still receive.
 527                                 * Close our write side, ev. give some time
 528                                 * for address notification and/or checking
 529                                 * the current status
 530                                 */
 531                                if (cfg_wait)
 532                                        usleep(cfg_wait);
 533                                shutdown(peerfd, SHUT_WR);
 534                        } else {
 535                                if (errno == EINTR)
 536                                        continue;
 537                                perror("read");
 538                                return 4;
 539                        }
 540                }
 541
 542                if (fds.revents & (POLLERR | POLLNVAL)) {
 543                        fprintf(stderr, "Unexpected revents: "
 544                                "POLLERR/POLLNVAL(%x)\n", fds.revents);
 545                        return 5;
 546                }
 547        }
 548
 549        /* leave some time for late join/announce */
 550        if (cfg_join || cfg_remove)
 551                usleep(cfg_wait);
 552
 553        close(peerfd);
 554        return 0;
 555}
 556
 557static int do_recvfile(int infd, int outfd)
 558{
 559        ssize_t r;
 560
 561        do {
 562                char buf[16384];
 563
 564                r = do_rnd_read(infd, buf, sizeof(buf));
 565                if (r > 0) {
 566                        if (write(outfd, buf, r) != r)
 567                                break;
 568                } else if (r < 0) {
 569                        perror("read");
 570                }
 571        } while (r > 0);
 572
 573        return (int)r;
 574}
 575
 576static int do_mmap(int infd, int outfd, unsigned int size)
 577{
 578        char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0);
 579        ssize_t ret = 0, off = 0;
 580        size_t rem;
 581
 582        if (inbuf == MAP_FAILED) {
 583                perror("mmap");
 584                return 1;
 585        }
 586
 587        rem = size;
 588
 589        while (rem > 0) {
 590                ret = write(outfd, inbuf + off, rem);
 591
 592                if (ret < 0) {
 593                        perror("write");
 594                        break;
 595                }
 596
 597                off += ret;
 598                rem -= ret;
 599        }
 600
 601        munmap(inbuf, size);
 602        return rem;
 603}
 604
 605static int get_infd_size(int fd)
 606{
 607        struct stat sb;
 608        ssize_t count;
 609        int err;
 610
 611        err = fstat(fd, &sb);
 612        if (err < 0) {
 613                perror("fstat");
 614                return -1;
 615        }
 616
 617        if ((sb.st_mode & S_IFMT) != S_IFREG) {
 618                fprintf(stderr, "%s: stdin is not a regular file\n", __func__);
 619                return -2;
 620        }
 621
 622        count = sb.st_size;
 623        if (count > INT_MAX) {
 624                fprintf(stderr, "File too large: %zu\n", count);
 625                return -3;
 626        }
 627
 628        return (int)count;
 629}
 630
 631static int do_sendfile(int infd, int outfd, unsigned int count)
 632{
 633        while (count > 0) {
 634                ssize_t r;
 635
 636                r = sendfile(outfd, infd, NULL, count);
 637                if (r < 0) {
 638                        perror("sendfile");
 639                        return 3;
 640                }
 641
 642                count -= r;
 643        }
 644
 645        return 0;
 646}
 647
 648static int copyfd_io_mmap(int infd, int peerfd, int outfd,
 649                          unsigned int size)
 650{
 651        int err;
 652
 653        if (listen_mode) {
 654                err = do_recvfile(peerfd, outfd);
 655                if (err)
 656                        return err;
 657
 658                err = do_mmap(infd, peerfd, size);
 659        } else {
 660                err = do_mmap(infd, peerfd, size);
 661                if (err)
 662                        return err;
 663
 664                shutdown(peerfd, SHUT_WR);
 665
 666                err = do_recvfile(peerfd, outfd);
 667        }
 668
 669        return err;
 670}
 671
 672static int copyfd_io_sendfile(int infd, int peerfd, int outfd,
 673                              unsigned int size)
 674{
 675        int err;
 676
 677        if (listen_mode) {
 678                err = do_recvfile(peerfd, outfd);
 679                if (err)
 680                        return err;
 681
 682                err = do_sendfile(infd, peerfd, size);
 683        } else {
 684                err = do_sendfile(infd, peerfd, size);
 685                if (err)
 686                        return err;
 687                err = do_recvfile(peerfd, outfd);
 688        }
 689
 690        return err;
 691}
 692
 693static int copyfd_io(int infd, int peerfd, int outfd)
 694{
 695        int file_size;
 696
 697        switch (cfg_mode) {
 698        case CFG_MODE_POLL:
 699                return copyfd_io_poll(infd, peerfd, outfd);
 700        case CFG_MODE_MMAP:
 701                file_size = get_infd_size(infd);
 702                if (file_size < 0)
 703                        return file_size;
 704                return copyfd_io_mmap(infd, peerfd, outfd, file_size);
 705        case CFG_MODE_SENDFILE:
 706                file_size = get_infd_size(infd);
 707                if (file_size < 0)
 708                        return file_size;
 709                return copyfd_io_sendfile(infd, peerfd, outfd, file_size);
 710        }
 711
 712        fprintf(stderr, "Invalid mode %d\n", cfg_mode);
 713
 714        die_usage();
 715        return 1;
 716}
 717
 718static void check_sockaddr(int pf, struct sockaddr_storage *ss,
 719                           socklen_t salen)
 720{
 721        struct sockaddr_in6 *sin6;
 722        struct sockaddr_in *sin;
 723        socklen_t wanted_size = 0;
 724
 725        switch (pf) {
 726        case AF_INET:
 727                wanted_size = sizeof(*sin);
 728                sin = (void *)ss;
 729                if (!sin->sin_port)
 730                        fprintf(stderr, "accept: something wrong: ip connection from port 0");
 731                break;
 732        case AF_INET6:
 733                wanted_size = sizeof(*sin6);
 734                sin6 = (void *)ss;
 735                if (!sin6->sin6_port)
 736                        fprintf(stderr, "accept: something wrong: ipv6 connection from port 0");
 737                break;
 738        default:
 739                fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen);
 740                return;
 741        }
 742
 743        if (salen != wanted_size)
 744                fprintf(stderr, "accept: size mismatch, got %d expected %d\n",
 745                        (int)salen, wanted_size);
 746
 747        if (ss->ss_family != pf)
 748                fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n",
 749                        (int)ss->ss_family, pf);
 750}
 751
 752static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen)
 753{
 754        struct sockaddr_storage peerss;
 755        socklen_t peersalen = sizeof(peerss);
 756
 757        if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) {
 758                perror("getpeername");
 759                return;
 760        }
 761
 762        if (peersalen != salen) {
 763                fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen);
 764                return;
 765        }
 766
 767        if (memcmp(ss, &peerss, peersalen)) {
 768                char a[INET6_ADDRSTRLEN];
 769                char b[INET6_ADDRSTRLEN];
 770                char c[INET6_ADDRSTRLEN];
 771                char d[INET6_ADDRSTRLEN];
 772
 773                xgetnameinfo((struct sockaddr *)ss, salen,
 774                             a, sizeof(a), b, sizeof(b));
 775
 776                xgetnameinfo((struct sockaddr *)&peerss, peersalen,
 777                             c, sizeof(c), d, sizeof(d));
 778
 779                fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n",
 780                        __func__, a, c, b, d, peersalen, salen);
 781        }
 782}
 783
 784static void check_getpeername_connect(int fd)
 785{
 786        struct sockaddr_storage ss;
 787        socklen_t salen = sizeof(ss);
 788        char a[INET6_ADDRSTRLEN];
 789        char b[INET6_ADDRSTRLEN];
 790
 791        if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) {
 792                perror("getpeername");
 793                return;
 794        }
 795
 796        xgetnameinfo((struct sockaddr *)&ss, salen,
 797                     a, sizeof(a), b, sizeof(b));
 798
 799        if (strcmp(cfg_host, a) || strcmp(cfg_port, b))
 800                fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__,
 801                        cfg_host, a, cfg_port, b);
 802}
 803
 804static void maybe_close(int fd)
 805{
 806        unsigned int r = rand();
 807
 808        if (!(cfg_join || cfg_remove) && (r & 1))
 809                close(fd);
 810}
 811
 812int main_loop_s(int listensock)
 813{
 814        struct sockaddr_storage ss;
 815        struct pollfd polls;
 816        socklen_t salen;
 817        int remotesock;
 818
 819        polls.fd = listensock;
 820        polls.events = POLLIN;
 821
 822        switch (poll(&polls, 1, poll_timeout)) {
 823        case -1:
 824                perror("poll");
 825                return 1;
 826        case 0:
 827                fprintf(stderr, "%s: timed out\n", __func__);
 828                close(listensock);
 829                return 2;
 830        }
 831
 832        salen = sizeof(ss);
 833        remotesock = accept(listensock, (struct sockaddr *)&ss, &salen);
 834        if (remotesock >= 0) {
 835                maybe_close(listensock);
 836                check_sockaddr(pf, &ss, salen);
 837                check_getpeername(remotesock, &ss, salen);
 838
 839                return copyfd_io(0, remotesock, 1);
 840        }
 841
 842        perror("accept");
 843
 844        return 1;
 845}
 846
 847static void init_rng(void)
 848{
 849        int fd = open("/dev/urandom", O_RDONLY);
 850        unsigned int foo;
 851
 852        if (fd > 0) {
 853                int ret = read(fd, &foo, sizeof(foo));
 854
 855                if (ret < 0)
 856                        srand(fd + foo);
 857                close(fd);
 858        }
 859
 860        srand(foo);
 861}
 862
 863static void xsetsockopt(int fd, int level, int optname, const void *optval, socklen_t optlen)
 864{
 865        int err;
 866
 867        err = setsockopt(fd, level, optname, optval, optlen);
 868        if (err) {
 869                perror("setsockopt");
 870                exit(1);
 871        }
 872}
 873
 874static void apply_cmsg_types(int fd, const struct cfg_cmsg_types *cmsg)
 875{
 876        static const unsigned int on = 1;
 877
 878        if (cmsg->timestampns)
 879                xsetsockopt(fd, SOL_SOCKET, SO_TIMESTAMPNS_NEW, &on, sizeof(on));
 880}
 881
 882static void parse_cmsg_types(const char *type)
 883{
 884        char *next = strchr(type, ',');
 885        unsigned int len = 0;
 886
 887        cfg_cmsg_types.cmsg_enabled = 1;
 888
 889        if (next) {
 890                parse_cmsg_types(next + 1);
 891                len = next - type;
 892        } else {
 893                len = strlen(type);
 894        }
 895
 896        if (strncmp(type, "TIMESTAMPNS", len) == 0) {
 897                cfg_cmsg_types.timestampns = 1;
 898                return;
 899        }
 900
 901        fprintf(stderr, "Unrecognized cmsg option %s\n", type);
 902        exit(1);
 903}
 904
 905int main_loop(void)
 906{
 907        int fd;
 908
 909        /* listener is ready. */
 910        fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto);
 911        if (fd < 0)
 912                return 2;
 913
 914        check_getpeername_connect(fd);
 915
 916        if (cfg_rcvbuf)
 917                set_rcvbuf(fd, cfg_rcvbuf);
 918        if (cfg_sndbuf)
 919                set_sndbuf(fd, cfg_sndbuf);
 920        if (cfg_cmsg_types.cmsg_enabled)
 921                apply_cmsg_types(fd, &cfg_cmsg_types);
 922
 923        return copyfd_io(0, fd, 1);
 924}
 925
 926int parse_proto(const char *proto)
 927{
 928        if (!strcasecmp(proto, "MPTCP"))
 929                return IPPROTO_MPTCP;
 930        if (!strcasecmp(proto, "TCP"))
 931                return IPPROTO_TCP;
 932
 933        fprintf(stderr, "Unknown protocol: %s\n.", proto);
 934        die_usage();
 935
 936        /* silence compiler warning */
 937        return 0;
 938}
 939
 940int parse_mode(const char *mode)
 941{
 942        if (!strcasecmp(mode, "poll"))
 943                return CFG_MODE_POLL;
 944        if (!strcasecmp(mode, "mmap"))
 945                return CFG_MODE_MMAP;
 946        if (!strcasecmp(mode, "sendfile"))
 947                return CFG_MODE_SENDFILE;
 948
 949        fprintf(stderr, "Unknown test mode: %s\n", mode);
 950        fprintf(stderr, "Supported modes are:\n");
 951        fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n");
 952        fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n");
 953        fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n");
 954
 955        die_usage();
 956
 957        /* silence compiler warning */
 958        return 0;
 959}
 960
 961int parse_peek(const char *mode)
 962{
 963        if (!strcasecmp(mode, "saveWithPeek"))
 964                return CFG_WITH_PEEK;
 965        if (!strcasecmp(mode, "saveAfterPeek"))
 966                return CFG_AFTER_PEEK;
 967
 968        fprintf(stderr, "Unknown: %s\n", mode);
 969        fprintf(stderr, "Supported MSG_PEEK mode are:\n");
 970        fprintf(stderr,
 971                "\t\t\"saveWithPeek\" - recv data with flags 'MSG_PEEK' and save the peek data into file\n");
 972        fprintf(stderr,
 973                "\t\t\"saveAfterPeek\" - read and save data into file after recv with flags 'MSG_PEEK'\n");
 974
 975        die_usage();
 976
 977        /* silence compiler warning */
 978        return 0;
 979}
 980
 981static int parse_int(const char *size)
 982{
 983        unsigned long s;
 984
 985        errno = 0;
 986
 987        s = strtoul(size, NULL, 0);
 988
 989        if (errno) {
 990                fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
 991                        size, strerror(errno));
 992                die_usage();
 993        }
 994
 995        if (s > INT_MAX) {
 996                fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
 997                        size, strerror(ERANGE));
 998                die_usage();
 999        }
1000
1001        return (int)s;
1002}
1003
1004static void parse_opts(int argc, char **argv)
1005{
1006        int c;
1007
1008        while ((c = getopt(argc, argv, "6jr:lp:s:hut:m:S:R:w:M:P:c:")) != -1) {
1009                switch (c) {
1010                case 'j':
1011                        cfg_join = true;
1012                        cfg_mode = CFG_MODE_POLL;
1013                        cfg_wait = 400000;
1014                        break;
1015                case 'r':
1016                        cfg_remove = true;
1017                        cfg_mode = CFG_MODE_POLL;
1018                        cfg_wait = 400000;
1019                        cfg_do_w = atoi(optarg);
1020                        if (cfg_do_w <= 0)
1021                                cfg_do_w = 50;
1022                        break;
1023                case 'l':
1024                        listen_mode = true;
1025                        break;
1026                case 'p':
1027                        cfg_port = optarg;
1028                        break;
1029                case 's':
1030                        cfg_sock_proto = parse_proto(optarg);
1031                        break;
1032                case 'h':
1033                        die_usage();
1034                        break;
1035                case 'u':
1036                        tcpulp_audit = true;
1037                        break;
1038                case '6':
1039                        pf = AF_INET6;
1040                        break;
1041                case 't':
1042                        poll_timeout = atoi(optarg) * 1000;
1043                        if (poll_timeout <= 0)
1044                                poll_timeout = -1;
1045                        break;
1046                case 'm':
1047                        cfg_mode = parse_mode(optarg);
1048                        break;
1049                case 'S':
1050                        cfg_sndbuf = parse_int(optarg);
1051                        break;
1052                case 'R':
1053                        cfg_rcvbuf = parse_int(optarg);
1054                        break;
1055                case 'w':
1056                        cfg_wait = atoi(optarg)*1000000;
1057                        break;
1058                case 'M':
1059                        cfg_mark = strtol(optarg, NULL, 0);
1060                        break;
1061                case 'P':
1062                        cfg_peek = parse_peek(optarg);
1063                        break;
1064                case 'c':
1065                        parse_cmsg_types(optarg);
1066                        break;
1067                }
1068        }
1069
1070        if (optind + 1 != argc)
1071                die_usage();
1072        cfg_host = argv[optind];
1073
1074        if (strchr(cfg_host, ':'))
1075                pf = AF_INET6;
1076}
1077
1078int main(int argc, char *argv[])
1079{
1080        init_rng();
1081
1082        signal(SIGUSR1, handle_signal);
1083        parse_opts(argc, argv);
1084
1085        if (tcpulp_audit)
1086                return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1;
1087
1088        if (listen_mode) {
1089                int fd = sock_listen_mptcp(cfg_host, cfg_port);
1090
1091                if (fd < 0)
1092                        return 1;
1093
1094                if (cfg_rcvbuf)
1095                        set_rcvbuf(fd, cfg_rcvbuf);
1096                if (cfg_sndbuf)
1097                        set_sndbuf(fd, cfg_sndbuf);
1098                if (cfg_mark)
1099                        set_mark(fd, cfg_mark);
1100                if (cfg_cmsg_types.cmsg_enabled)
1101                        apply_cmsg_types(fd, &cfg_cmsg_types);
1102
1103                return main_loop_s(fd);
1104        }
1105
1106        return main_loop();
1107}
1108