linux/tools/testing/vsock/vsock_diag_test.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * vsock_diag_test - vsock_diag.ko test suite
   4 *
   5 * Copyright (C) 2017 Red Hat, Inc.
   6 *
   7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
   8 */
   9
  10#include <getopt.h>
  11#include <stdio.h>
  12#include <stdlib.h>
  13#include <string.h>
  14#include <errno.h>
  15#include <unistd.h>
  16#include <sys/stat.h>
  17#include <sys/types.h>
  18#include <linux/list.h>
  19#include <linux/net.h>
  20#include <linux/netlink.h>
  21#include <linux/sock_diag.h>
  22#include <linux/vm_sockets_diag.h>
  23#include <netinet/tcp.h>
  24
  25#include "timeout.h"
  26#include "control.h"
  27#include "util.h"
  28
  29/* Per-socket status */
  30struct vsock_stat {
  31        struct list_head list;
  32        struct vsock_diag_msg msg;
  33};
  34
  35static const char *sock_type_str(int type)
  36{
  37        switch (type) {
  38        case SOCK_DGRAM:
  39                return "DGRAM";
  40        case SOCK_STREAM:
  41                return "STREAM";
  42        default:
  43                return "INVALID TYPE";
  44        }
  45}
  46
  47static const char *sock_state_str(int state)
  48{
  49        switch (state) {
  50        case TCP_CLOSE:
  51                return "UNCONNECTED";
  52        case TCP_SYN_SENT:
  53                return "CONNECTING";
  54        case TCP_ESTABLISHED:
  55                return "CONNECTED";
  56        case TCP_CLOSING:
  57                return "DISCONNECTING";
  58        case TCP_LISTEN:
  59                return "LISTEN";
  60        default:
  61                return "INVALID STATE";
  62        }
  63}
  64
  65static const char *sock_shutdown_str(int shutdown)
  66{
  67        switch (shutdown) {
  68        case 1:
  69                return "RCV_SHUTDOWN";
  70        case 2:
  71                return "SEND_SHUTDOWN";
  72        case 3:
  73                return "RCV_SHUTDOWN | SEND_SHUTDOWN";
  74        default:
  75                return "0";
  76        }
  77}
  78
  79static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
  80{
  81        if (cid == VMADDR_CID_ANY)
  82                fprintf(fp, "*:");
  83        else
  84                fprintf(fp, "%u:", cid);
  85
  86        if (port == VMADDR_PORT_ANY)
  87                fprintf(fp, "*");
  88        else
  89                fprintf(fp, "%u", port);
  90}
  91
  92static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
  93{
  94        print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
  95        fprintf(fp, " ");
  96        print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
  97        fprintf(fp, " %s %s %s %u\n",
  98                sock_type_str(st->msg.vdiag_type),
  99                sock_state_str(st->msg.vdiag_state),
 100                sock_shutdown_str(st->msg.vdiag_shutdown),
 101                st->msg.vdiag_ino);
 102}
 103
 104static void print_vsock_stats(FILE *fp, struct list_head *head)
 105{
 106        struct vsock_stat *st;
 107
 108        list_for_each_entry(st, head, list)
 109                print_vsock_stat(fp, st);
 110}
 111
 112static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
 113{
 114        struct vsock_stat *st;
 115        struct stat stat;
 116
 117        if (fstat(fd, &stat) < 0) {
 118                perror("fstat");
 119                exit(EXIT_FAILURE);
 120        }
 121
 122        list_for_each_entry(st, head, list)
 123                if (st->msg.vdiag_ino == stat.st_ino)
 124                        return st;
 125
 126        fprintf(stderr, "cannot find fd %d\n", fd);
 127        exit(EXIT_FAILURE);
 128}
 129
 130static void check_no_sockets(struct list_head *head)
 131{
 132        if (!list_empty(head)) {
 133                fprintf(stderr, "expected no sockets\n");
 134                print_vsock_stats(stderr, head);
 135                exit(1);
 136        }
 137}
 138
 139static void check_num_sockets(struct list_head *head, int expected)
 140{
 141        struct list_head *node;
 142        int n = 0;
 143
 144        list_for_each(node, head)
 145                n++;
 146
 147        if (n != expected) {
 148                fprintf(stderr, "expected %d sockets, found %d\n",
 149                        expected, n);
 150                print_vsock_stats(stderr, head);
 151                exit(EXIT_FAILURE);
 152        }
 153}
 154
 155static void check_socket_state(struct vsock_stat *st, __u8 state)
 156{
 157        if (st->msg.vdiag_state != state) {
 158                fprintf(stderr, "expected socket state %#x, got %#x\n",
 159                        state, st->msg.vdiag_state);
 160                exit(EXIT_FAILURE);
 161        }
 162}
 163
 164static void send_req(int fd)
 165{
 166        struct sockaddr_nl nladdr = {
 167                .nl_family = AF_NETLINK,
 168        };
 169        struct {
 170                struct nlmsghdr nlh;
 171                struct vsock_diag_req vreq;
 172        } req = {
 173                .nlh = {
 174                        .nlmsg_len = sizeof(req),
 175                        .nlmsg_type = SOCK_DIAG_BY_FAMILY,
 176                        .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
 177                },
 178                .vreq = {
 179                        .sdiag_family = AF_VSOCK,
 180                        .vdiag_states = ~(__u32)0,
 181                },
 182        };
 183        struct iovec iov = {
 184                .iov_base = &req,
 185                .iov_len = sizeof(req),
 186        };
 187        struct msghdr msg = {
 188                .msg_name = &nladdr,
 189                .msg_namelen = sizeof(nladdr),
 190                .msg_iov = &iov,
 191                .msg_iovlen = 1,
 192        };
 193
 194        for (;;) {
 195                if (sendmsg(fd, &msg, 0) < 0) {
 196                        if (errno == EINTR)
 197                                continue;
 198
 199                        perror("sendmsg");
 200                        exit(EXIT_FAILURE);
 201                }
 202
 203                return;
 204        }
 205}
 206
 207static ssize_t recv_resp(int fd, void *buf, size_t len)
 208{
 209        struct sockaddr_nl nladdr = {
 210                .nl_family = AF_NETLINK,
 211        };
 212        struct iovec iov = {
 213                .iov_base = buf,
 214                .iov_len = len,
 215        };
 216        struct msghdr msg = {
 217                .msg_name = &nladdr,
 218                .msg_namelen = sizeof(nladdr),
 219                .msg_iov = &iov,
 220                .msg_iovlen = 1,
 221        };
 222        ssize_t ret;
 223
 224        do {
 225                ret = recvmsg(fd, &msg, 0);
 226        } while (ret < 0 && errno == EINTR);
 227
 228        if (ret < 0) {
 229                perror("recvmsg");
 230                exit(EXIT_FAILURE);
 231        }
 232
 233        return ret;
 234}
 235
 236static void add_vsock_stat(struct list_head *sockets,
 237                           const struct vsock_diag_msg *resp)
 238{
 239        struct vsock_stat *st;
 240
 241        st = malloc(sizeof(*st));
 242        if (!st) {
 243                perror("malloc");
 244                exit(EXIT_FAILURE);
 245        }
 246
 247        st->msg = *resp;
 248        list_add_tail(&st->list, sockets);
 249}
 250
 251/*
 252 * Read vsock stats into a list.
 253 */
 254static void read_vsock_stat(struct list_head *sockets)
 255{
 256        long buf[8192 / sizeof(long)];
 257        int fd;
 258
 259        fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
 260        if (fd < 0) {
 261                perror("socket");
 262                exit(EXIT_FAILURE);
 263        }
 264
 265        send_req(fd);
 266
 267        for (;;) {
 268                const struct nlmsghdr *h;
 269                ssize_t ret;
 270
 271                ret = recv_resp(fd, buf, sizeof(buf));
 272                if (ret == 0)
 273                        goto done;
 274                if (ret < sizeof(*h)) {
 275                        fprintf(stderr, "short read of %zd bytes\n", ret);
 276                        exit(EXIT_FAILURE);
 277                }
 278
 279                h = (struct nlmsghdr *)buf;
 280
 281                while (NLMSG_OK(h, ret)) {
 282                        if (h->nlmsg_type == NLMSG_DONE)
 283                                goto done;
 284
 285                        if (h->nlmsg_type == NLMSG_ERROR) {
 286                                const struct nlmsgerr *err = NLMSG_DATA(h);
 287
 288                                if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
 289                                        fprintf(stderr, "NLMSG_ERROR\n");
 290                                else {
 291                                        errno = -err->error;
 292                                        perror("NLMSG_ERROR");
 293                                }
 294
 295                                exit(EXIT_FAILURE);
 296                        }
 297
 298                        if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
 299                                fprintf(stderr, "unexpected nlmsg_type %#x\n",
 300                                        h->nlmsg_type);
 301                                exit(EXIT_FAILURE);
 302                        }
 303                        if (h->nlmsg_len <
 304                            NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
 305                                fprintf(stderr, "short vsock_diag_msg\n");
 306                                exit(EXIT_FAILURE);
 307                        }
 308
 309                        add_vsock_stat(sockets, NLMSG_DATA(h));
 310
 311                        h = NLMSG_NEXT(h, ret);
 312                }
 313        }
 314
 315done:
 316        close(fd);
 317}
 318
 319static void free_sock_stat(struct list_head *sockets)
 320{
 321        struct vsock_stat *st;
 322        struct vsock_stat *next;
 323
 324        list_for_each_entry_safe(st, next, sockets, list)
 325                free(st);
 326}
 327
 328static void test_no_sockets(const struct test_opts *opts)
 329{
 330        LIST_HEAD(sockets);
 331
 332        read_vsock_stat(&sockets);
 333
 334        check_no_sockets(&sockets);
 335
 336        free_sock_stat(&sockets);
 337}
 338
 339static void test_listen_socket_server(const struct test_opts *opts)
 340{
 341        union {
 342                struct sockaddr sa;
 343                struct sockaddr_vm svm;
 344        } addr = {
 345                .svm = {
 346                        .svm_family = AF_VSOCK,
 347                        .svm_port = 1234,
 348                        .svm_cid = VMADDR_CID_ANY,
 349                },
 350        };
 351        LIST_HEAD(sockets);
 352        struct vsock_stat *st;
 353        int fd;
 354
 355        fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 356
 357        if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
 358                perror("bind");
 359                exit(EXIT_FAILURE);
 360        }
 361
 362        if (listen(fd, 1) < 0) {
 363                perror("listen");
 364                exit(EXIT_FAILURE);
 365        }
 366
 367        read_vsock_stat(&sockets);
 368
 369        check_num_sockets(&sockets, 1);
 370        st = find_vsock_stat(&sockets, fd);
 371        check_socket_state(st, TCP_LISTEN);
 372
 373        close(fd);
 374        free_sock_stat(&sockets);
 375}
 376
 377static void test_connect_client(const struct test_opts *opts)
 378{
 379        int fd;
 380        LIST_HEAD(sockets);
 381        struct vsock_stat *st;
 382
 383        fd = vsock_stream_connect(opts->peer_cid, 1234);
 384        if (fd < 0) {
 385                perror("connect");
 386                exit(EXIT_FAILURE);
 387        }
 388
 389        read_vsock_stat(&sockets);
 390
 391        check_num_sockets(&sockets, 1);
 392        st = find_vsock_stat(&sockets, fd);
 393        check_socket_state(st, TCP_ESTABLISHED);
 394
 395        control_expectln("DONE");
 396        control_writeln("DONE");
 397
 398        close(fd);
 399        free_sock_stat(&sockets);
 400}
 401
 402static void test_connect_server(const struct test_opts *opts)
 403{
 404        struct vsock_stat *st;
 405        LIST_HEAD(sockets);
 406        int client_fd;
 407
 408        client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
 409        if (client_fd < 0) {
 410                perror("accept");
 411                exit(EXIT_FAILURE);
 412        }
 413
 414        read_vsock_stat(&sockets);
 415
 416        check_num_sockets(&sockets, 1);
 417        st = find_vsock_stat(&sockets, client_fd);
 418        check_socket_state(st, TCP_ESTABLISHED);
 419
 420        control_writeln("DONE");
 421        control_expectln("DONE");
 422
 423        close(client_fd);
 424        free_sock_stat(&sockets);
 425}
 426
 427static struct test_case test_cases[] = {
 428        {
 429                .name = "No sockets",
 430                .run_server = test_no_sockets,
 431        },
 432        {
 433                .name = "Listen socket",
 434                .run_server = test_listen_socket_server,
 435        },
 436        {
 437                .name = "Connect",
 438                .run_client = test_connect_client,
 439                .run_server = test_connect_server,
 440        },
 441        {},
 442};
 443
 444static const char optstring[] = "";
 445static const struct option longopts[] = {
 446        {
 447                .name = "control-host",
 448                .has_arg = required_argument,
 449                .val = 'H',
 450        },
 451        {
 452                .name = "control-port",
 453                .has_arg = required_argument,
 454                .val = 'P',
 455        },
 456        {
 457                .name = "mode",
 458                .has_arg = required_argument,
 459                .val = 'm',
 460        },
 461        {
 462                .name = "peer-cid",
 463                .has_arg = required_argument,
 464                .val = 'p',
 465        },
 466        {
 467                .name = "list",
 468                .has_arg = no_argument,
 469                .val = 'l',
 470        },
 471        {
 472                .name = "skip",
 473                .has_arg = required_argument,
 474                .val = 's',
 475        },
 476        {
 477                .name = "help",
 478                .has_arg = no_argument,
 479                .val = '?',
 480        },
 481        {},
 482};
 483
 484static void usage(void)
 485{
 486        fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
 487                "\n"
 488                "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
 489                "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
 490                "\n"
 491                "Run vsock_diag.ko tests.  Must be launched in both\n"
 492                "guest and host.  One side must use --mode=client and\n"
 493                "the other side must use --mode=server.\n"
 494                "\n"
 495                "A TCP control socket connection is used to coordinate tests\n"
 496                "between the client and the server.  The server requires a\n"
 497                "listen address and the client requires an address to\n"
 498                "connect to.\n"
 499                "\n"
 500                "The CID of the other side must be given with --peer-cid=<cid>.\n"
 501                "\n"
 502                "Options:\n"
 503                "  --help                 This help message\n"
 504                "  --control-host <host>  Server IP address to connect to\n"
 505                "  --control-port <port>  Server port to listen on/connect to\n"
 506                "  --mode client|server   Server or client mode\n"
 507                "  --peer-cid <cid>       CID of the other side\n"
 508                "  --list                 List of tests that will be executed\n"
 509                "  --skip <test_id>       Test ID to skip;\n"
 510                "                         use multiple --skip options to skip more tests\n"
 511                );
 512        exit(EXIT_FAILURE);
 513}
 514
 515int main(int argc, char **argv)
 516{
 517        const char *control_host = NULL;
 518        const char *control_port = NULL;
 519        struct test_opts opts = {
 520                .mode = TEST_MODE_UNSET,
 521                .peer_cid = VMADDR_CID_ANY,
 522        };
 523
 524        init_signals();
 525
 526        for (;;) {
 527                int opt = getopt_long(argc, argv, optstring, longopts, NULL);
 528
 529                if (opt == -1)
 530                        break;
 531
 532                switch (opt) {
 533                case 'H':
 534                        control_host = optarg;
 535                        break;
 536                case 'm':
 537                        if (strcmp(optarg, "client") == 0)
 538                                opts.mode = TEST_MODE_CLIENT;
 539                        else if (strcmp(optarg, "server") == 0)
 540                                opts.mode = TEST_MODE_SERVER;
 541                        else {
 542                                fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
 543                                return EXIT_FAILURE;
 544                        }
 545                        break;
 546                case 'p':
 547                        opts.peer_cid = parse_cid(optarg);
 548                        break;
 549                case 'P':
 550                        control_port = optarg;
 551                        break;
 552                case 'l':
 553                        list_tests(test_cases);
 554                        break;
 555                case 's':
 556                        skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
 557                                  optarg);
 558                        break;
 559                case '?':
 560                default:
 561                        usage();
 562                }
 563        }
 564
 565        if (!control_port)
 566                usage();
 567        if (opts.mode == TEST_MODE_UNSET)
 568                usage();
 569        if (opts.peer_cid == VMADDR_CID_ANY)
 570                usage();
 571
 572        if (!control_host) {
 573                if (opts.mode != TEST_MODE_SERVER)
 574                        usage();
 575                control_host = "0.0.0.0";
 576        }
 577
 578        control_init(control_host, control_port,
 579                     opts.mode == TEST_MODE_SERVER);
 580
 581        run_tests(test_cases, &opts);
 582
 583        control_cleanup();
 584        return EXIT_SUCCESS;
 585}
 586