linux/tools/testing/vsock/vsock_diag_test.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * vsock_diag_test - vsock_diag.ko test suite
   4 *
   5 * Copyright (C) 2017 Red Hat, Inc.
   6 *
   7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
   8 */
   9
  10#include <getopt.h>
  11#include <stdio.h>
  12#include <stdlib.h>
  13#include <string.h>
  14#include <errno.h>
  15#include <unistd.h>
  16#include <sys/stat.h>
  17#include <sys/types.h>
  18#include <linux/list.h>
  19#include <linux/net.h>
  20#include <linux/netlink.h>
  21#include <linux/sock_diag.h>
  22#include <linux/vm_sockets_diag.h>
  23#include <netinet/tcp.h>
  24
  25#include "timeout.h"
  26#include "control.h"
  27#include "util.h"
  28
  29/* Per-socket status */
  30struct vsock_stat {
  31        struct list_head list;
  32        struct vsock_diag_msg msg;
  33};
  34
  35static const char *sock_type_str(int type)
  36{
  37        switch (type) {
  38        case SOCK_DGRAM:
  39                return "DGRAM";
  40        case SOCK_STREAM:
  41                return "STREAM";
  42        default:
  43                return "INVALID TYPE";
  44        }
  45}
  46
  47static const char *sock_state_str(int state)
  48{
  49        switch (state) {
  50        case TCP_CLOSE:
  51                return "UNCONNECTED";
  52        case TCP_SYN_SENT:
  53                return "CONNECTING";
  54        case TCP_ESTABLISHED:
  55                return "CONNECTED";
  56        case TCP_CLOSING:
  57                return "DISCONNECTING";
  58        case TCP_LISTEN:
  59                return "LISTEN";
  60        default:
  61                return "INVALID STATE";
  62        }
  63}
  64
  65static const char *sock_shutdown_str(int shutdown)
  66{
  67        switch (shutdown) {
  68        case 1:
  69                return "RCV_SHUTDOWN";
  70        case 2:
  71                return "SEND_SHUTDOWN";
  72        case 3:
  73                return "RCV_SHUTDOWN | SEND_SHUTDOWN";
  74        default:
  75                return "0";
  76        }
  77}
  78
  79static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
  80{
  81        if (cid == VMADDR_CID_ANY)
  82                fprintf(fp, "*:");
  83        else
  84                fprintf(fp, "%u:", cid);
  85
  86        if (port == VMADDR_PORT_ANY)
  87                fprintf(fp, "*");
  88        else
  89                fprintf(fp, "%u", port);
  90}
  91
  92static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
  93{
  94        print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
  95        fprintf(fp, " ");
  96        print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
  97        fprintf(fp, " %s %s %s %u\n",
  98                sock_type_str(st->msg.vdiag_type),
  99                sock_state_str(st->msg.vdiag_state),
 100                sock_shutdown_str(st->msg.vdiag_shutdown),
 101                st->msg.vdiag_ino);
 102}
 103
 104static void print_vsock_stats(FILE *fp, struct list_head *head)
 105{
 106        struct vsock_stat *st;
 107
 108        list_for_each_entry(st, head, list)
 109                print_vsock_stat(fp, st);
 110}
 111
 112static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
 113{
 114        struct vsock_stat *st;
 115        struct stat stat;
 116
 117        if (fstat(fd, &stat) < 0) {
 118                perror("fstat");
 119                exit(EXIT_FAILURE);
 120        }
 121
 122        list_for_each_entry(st, head, list)
 123                if (st->msg.vdiag_ino == stat.st_ino)
 124                        return st;
 125
 126        fprintf(stderr, "cannot find fd %d\n", fd);
 127        exit(EXIT_FAILURE);
 128}
 129
 130static void check_no_sockets(struct list_head *head)
 131{
 132        if (!list_empty(head)) {
 133                fprintf(stderr, "expected no sockets\n");
 134                print_vsock_stats(stderr, head);
 135                exit(1);
 136        }
 137}
 138
 139static void check_num_sockets(struct list_head *head, int expected)
 140{
 141        struct list_head *node;
 142        int n = 0;
 143
 144        list_for_each(node, head)
 145                n++;
 146
 147        if (n != expected) {
 148                fprintf(stderr, "expected %d sockets, found %d\n",
 149                        expected, n);
 150                print_vsock_stats(stderr, head);
 151                exit(EXIT_FAILURE);
 152        }
 153}
 154
 155static void check_socket_state(struct vsock_stat *st, __u8 state)
 156{
 157        if (st->msg.vdiag_state != state) {
 158                fprintf(stderr, "expected socket state %#x, got %#x\n",
 159                        state, st->msg.vdiag_state);
 160                exit(EXIT_FAILURE);
 161        }
 162}
 163
 164static void send_req(int fd)
 165{
 166        struct sockaddr_nl nladdr = {
 167                .nl_family = AF_NETLINK,
 168        };
 169        struct {
 170                struct nlmsghdr nlh;
 171                struct vsock_diag_req vreq;
 172        } req = {
 173                .nlh = {
 174                        .nlmsg_len = sizeof(req),
 175                        .nlmsg_type = SOCK_DIAG_BY_FAMILY,
 176                        .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
 177                },
 178                .vreq = {
 179                        .sdiag_family = AF_VSOCK,
 180                        .vdiag_states = ~(__u32)0,
 181                },
 182        };
 183        struct iovec iov = {
 184                .iov_base = &req,
 185                .iov_len = sizeof(req),
 186        };
 187        struct msghdr msg = {
 188                .msg_name = &nladdr,
 189                .msg_namelen = sizeof(nladdr),
 190                .msg_iov = &iov,
 191                .msg_iovlen = 1,
 192        };
 193
 194        for (;;) {
 195                if (sendmsg(fd, &msg, 0) < 0) {
 196                        if (errno == EINTR)
 197                                continue;
 198
 199                        perror("sendmsg");
 200                        exit(EXIT_FAILURE);
 201                }
 202
 203                return;
 204        }
 205}
 206
 207static ssize_t recv_resp(int fd, void *buf, size_t len)
 208{
 209        struct sockaddr_nl nladdr = {
 210                .nl_family = AF_NETLINK,
 211        };
 212        struct iovec iov = {
 213                .iov_base = buf,
 214                .iov_len = len,
 215        };
 216        struct msghdr msg = {
 217                .msg_name = &nladdr,
 218                .msg_namelen = sizeof(nladdr),
 219                .msg_iov = &iov,
 220                .msg_iovlen = 1,
 221        };
 222        ssize_t ret;
 223
 224        do {
 225                ret = recvmsg(fd, &msg, 0);
 226        } while (ret < 0 && errno == EINTR);
 227
 228        if (ret < 0) {
 229                perror("recvmsg");
 230                exit(EXIT_FAILURE);
 231        }
 232
 233        return ret;
 234}
 235
 236static void add_vsock_stat(struct list_head *sockets,
 237                           const struct vsock_diag_msg *resp)
 238{
 239        struct vsock_stat *st;
 240
 241        st = malloc(sizeof(*st));
 242        if (!st) {
 243                perror("malloc");
 244                exit(EXIT_FAILURE);
 245        }
 246
 247        st->msg = *resp;
 248        list_add_tail(&st->list, sockets);
 249}
 250
 251/*
 252 * Read vsock stats into a list.
 253 */
 254static void read_vsock_stat(struct list_head *sockets)
 255{
 256        long buf[8192 / sizeof(long)];
 257        int fd;
 258
 259        fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
 260        if (fd < 0) {
 261                perror("socket");
 262                exit(EXIT_FAILURE);
 263        }
 264
 265        send_req(fd);
 266
 267        for (;;) {
 268                const struct nlmsghdr *h;
 269                ssize_t ret;
 270
 271                ret = recv_resp(fd, buf, sizeof(buf));
 272                if (ret == 0)
 273                        goto done;
 274                if (ret < sizeof(*h)) {
 275                        fprintf(stderr, "short read of %zd bytes\n", ret);
 276                        exit(EXIT_FAILURE);
 277                }
 278
 279                h = (struct nlmsghdr *)buf;
 280
 281                while (NLMSG_OK(h, ret)) {
 282                        if (h->nlmsg_type == NLMSG_DONE)
 283                                goto done;
 284
 285                        if (h->nlmsg_type == NLMSG_ERROR) {
 286                                const struct nlmsgerr *err = NLMSG_DATA(h);
 287
 288                                if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
 289                                        fprintf(stderr, "NLMSG_ERROR\n");
 290                                else {
 291                                        errno = -err->error;
 292                                        perror("NLMSG_ERROR");
 293                                }
 294
 295                                exit(EXIT_FAILURE);
 296                        }
 297
 298                        if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
 299                                fprintf(stderr, "unexpected nlmsg_type %#x\n",
 300                                        h->nlmsg_type);
 301                                exit(EXIT_FAILURE);
 302                        }
 303                        if (h->nlmsg_len <
 304                            NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
 305                                fprintf(stderr, "short vsock_diag_msg\n");
 306                                exit(EXIT_FAILURE);
 307                        }
 308
 309                        add_vsock_stat(sockets, NLMSG_DATA(h));
 310
 311                        h = NLMSG_NEXT(h, ret);
 312                }
 313        }
 314
 315done:
 316        close(fd);
 317}
 318
 319static void free_sock_stat(struct list_head *sockets)
 320{
 321        struct vsock_stat *st;
 322        struct vsock_stat *next;
 323
 324        list_for_each_entry_safe(st, next, sockets, list)
 325                free(st);
 326}
 327
 328static void test_no_sockets(const struct test_opts *opts)
 329{
 330        LIST_HEAD(sockets);
 331
 332        read_vsock_stat(&sockets);
 333
 334        check_no_sockets(&sockets);
 335}
 336
 337static void test_listen_socket_server(const struct test_opts *opts)
 338{
 339        union {
 340                struct sockaddr sa;
 341                struct sockaddr_vm svm;
 342        } addr = {
 343                .svm = {
 344                        .svm_family = AF_VSOCK,
 345                        .svm_port = 1234,
 346                        .svm_cid = VMADDR_CID_ANY,
 347                },
 348        };
 349        LIST_HEAD(sockets);
 350        struct vsock_stat *st;
 351        int fd;
 352
 353        fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 354
 355        if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
 356                perror("bind");
 357                exit(EXIT_FAILURE);
 358        }
 359
 360        if (listen(fd, 1) < 0) {
 361                perror("listen");
 362                exit(EXIT_FAILURE);
 363        }
 364
 365        read_vsock_stat(&sockets);
 366
 367        check_num_sockets(&sockets, 1);
 368        st = find_vsock_stat(&sockets, fd);
 369        check_socket_state(st, TCP_LISTEN);
 370
 371        close(fd);
 372        free_sock_stat(&sockets);
 373}
 374
 375static void test_connect_client(const struct test_opts *opts)
 376{
 377        int fd;
 378        LIST_HEAD(sockets);
 379        struct vsock_stat *st;
 380
 381        fd = vsock_stream_connect(opts->peer_cid, 1234);
 382        if (fd < 0) {
 383                perror("connect");
 384                exit(EXIT_FAILURE);
 385        }
 386
 387        read_vsock_stat(&sockets);
 388
 389        check_num_sockets(&sockets, 1);
 390        st = find_vsock_stat(&sockets, fd);
 391        check_socket_state(st, TCP_ESTABLISHED);
 392
 393        control_expectln("DONE");
 394        control_writeln("DONE");
 395
 396        close(fd);
 397        free_sock_stat(&sockets);
 398}
 399
 400static void test_connect_server(const struct test_opts *opts)
 401{
 402        struct vsock_stat *st;
 403        LIST_HEAD(sockets);
 404        int client_fd;
 405
 406        client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
 407        if (client_fd < 0) {
 408                perror("accept");
 409                exit(EXIT_FAILURE);
 410        }
 411
 412        read_vsock_stat(&sockets);
 413
 414        check_num_sockets(&sockets, 1);
 415        st = find_vsock_stat(&sockets, client_fd);
 416        check_socket_state(st, TCP_ESTABLISHED);
 417
 418        control_writeln("DONE");
 419        control_expectln("DONE");
 420
 421        close(client_fd);
 422        free_sock_stat(&sockets);
 423}
 424
 425static struct test_case test_cases[] = {
 426        {
 427                .name = "No sockets",
 428                .run_server = test_no_sockets,
 429        },
 430        {
 431                .name = "Listen socket",
 432                .run_server = test_listen_socket_server,
 433        },
 434        {
 435                .name = "Connect",
 436                .run_client = test_connect_client,
 437                .run_server = test_connect_server,
 438        },
 439        {},
 440};
 441
 442static const char optstring[] = "";
 443static const struct option longopts[] = {
 444        {
 445                .name = "control-host",
 446                .has_arg = required_argument,
 447                .val = 'H',
 448        },
 449        {
 450                .name = "control-port",
 451                .has_arg = required_argument,
 452                .val = 'P',
 453        },
 454        {
 455                .name = "mode",
 456                .has_arg = required_argument,
 457                .val = 'm',
 458        },
 459        {
 460                .name = "peer-cid",
 461                .has_arg = required_argument,
 462                .val = 'p',
 463        },
 464        {
 465                .name = "list",
 466                .has_arg = no_argument,
 467                .val = 'l',
 468        },
 469        {
 470                .name = "skip",
 471                .has_arg = required_argument,
 472                .val = 's',
 473        },
 474        {
 475                .name = "help",
 476                .has_arg = no_argument,
 477                .val = '?',
 478        },
 479        {},
 480};
 481
 482static void usage(void)
 483{
 484        fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
 485                "\n"
 486                "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
 487                "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
 488                "\n"
 489                "Run vsock_diag.ko tests.  Must be launched in both\n"
 490                "guest and host.  One side must use --mode=client and\n"
 491                "the other side must use --mode=server.\n"
 492                "\n"
 493                "A TCP control socket connection is used to coordinate tests\n"
 494                "between the client and the server.  The server requires a\n"
 495                "listen address and the client requires an address to\n"
 496                "connect to.\n"
 497                "\n"
 498                "The CID of the other side must be given with --peer-cid=<cid>.\n"
 499                "\n"
 500                "Options:\n"
 501                "  --help                 This help message\n"
 502                "  --control-host <host>  Server IP address to connect to\n"
 503                "  --control-port <port>  Server port to listen on/connect to\n"
 504                "  --mode client|server   Server or client mode\n"
 505                "  --peer-cid <cid>       CID of the other side\n"
 506                "  --list                 List of tests that will be executed\n"
 507                "  --skip <test_id>       Test ID to skip;\n"
 508                "                         use multiple --skip options to skip more tests\n"
 509                );
 510        exit(EXIT_FAILURE);
 511}
 512
 513int main(int argc, char **argv)
 514{
 515        const char *control_host = NULL;
 516        const char *control_port = NULL;
 517        struct test_opts opts = {
 518                .mode = TEST_MODE_UNSET,
 519                .peer_cid = VMADDR_CID_ANY,
 520        };
 521
 522        init_signals();
 523
 524        for (;;) {
 525                int opt = getopt_long(argc, argv, optstring, longopts, NULL);
 526
 527                if (opt == -1)
 528                        break;
 529
 530                switch (opt) {
 531                case 'H':
 532                        control_host = optarg;
 533                        break;
 534                case 'm':
 535                        if (strcmp(optarg, "client") == 0)
 536                                opts.mode = TEST_MODE_CLIENT;
 537                        else if (strcmp(optarg, "server") == 0)
 538                                opts.mode = TEST_MODE_SERVER;
 539                        else {
 540                                fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
 541                                return EXIT_FAILURE;
 542                        }
 543                        break;
 544                case 'p':
 545                        opts.peer_cid = parse_cid(optarg);
 546                        break;
 547                case 'P':
 548                        control_port = optarg;
 549                        break;
 550                case 'l':
 551                        list_tests(test_cases);
 552                        break;
 553                case 's':
 554                        skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
 555                                  optarg);
 556                        break;
 557                case '?':
 558                default:
 559                        usage();
 560                }
 561        }
 562
 563        if (!control_port)
 564                usage();
 565        if (opts.mode == TEST_MODE_UNSET)
 566                usage();
 567        if (opts.peer_cid == VMADDR_CID_ANY)
 568                usage();
 569
 570        if (!control_host) {
 571                if (opts.mode != TEST_MODE_SERVER)
 572                        usage();
 573                control_host = "0.0.0.0";
 574        }
 575
 576        control_init(control_host, control_port,
 577                     opts.mode == TEST_MODE_SERVER);
 578
 579        run_tests(test_cases, &opts);
 580
 581        control_cleanup();
 582        return EXIT_SUCCESS;
 583}
 584