linux/tools/testing/vsock/vsock_diag_test.c
<<
>>
Prefs
   1/*
   2 * vsock_diag_test - vsock_diag.ko test suite
   3 *
   4 * Copyright (C) 2017 Red Hat, Inc.
   5 *
   6 * Author: Stefan Hajnoczi <stefanha@redhat.com>
   7 *
   8 * This program is free software; you can redistribute it and/or
   9 * modify it under the terms of the GNU General Public License
  10 * as published by the Free Software Foundation; version 2
  11 * of the License.
  12 */
  13
  14#include <getopt.h>
  15#include <stdio.h>
  16#include <stdbool.h>
  17#include <stdlib.h>
  18#include <string.h>
  19#include <errno.h>
  20#include <unistd.h>
  21#include <signal.h>
  22#include <sys/socket.h>
  23#include <sys/stat.h>
  24#include <sys/types.h>
  25#include <linux/list.h>
  26#include <linux/net.h>
  27#include <linux/netlink.h>
  28#include <linux/sock_diag.h>
  29#include <netinet/tcp.h>
  30
  31#include "../../../include/uapi/linux/vm_sockets.h"
  32#include "../../../include/uapi/linux/vm_sockets_diag.h"
  33
  34#include "timeout.h"
  35#include "control.h"
  36
  37enum test_mode {
  38        TEST_MODE_UNSET,
  39        TEST_MODE_CLIENT,
  40        TEST_MODE_SERVER
  41};
  42
  43/* Per-socket status */
  44struct vsock_stat {
  45        struct list_head list;
  46        struct vsock_diag_msg msg;
  47};
  48
  49static const char *sock_type_str(int type)
  50{
  51        switch (type) {
  52        case SOCK_DGRAM:
  53                return "DGRAM";
  54        case SOCK_STREAM:
  55                return "STREAM";
  56        default:
  57                return "INVALID TYPE";
  58        }
  59}
  60
  61static const char *sock_state_str(int state)
  62{
  63        switch (state) {
  64        case TCP_CLOSE:
  65                return "UNCONNECTED";
  66        case TCP_SYN_SENT:
  67                return "CONNECTING";
  68        case TCP_ESTABLISHED:
  69                return "CONNECTED";
  70        case TCP_CLOSING:
  71                return "DISCONNECTING";
  72        case TCP_LISTEN:
  73                return "LISTEN";
  74        default:
  75                return "INVALID STATE";
  76        }
  77}
  78
  79static const char *sock_shutdown_str(int shutdown)
  80{
  81        switch (shutdown) {
  82        case 1:
  83                return "RCV_SHUTDOWN";
  84        case 2:
  85                return "SEND_SHUTDOWN";
  86        case 3:
  87                return "RCV_SHUTDOWN | SEND_SHUTDOWN";
  88        default:
  89                return "0";
  90        }
  91}
  92
  93static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
  94{
  95        if (cid == VMADDR_CID_ANY)
  96                fprintf(fp, "*:");
  97        else
  98                fprintf(fp, "%u:", cid);
  99
 100        if (port == VMADDR_PORT_ANY)
 101                fprintf(fp, "*");
 102        else
 103                fprintf(fp, "%u", port);
 104}
 105
 106static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
 107{
 108        print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
 109        fprintf(fp, " ");
 110        print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
 111        fprintf(fp, " %s %s %s %u\n",
 112                sock_type_str(st->msg.vdiag_type),
 113                sock_state_str(st->msg.vdiag_state),
 114                sock_shutdown_str(st->msg.vdiag_shutdown),
 115                st->msg.vdiag_ino);
 116}
 117
 118static void print_vsock_stats(FILE *fp, struct list_head *head)
 119{
 120        struct vsock_stat *st;
 121
 122        list_for_each_entry(st, head, list)
 123                print_vsock_stat(fp, st);
 124}
 125
 126static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
 127{
 128        struct vsock_stat *st;
 129        struct stat stat;
 130
 131        if (fstat(fd, &stat) < 0) {
 132                perror("fstat");
 133                exit(EXIT_FAILURE);
 134        }
 135
 136        list_for_each_entry(st, head, list)
 137                if (st->msg.vdiag_ino == stat.st_ino)
 138                        return st;
 139
 140        fprintf(stderr, "cannot find fd %d\n", fd);
 141        exit(EXIT_FAILURE);
 142}
 143
 144static void check_no_sockets(struct list_head *head)
 145{
 146        if (!list_empty(head)) {
 147                fprintf(stderr, "expected no sockets\n");
 148                print_vsock_stats(stderr, head);
 149                exit(1);
 150        }
 151}
 152
 153static void check_num_sockets(struct list_head *head, int expected)
 154{
 155        struct list_head *node;
 156        int n = 0;
 157
 158        list_for_each(node, head)
 159                n++;
 160
 161        if (n != expected) {
 162                fprintf(stderr, "expected %d sockets, found %d\n",
 163                        expected, n);
 164                print_vsock_stats(stderr, head);
 165                exit(EXIT_FAILURE);
 166        }
 167}
 168
 169static void check_socket_state(struct vsock_stat *st, __u8 state)
 170{
 171        if (st->msg.vdiag_state != state) {
 172                fprintf(stderr, "expected socket state %#x, got %#x\n",
 173                        state, st->msg.vdiag_state);
 174                exit(EXIT_FAILURE);
 175        }
 176}
 177
 178static void send_req(int fd)
 179{
 180        struct sockaddr_nl nladdr = {
 181                .nl_family = AF_NETLINK,
 182        };
 183        struct {
 184                struct nlmsghdr nlh;
 185                struct vsock_diag_req vreq;
 186        } req = {
 187                .nlh = {
 188                        .nlmsg_len = sizeof(req),
 189                        .nlmsg_type = SOCK_DIAG_BY_FAMILY,
 190                        .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
 191                },
 192                .vreq = {
 193                        .sdiag_family = AF_VSOCK,
 194                        .vdiag_states = ~(__u32)0,
 195                },
 196        };
 197        struct iovec iov = {
 198                .iov_base = &req,
 199                .iov_len = sizeof(req),
 200        };
 201        struct msghdr msg = {
 202                .msg_name = &nladdr,
 203                .msg_namelen = sizeof(nladdr),
 204                .msg_iov = &iov,
 205                .msg_iovlen = 1,
 206        };
 207
 208        for (;;) {
 209                if (sendmsg(fd, &msg, 0) < 0) {
 210                        if (errno == EINTR)
 211                                continue;
 212
 213                        perror("sendmsg");
 214                        exit(EXIT_FAILURE);
 215                }
 216
 217                return;
 218        }
 219}
 220
 221static ssize_t recv_resp(int fd, void *buf, size_t len)
 222{
 223        struct sockaddr_nl nladdr = {
 224                .nl_family = AF_NETLINK,
 225        };
 226        struct iovec iov = {
 227                .iov_base = buf,
 228                .iov_len = len,
 229        };
 230        struct msghdr msg = {
 231                .msg_name = &nladdr,
 232                .msg_namelen = sizeof(nladdr),
 233                .msg_iov = &iov,
 234                .msg_iovlen = 1,
 235        };
 236        ssize_t ret;
 237
 238        do {
 239                ret = recvmsg(fd, &msg, 0);
 240        } while (ret < 0 && errno == EINTR);
 241
 242        if (ret < 0) {
 243                perror("recvmsg");
 244                exit(EXIT_FAILURE);
 245        }
 246
 247        return ret;
 248}
 249
 250static void add_vsock_stat(struct list_head *sockets,
 251                           const struct vsock_diag_msg *resp)
 252{
 253        struct vsock_stat *st;
 254
 255        st = malloc(sizeof(*st));
 256        if (!st) {
 257                perror("malloc");
 258                exit(EXIT_FAILURE);
 259        }
 260
 261        st->msg = *resp;
 262        list_add_tail(&st->list, sockets);
 263}
 264
 265/*
 266 * Read vsock stats into a list.
 267 */
 268static void read_vsock_stat(struct list_head *sockets)
 269{
 270        long buf[8192 / sizeof(long)];
 271        int fd;
 272
 273        fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
 274        if (fd < 0) {
 275                perror("socket");
 276                exit(EXIT_FAILURE);
 277        }
 278
 279        send_req(fd);
 280
 281        for (;;) {
 282                const struct nlmsghdr *h;
 283                ssize_t ret;
 284
 285                ret = recv_resp(fd, buf, sizeof(buf));
 286                if (ret == 0)
 287                        goto done;
 288                if (ret < sizeof(*h)) {
 289                        fprintf(stderr, "short read of %zd bytes\n", ret);
 290                        exit(EXIT_FAILURE);
 291                }
 292
 293                h = (struct nlmsghdr *)buf;
 294
 295                while (NLMSG_OK(h, ret)) {
 296                        if (h->nlmsg_type == NLMSG_DONE)
 297                                goto done;
 298
 299                        if (h->nlmsg_type == NLMSG_ERROR) {
 300                                const struct nlmsgerr *err = NLMSG_DATA(h);
 301
 302                                if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
 303                                        fprintf(stderr, "NLMSG_ERROR\n");
 304                                else {
 305                                        errno = -err->error;
 306                                        perror("NLMSG_ERROR");
 307                                }
 308
 309                                exit(EXIT_FAILURE);
 310                        }
 311
 312                        if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
 313                                fprintf(stderr, "unexpected nlmsg_type %#x\n",
 314                                        h->nlmsg_type);
 315                                exit(EXIT_FAILURE);
 316                        }
 317                        if (h->nlmsg_len <
 318                            NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
 319                                fprintf(stderr, "short vsock_diag_msg\n");
 320                                exit(EXIT_FAILURE);
 321                        }
 322
 323                        add_vsock_stat(sockets, NLMSG_DATA(h));
 324
 325                        h = NLMSG_NEXT(h, ret);
 326                }
 327        }
 328
 329done:
 330        close(fd);
 331}
 332
 333static void free_sock_stat(struct list_head *sockets)
 334{
 335        struct vsock_stat *st;
 336        struct vsock_stat *next;
 337
 338        list_for_each_entry_safe(st, next, sockets, list)
 339                free(st);
 340}
 341
 342static void test_no_sockets(unsigned int peer_cid)
 343{
 344        LIST_HEAD(sockets);
 345
 346        read_vsock_stat(&sockets);
 347
 348        check_no_sockets(&sockets);
 349
 350        free_sock_stat(&sockets);
 351}
 352
 353static void test_listen_socket_server(unsigned int peer_cid)
 354{
 355        union {
 356                struct sockaddr sa;
 357                struct sockaddr_vm svm;
 358        } addr = {
 359                .svm = {
 360                        .svm_family = AF_VSOCK,
 361                        .svm_port = 1234,
 362                        .svm_cid = VMADDR_CID_ANY,
 363                },
 364        };
 365        LIST_HEAD(sockets);
 366        struct vsock_stat *st;
 367        int fd;
 368
 369        fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 370
 371        if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
 372                perror("bind");
 373                exit(EXIT_FAILURE);
 374        }
 375
 376        if (listen(fd, 1) < 0) {
 377                perror("listen");
 378                exit(EXIT_FAILURE);
 379        }
 380
 381        read_vsock_stat(&sockets);
 382
 383        check_num_sockets(&sockets, 1);
 384        st = find_vsock_stat(&sockets, fd);
 385        check_socket_state(st, TCP_LISTEN);
 386
 387        close(fd);
 388        free_sock_stat(&sockets);
 389}
 390
 391static void test_connect_client(unsigned int peer_cid)
 392{
 393        union {
 394                struct sockaddr sa;
 395                struct sockaddr_vm svm;
 396        } addr = {
 397                .svm = {
 398                        .svm_family = AF_VSOCK,
 399                        .svm_port = 1234,
 400                        .svm_cid = peer_cid,
 401                },
 402        };
 403        int fd;
 404        int ret;
 405        LIST_HEAD(sockets);
 406        struct vsock_stat *st;
 407
 408        control_expectln("LISTENING");
 409
 410        fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 411
 412        timeout_begin(TIMEOUT);
 413        do {
 414                ret = connect(fd, &addr.sa, sizeof(addr.svm));
 415                timeout_check("connect");
 416        } while (ret < 0 && errno == EINTR);
 417        timeout_end();
 418
 419        if (ret < 0) {
 420                perror("connect");
 421                exit(EXIT_FAILURE);
 422        }
 423
 424        read_vsock_stat(&sockets);
 425
 426        check_num_sockets(&sockets, 1);
 427        st = find_vsock_stat(&sockets, fd);
 428        check_socket_state(st, TCP_ESTABLISHED);
 429
 430        control_expectln("DONE");
 431        control_writeln("DONE");
 432
 433        close(fd);
 434        free_sock_stat(&sockets);
 435}
 436
 437static void test_connect_server(unsigned int peer_cid)
 438{
 439        union {
 440                struct sockaddr sa;
 441                struct sockaddr_vm svm;
 442        } addr = {
 443                .svm = {
 444                        .svm_family = AF_VSOCK,
 445                        .svm_port = 1234,
 446                        .svm_cid = VMADDR_CID_ANY,
 447                },
 448        };
 449        union {
 450                struct sockaddr sa;
 451                struct sockaddr_vm svm;
 452        } clientaddr;
 453        socklen_t clientaddr_len = sizeof(clientaddr.svm);
 454        LIST_HEAD(sockets);
 455        struct vsock_stat *st;
 456        int fd;
 457        int client_fd;
 458
 459        fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 460
 461        if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
 462                perror("bind");
 463                exit(EXIT_FAILURE);
 464        }
 465
 466        if (listen(fd, 1) < 0) {
 467                perror("listen");
 468                exit(EXIT_FAILURE);
 469        }
 470
 471        control_writeln("LISTENING");
 472
 473        timeout_begin(TIMEOUT);
 474        do {
 475                client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
 476                timeout_check("accept");
 477        } while (client_fd < 0 && errno == EINTR);
 478        timeout_end();
 479
 480        if (client_fd < 0) {
 481                perror("accept");
 482                exit(EXIT_FAILURE);
 483        }
 484        if (clientaddr.sa.sa_family != AF_VSOCK) {
 485                fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
 486                        clientaddr.sa.sa_family);
 487                exit(EXIT_FAILURE);
 488        }
 489        if (clientaddr.svm.svm_cid != peer_cid) {
 490                fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
 491                        peer_cid, clientaddr.svm.svm_cid);
 492                exit(EXIT_FAILURE);
 493        }
 494
 495        read_vsock_stat(&sockets);
 496
 497        check_num_sockets(&sockets, 2);
 498        find_vsock_stat(&sockets, fd);
 499        st = find_vsock_stat(&sockets, client_fd);
 500        check_socket_state(st, TCP_ESTABLISHED);
 501
 502        control_writeln("DONE");
 503        control_expectln("DONE");
 504
 505        close(client_fd);
 506        close(fd);
 507        free_sock_stat(&sockets);
 508}
 509
 510static struct {
 511        const char *name;
 512        void (*run_client)(unsigned int peer_cid);
 513        void (*run_server)(unsigned int peer_cid);
 514} test_cases[] = {
 515        {
 516                .name = "No sockets",
 517                .run_server = test_no_sockets,
 518        },
 519        {
 520                .name = "Listen socket",
 521                .run_server = test_listen_socket_server,
 522        },
 523        {
 524                .name = "Connect",
 525                .run_client = test_connect_client,
 526                .run_server = test_connect_server,
 527        },
 528        {},
 529};
 530
 531static void init_signals(void)
 532{
 533        struct sigaction act = {
 534                .sa_handler = sigalrm,
 535        };
 536
 537        sigaction(SIGALRM, &act, NULL);
 538        signal(SIGPIPE, SIG_IGN);
 539}
 540
 541static unsigned int parse_cid(const char *str)
 542{
 543        char *endptr = NULL;
 544        unsigned long int n;
 545
 546        errno = 0;
 547        n = strtoul(str, &endptr, 10);
 548        if (errno || *endptr != '\0') {
 549                fprintf(stderr, "malformed CID \"%s\"\n", str);
 550                exit(EXIT_FAILURE);
 551        }
 552        return n;
 553}
 554
 555static const char optstring[] = "";
 556static const struct option longopts[] = {
 557        {
 558                .name = "control-host",
 559                .has_arg = required_argument,
 560                .val = 'H',
 561        },
 562        {
 563                .name = "control-port",
 564                .has_arg = required_argument,
 565                .val = 'P',
 566        },
 567        {
 568                .name = "mode",
 569                .has_arg = required_argument,
 570                .val = 'm',
 571        },
 572        {
 573                .name = "peer-cid",
 574                .has_arg = required_argument,
 575                .val = 'p',
 576        },
 577        {
 578                .name = "help",
 579                .has_arg = no_argument,
 580                .val = '?',
 581        },
 582        {},
 583};
 584
 585static void usage(void)
 586{
 587        fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
 588                "\n"
 589                "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
 590                "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
 591                "\n"
 592                "Run vsock_diag.ko tests.  Must be launched in both\n"
 593                "guest and host.  One side must use --mode=client and\n"
 594                "the other side must use --mode=server.\n"
 595                "\n"
 596                "A TCP control socket connection is used to coordinate tests\n"
 597                "between the client and the server.  The server requires a\n"
 598                "listen address and the client requires an address to\n"
 599                "connect to.\n"
 600                "\n"
 601                "The CID of the other side must be given with --peer-cid=<cid>.\n");
 602        exit(EXIT_FAILURE);
 603}
 604
 605int main(int argc, char **argv)
 606{
 607        const char *control_host = NULL;
 608        const char *control_port = NULL;
 609        int mode = TEST_MODE_UNSET;
 610        unsigned int peer_cid = VMADDR_CID_ANY;
 611        int i;
 612
 613        init_signals();
 614
 615        for (;;) {
 616                int opt = getopt_long(argc, argv, optstring, longopts, NULL);
 617
 618                if (opt == -1)
 619                        break;
 620
 621                switch (opt) {
 622                case 'H':
 623                        control_host = optarg;
 624                        break;
 625                case 'm':
 626                        if (strcmp(optarg, "client") == 0)
 627                                mode = TEST_MODE_CLIENT;
 628                        else if (strcmp(optarg, "server") == 0)
 629                                mode = TEST_MODE_SERVER;
 630                        else {
 631                                fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
 632                                return EXIT_FAILURE;
 633                        }
 634                        break;
 635                case 'p':
 636                        peer_cid = parse_cid(optarg);
 637                        break;
 638                case 'P':
 639                        control_port = optarg;
 640                        break;
 641                case '?':
 642                default:
 643                        usage();
 644                }
 645        }
 646
 647        if (!control_port)
 648                usage();
 649        if (mode == TEST_MODE_UNSET)
 650                usage();
 651        if (peer_cid == VMADDR_CID_ANY)
 652                usage();
 653
 654        if (!control_host) {
 655                if (mode != TEST_MODE_SERVER)
 656                        usage();
 657                control_host = "0.0.0.0";
 658        }
 659
 660        control_init(control_host, control_port, mode == TEST_MODE_SERVER);
 661
 662        for (i = 0; test_cases[i].name; i++) {
 663                void (*run)(unsigned int peer_cid);
 664
 665                printf("%s...", test_cases[i].name);
 666                fflush(stdout);
 667
 668                if (mode == TEST_MODE_CLIENT)
 669                        run = test_cases[i].run_client;
 670                else
 671                        run = test_cases[i].run_server;
 672
 673                if (run)
 674                        run(peer_cid);
 675
 676                printf("ok\n");
 677        }
 678
 679        control_cleanup();
 680        return EXIT_SUCCESS;
 681}
 682