linux/tools/testing/vsock/vsock_diag_test.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * vsock_diag_test - vsock_diag.ko test suite
   4 *
   5 * Copyright (C) 2017 Red Hat, Inc.
   6 *
   7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
   8 */
   9
  10#include <getopt.h>
  11#include <stdio.h>
  12#include <stdbool.h>
  13#include <stdlib.h>
  14#include <string.h>
  15#include <errno.h>
  16#include <unistd.h>
  17#include <signal.h>
  18#include <sys/socket.h>
  19#include <sys/stat.h>
  20#include <sys/types.h>
  21#include <linux/list.h>
  22#include <linux/net.h>
  23#include <linux/netlink.h>
  24#include <linux/sock_diag.h>
  25#include <netinet/tcp.h>
  26
  27#include "../../../include/uapi/linux/vm_sockets.h"
  28#include "../../../include/uapi/linux/vm_sockets_diag.h"
  29
  30#include "timeout.h"
  31#include "control.h"
  32
  33enum test_mode {
  34        TEST_MODE_UNSET,
  35        TEST_MODE_CLIENT,
  36        TEST_MODE_SERVER
  37};
  38
  39/* Per-socket status */
  40struct vsock_stat {
  41        struct list_head list;
  42        struct vsock_diag_msg msg;
  43};
  44
  45static const char *sock_type_str(int type)
  46{
  47        switch (type) {
  48        case SOCK_DGRAM:
  49                return "DGRAM";
  50        case SOCK_STREAM:
  51                return "STREAM";
  52        default:
  53                return "INVALID TYPE";
  54        }
  55}
  56
  57static const char *sock_state_str(int state)
  58{
  59        switch (state) {
  60        case TCP_CLOSE:
  61                return "UNCONNECTED";
  62        case TCP_SYN_SENT:
  63                return "CONNECTING";
  64        case TCP_ESTABLISHED:
  65                return "CONNECTED";
  66        case TCP_CLOSING:
  67                return "DISCONNECTING";
  68        case TCP_LISTEN:
  69                return "LISTEN";
  70        default:
  71                return "INVALID STATE";
  72        }
  73}
  74
  75static const char *sock_shutdown_str(int shutdown)
  76{
  77        switch (shutdown) {
  78        case 1:
  79                return "RCV_SHUTDOWN";
  80        case 2:
  81                return "SEND_SHUTDOWN";
  82        case 3:
  83                return "RCV_SHUTDOWN | SEND_SHUTDOWN";
  84        default:
  85                return "0";
  86        }
  87}
  88
  89static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
  90{
  91        if (cid == VMADDR_CID_ANY)
  92                fprintf(fp, "*:");
  93        else
  94                fprintf(fp, "%u:", cid);
  95
  96        if (port == VMADDR_PORT_ANY)
  97                fprintf(fp, "*");
  98        else
  99                fprintf(fp, "%u", port);
 100}
 101
 102static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
 103{
 104        print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
 105        fprintf(fp, " ");
 106        print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
 107        fprintf(fp, " %s %s %s %u\n",
 108                sock_type_str(st->msg.vdiag_type),
 109                sock_state_str(st->msg.vdiag_state),
 110                sock_shutdown_str(st->msg.vdiag_shutdown),
 111                st->msg.vdiag_ino);
 112}
 113
 114static void print_vsock_stats(FILE *fp, struct list_head *head)
 115{
 116        struct vsock_stat *st;
 117
 118        list_for_each_entry(st, head, list)
 119                print_vsock_stat(fp, st);
 120}
 121
 122static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
 123{
 124        struct vsock_stat *st;
 125        struct stat stat;
 126
 127        if (fstat(fd, &stat) < 0) {
 128                perror("fstat");
 129                exit(EXIT_FAILURE);
 130        }
 131
 132        list_for_each_entry(st, head, list)
 133                if (st->msg.vdiag_ino == stat.st_ino)
 134                        return st;
 135
 136        fprintf(stderr, "cannot find fd %d\n", fd);
 137        exit(EXIT_FAILURE);
 138}
 139
 140static void check_no_sockets(struct list_head *head)
 141{
 142        if (!list_empty(head)) {
 143                fprintf(stderr, "expected no sockets\n");
 144                print_vsock_stats(stderr, head);
 145                exit(1);
 146        }
 147}
 148
 149static void check_num_sockets(struct list_head *head, int expected)
 150{
 151        struct list_head *node;
 152        int n = 0;
 153
 154        list_for_each(node, head)
 155                n++;
 156
 157        if (n != expected) {
 158                fprintf(stderr, "expected %d sockets, found %d\n",
 159                        expected, n);
 160                print_vsock_stats(stderr, head);
 161                exit(EXIT_FAILURE);
 162        }
 163}
 164
 165static void check_socket_state(struct vsock_stat *st, __u8 state)
 166{
 167        if (st->msg.vdiag_state != state) {
 168                fprintf(stderr, "expected socket state %#x, got %#x\n",
 169                        state, st->msg.vdiag_state);
 170                exit(EXIT_FAILURE);
 171        }
 172}
 173
 174static void send_req(int fd)
 175{
 176        struct sockaddr_nl nladdr = {
 177                .nl_family = AF_NETLINK,
 178        };
 179        struct {
 180                struct nlmsghdr nlh;
 181                struct vsock_diag_req vreq;
 182        } req = {
 183                .nlh = {
 184                        .nlmsg_len = sizeof(req),
 185                        .nlmsg_type = SOCK_DIAG_BY_FAMILY,
 186                        .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
 187                },
 188                .vreq = {
 189                        .sdiag_family = AF_VSOCK,
 190                        .vdiag_states = ~(__u32)0,
 191                },
 192        };
 193        struct iovec iov = {
 194                .iov_base = &req,
 195                .iov_len = sizeof(req),
 196        };
 197        struct msghdr msg = {
 198                .msg_name = &nladdr,
 199                .msg_namelen = sizeof(nladdr),
 200                .msg_iov = &iov,
 201                .msg_iovlen = 1,
 202        };
 203
 204        for (;;) {
 205                if (sendmsg(fd, &msg, 0) < 0) {
 206                        if (errno == EINTR)
 207                                continue;
 208
 209                        perror("sendmsg");
 210                        exit(EXIT_FAILURE);
 211                }
 212
 213                return;
 214        }
 215}
 216
 217static ssize_t recv_resp(int fd, void *buf, size_t len)
 218{
 219        struct sockaddr_nl nladdr = {
 220                .nl_family = AF_NETLINK,
 221        };
 222        struct iovec iov = {
 223                .iov_base = buf,
 224                .iov_len = len,
 225        };
 226        struct msghdr msg = {
 227                .msg_name = &nladdr,
 228                .msg_namelen = sizeof(nladdr),
 229                .msg_iov = &iov,
 230                .msg_iovlen = 1,
 231        };
 232        ssize_t ret;
 233
 234        do {
 235                ret = recvmsg(fd, &msg, 0);
 236        } while (ret < 0 && errno == EINTR);
 237
 238        if (ret < 0) {
 239                perror("recvmsg");
 240                exit(EXIT_FAILURE);
 241        }
 242
 243        return ret;
 244}
 245
 246static void add_vsock_stat(struct list_head *sockets,
 247                           const struct vsock_diag_msg *resp)
 248{
 249        struct vsock_stat *st;
 250
 251        st = malloc(sizeof(*st));
 252        if (!st) {
 253                perror("malloc");
 254                exit(EXIT_FAILURE);
 255        }
 256
 257        st->msg = *resp;
 258        list_add_tail(&st->list, sockets);
 259}
 260
 261/*
 262 * Read vsock stats into a list.
 263 */
 264static void read_vsock_stat(struct list_head *sockets)
 265{
 266        long buf[8192 / sizeof(long)];
 267        int fd;
 268
 269        fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
 270        if (fd < 0) {
 271                perror("socket");
 272                exit(EXIT_FAILURE);
 273        }
 274
 275        send_req(fd);
 276
 277        for (;;) {
 278                const struct nlmsghdr *h;
 279                ssize_t ret;
 280
 281                ret = recv_resp(fd, buf, sizeof(buf));
 282                if (ret == 0)
 283                        goto done;
 284                if (ret < sizeof(*h)) {
 285                        fprintf(stderr, "short read of %zd bytes\n", ret);
 286                        exit(EXIT_FAILURE);
 287                }
 288
 289                h = (struct nlmsghdr *)buf;
 290
 291                while (NLMSG_OK(h, ret)) {
 292                        if (h->nlmsg_type == NLMSG_DONE)
 293                                goto done;
 294
 295                        if (h->nlmsg_type == NLMSG_ERROR) {
 296                                const struct nlmsgerr *err = NLMSG_DATA(h);
 297
 298                                if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
 299                                        fprintf(stderr, "NLMSG_ERROR\n");
 300                                else {
 301                                        errno = -err->error;
 302                                        perror("NLMSG_ERROR");
 303                                }
 304
 305                                exit(EXIT_FAILURE);
 306                        }
 307
 308                        if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
 309                                fprintf(stderr, "unexpected nlmsg_type %#x\n",
 310                                        h->nlmsg_type);
 311                                exit(EXIT_FAILURE);
 312                        }
 313                        if (h->nlmsg_len <
 314                            NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
 315                                fprintf(stderr, "short vsock_diag_msg\n");
 316                                exit(EXIT_FAILURE);
 317                        }
 318
 319                        add_vsock_stat(sockets, NLMSG_DATA(h));
 320
 321                        h = NLMSG_NEXT(h, ret);
 322                }
 323        }
 324
 325done:
 326        close(fd);
 327}
 328
 329static void free_sock_stat(struct list_head *sockets)
 330{
 331        struct vsock_stat *st;
 332        struct vsock_stat *next;
 333
 334        list_for_each_entry_safe(st, next, sockets, list)
 335                free(st);
 336}
 337
 338static void test_no_sockets(unsigned int peer_cid)
 339{
 340        LIST_HEAD(sockets);
 341
 342        read_vsock_stat(&sockets);
 343
 344        check_no_sockets(&sockets);
 345
 346        free_sock_stat(&sockets);
 347}
 348
 349static void test_listen_socket_server(unsigned int peer_cid)
 350{
 351        union {
 352                struct sockaddr sa;
 353                struct sockaddr_vm svm;
 354        } addr = {
 355                .svm = {
 356                        .svm_family = AF_VSOCK,
 357                        .svm_port = 1234,
 358                        .svm_cid = VMADDR_CID_ANY,
 359                },
 360        };
 361        LIST_HEAD(sockets);
 362        struct vsock_stat *st;
 363        int fd;
 364
 365        fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 366
 367        if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
 368                perror("bind");
 369                exit(EXIT_FAILURE);
 370        }
 371
 372        if (listen(fd, 1) < 0) {
 373                perror("listen");
 374                exit(EXIT_FAILURE);
 375        }
 376
 377        read_vsock_stat(&sockets);
 378
 379        check_num_sockets(&sockets, 1);
 380        st = find_vsock_stat(&sockets, fd);
 381        check_socket_state(st, TCP_LISTEN);
 382
 383        close(fd);
 384        free_sock_stat(&sockets);
 385}
 386
 387static void test_connect_client(unsigned int peer_cid)
 388{
 389        union {
 390                struct sockaddr sa;
 391                struct sockaddr_vm svm;
 392        } addr = {
 393                .svm = {
 394                        .svm_family = AF_VSOCK,
 395                        .svm_port = 1234,
 396                        .svm_cid = peer_cid,
 397                },
 398        };
 399        int fd;
 400        int ret;
 401        LIST_HEAD(sockets);
 402        struct vsock_stat *st;
 403
 404        control_expectln("LISTENING");
 405
 406        fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 407
 408        timeout_begin(TIMEOUT);
 409        do {
 410                ret = connect(fd, &addr.sa, sizeof(addr.svm));
 411                timeout_check("connect");
 412        } while (ret < 0 && errno == EINTR);
 413        timeout_end();
 414
 415        if (ret < 0) {
 416                perror("connect");
 417                exit(EXIT_FAILURE);
 418        }
 419
 420        read_vsock_stat(&sockets);
 421
 422        check_num_sockets(&sockets, 1);
 423        st = find_vsock_stat(&sockets, fd);
 424        check_socket_state(st, TCP_ESTABLISHED);
 425
 426        control_expectln("DONE");
 427        control_writeln("DONE");
 428
 429        close(fd);
 430        free_sock_stat(&sockets);
 431}
 432
 433static void test_connect_server(unsigned int peer_cid)
 434{
 435        union {
 436                struct sockaddr sa;
 437                struct sockaddr_vm svm;
 438        } addr = {
 439                .svm = {
 440                        .svm_family = AF_VSOCK,
 441                        .svm_port = 1234,
 442                        .svm_cid = VMADDR_CID_ANY,
 443                },
 444        };
 445        union {
 446                struct sockaddr sa;
 447                struct sockaddr_vm svm;
 448        } clientaddr;
 449        socklen_t clientaddr_len = sizeof(clientaddr.svm);
 450        LIST_HEAD(sockets);
 451        struct vsock_stat *st;
 452        int fd;
 453        int client_fd;
 454
 455        fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 456
 457        if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
 458                perror("bind");
 459                exit(EXIT_FAILURE);
 460        }
 461
 462        if (listen(fd, 1) < 0) {
 463                perror("listen");
 464                exit(EXIT_FAILURE);
 465        }
 466
 467        control_writeln("LISTENING");
 468
 469        timeout_begin(TIMEOUT);
 470        do {
 471                client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
 472                timeout_check("accept");
 473        } while (client_fd < 0 && errno == EINTR);
 474        timeout_end();
 475
 476        if (client_fd < 0) {
 477                perror("accept");
 478                exit(EXIT_FAILURE);
 479        }
 480        if (clientaddr.sa.sa_family != AF_VSOCK) {
 481                fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
 482                        clientaddr.sa.sa_family);
 483                exit(EXIT_FAILURE);
 484        }
 485        if (clientaddr.svm.svm_cid != peer_cid) {
 486                fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
 487                        peer_cid, clientaddr.svm.svm_cid);
 488                exit(EXIT_FAILURE);
 489        }
 490
 491        read_vsock_stat(&sockets);
 492
 493        check_num_sockets(&sockets, 2);
 494        find_vsock_stat(&sockets, fd);
 495        st = find_vsock_stat(&sockets, client_fd);
 496        check_socket_state(st, TCP_ESTABLISHED);
 497
 498        control_writeln("DONE");
 499        control_expectln("DONE");
 500
 501        close(client_fd);
 502        close(fd);
 503        free_sock_stat(&sockets);
 504}
 505
 506static struct {
 507        const char *name;
 508        void (*run_client)(unsigned int peer_cid);
 509        void (*run_server)(unsigned int peer_cid);
 510} test_cases[] = {
 511        {
 512                .name = "No sockets",
 513                .run_server = test_no_sockets,
 514        },
 515        {
 516                .name = "Listen socket",
 517                .run_server = test_listen_socket_server,
 518        },
 519        {
 520                .name = "Connect",
 521                .run_client = test_connect_client,
 522                .run_server = test_connect_server,
 523        },
 524        {},
 525};
 526
 527static void init_signals(void)
 528{
 529        struct sigaction act = {
 530                .sa_handler = sigalrm,
 531        };
 532
 533        sigaction(SIGALRM, &act, NULL);
 534        signal(SIGPIPE, SIG_IGN);
 535}
 536
 537static unsigned int parse_cid(const char *str)
 538{
 539        char *endptr = NULL;
 540        unsigned long int n;
 541
 542        errno = 0;
 543        n = strtoul(str, &endptr, 10);
 544        if (errno || *endptr != '\0') {
 545                fprintf(stderr, "malformed CID \"%s\"\n", str);
 546                exit(EXIT_FAILURE);
 547        }
 548        return n;
 549}
 550
 551static const char optstring[] = "";
 552static const struct option longopts[] = {
 553        {
 554                .name = "control-host",
 555                .has_arg = required_argument,
 556                .val = 'H',
 557        },
 558        {
 559                .name = "control-port",
 560                .has_arg = required_argument,
 561                .val = 'P',
 562        },
 563        {
 564                .name = "mode",
 565                .has_arg = required_argument,
 566                .val = 'm',
 567        },
 568        {
 569                .name = "peer-cid",
 570                .has_arg = required_argument,
 571                .val = 'p',
 572        },
 573        {
 574                .name = "help",
 575                .has_arg = no_argument,
 576                .val = '?',
 577        },
 578        {},
 579};
 580
 581static void usage(void)
 582{
 583        fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
 584                "\n"
 585                "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
 586                "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
 587                "\n"
 588                "Run vsock_diag.ko tests.  Must be launched in both\n"
 589                "guest and host.  One side must use --mode=client and\n"
 590                "the other side must use --mode=server.\n"
 591                "\n"
 592                "A TCP control socket connection is used to coordinate tests\n"
 593                "between the client and the server.  The server requires a\n"
 594                "listen address and the client requires an address to\n"
 595                "connect to.\n"
 596                "\n"
 597                "The CID of the other side must be given with --peer-cid=<cid>.\n");
 598        exit(EXIT_FAILURE);
 599}
 600
 601int main(int argc, char **argv)
 602{
 603        const char *control_host = NULL;
 604        const char *control_port = NULL;
 605        int mode = TEST_MODE_UNSET;
 606        unsigned int peer_cid = VMADDR_CID_ANY;
 607        int i;
 608
 609        init_signals();
 610
 611        for (;;) {
 612                int opt = getopt_long(argc, argv, optstring, longopts, NULL);
 613
 614                if (opt == -1)
 615                        break;
 616
 617                switch (opt) {
 618                case 'H':
 619                        control_host = optarg;
 620                        break;
 621                case 'm':
 622                        if (strcmp(optarg, "client") == 0)
 623                                mode = TEST_MODE_CLIENT;
 624                        else if (strcmp(optarg, "server") == 0)
 625                                mode = TEST_MODE_SERVER;
 626                        else {
 627                                fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
 628                                return EXIT_FAILURE;
 629                        }
 630                        break;
 631                case 'p':
 632                        peer_cid = parse_cid(optarg);
 633                        break;
 634                case 'P':
 635                        control_port = optarg;
 636                        break;
 637                case '?':
 638                default:
 639                        usage();
 640                }
 641        }
 642
 643        if (!control_port)
 644                usage();
 645        if (mode == TEST_MODE_UNSET)
 646                usage();
 647        if (peer_cid == VMADDR_CID_ANY)
 648                usage();
 649
 650        if (!control_host) {
 651                if (mode != TEST_MODE_SERVER)
 652                        usage();
 653                control_host = "0.0.0.0";
 654        }
 655
 656        control_init(control_host, control_port, mode == TEST_MODE_SERVER);
 657
 658        for (i = 0; test_cases[i].name; i++) {
 659                void (*run)(unsigned int peer_cid);
 660
 661                printf("%s...", test_cases[i].name);
 662                fflush(stdout);
 663
 664                if (mode == TEST_MODE_CLIENT)
 665                        run = test_cases[i].run_client;
 666                else
 667                        run = test_cases[i].run_server;
 668
 669                if (run)
 670                        run(peer_cid);
 671
 672                printf("ok\n");
 673        }
 674
 675        control_cleanup();
 676        return EXIT_SUCCESS;
 677}
 678