1
2
3
4
5
6
7
8
9
10#include <getopt.h>
11#include <stdio.h>
12#include <stdbool.h>
13#include <stdlib.h>
14#include <string.h>
15#include <errno.h>
16#include <unistd.h>
17#include <signal.h>
18#include <sys/socket.h>
19#include <sys/stat.h>
20#include <sys/types.h>
21#include <linux/list.h>
22#include <linux/net.h>
23#include <linux/netlink.h>
24#include <linux/sock_diag.h>
25#include <netinet/tcp.h>
26
27#include "../../../include/uapi/linux/vm_sockets.h"
28#include "../../../include/uapi/linux/vm_sockets_diag.h"
29
30#include "timeout.h"
31#include "control.h"
32
33enum test_mode {
34 TEST_MODE_UNSET,
35 TEST_MODE_CLIENT,
36 TEST_MODE_SERVER
37};
38
39
40struct vsock_stat {
41 struct list_head list;
42 struct vsock_diag_msg msg;
43};
44
45static const char *sock_type_str(int type)
46{
47 switch (type) {
48 case SOCK_DGRAM:
49 return "DGRAM";
50 case SOCK_STREAM:
51 return "STREAM";
52 default:
53 return "INVALID TYPE";
54 }
55}
56
57static const char *sock_state_str(int state)
58{
59 switch (state) {
60 case TCP_CLOSE:
61 return "UNCONNECTED";
62 case TCP_SYN_SENT:
63 return "CONNECTING";
64 case TCP_ESTABLISHED:
65 return "CONNECTED";
66 case TCP_CLOSING:
67 return "DISCONNECTING";
68 case TCP_LISTEN:
69 return "LISTEN";
70 default:
71 return "INVALID STATE";
72 }
73}
74
75static const char *sock_shutdown_str(int shutdown)
76{
77 switch (shutdown) {
78 case 1:
79 return "RCV_SHUTDOWN";
80 case 2:
81 return "SEND_SHUTDOWN";
82 case 3:
83 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
84 default:
85 return "0";
86 }
87}
88
89static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
90{
91 if (cid == VMADDR_CID_ANY)
92 fprintf(fp, "*:");
93 else
94 fprintf(fp, "%u:", cid);
95
96 if (port == VMADDR_PORT_ANY)
97 fprintf(fp, "*");
98 else
99 fprintf(fp, "%u", port);
100}
101
102static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
103{
104 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
105 fprintf(fp, " ");
106 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
107 fprintf(fp, " %s %s %s %u\n",
108 sock_type_str(st->msg.vdiag_type),
109 sock_state_str(st->msg.vdiag_state),
110 sock_shutdown_str(st->msg.vdiag_shutdown),
111 st->msg.vdiag_ino);
112}
113
114static void print_vsock_stats(FILE *fp, struct list_head *head)
115{
116 struct vsock_stat *st;
117
118 list_for_each_entry(st, head, list)
119 print_vsock_stat(fp, st);
120}
121
122static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
123{
124 struct vsock_stat *st;
125 struct stat stat;
126
127 if (fstat(fd, &stat) < 0) {
128 perror("fstat");
129 exit(EXIT_FAILURE);
130 }
131
132 list_for_each_entry(st, head, list)
133 if (st->msg.vdiag_ino == stat.st_ino)
134 return st;
135
136 fprintf(stderr, "cannot find fd %d\n", fd);
137 exit(EXIT_FAILURE);
138}
139
140static void check_no_sockets(struct list_head *head)
141{
142 if (!list_empty(head)) {
143 fprintf(stderr, "expected no sockets\n");
144 print_vsock_stats(stderr, head);
145 exit(1);
146 }
147}
148
149static void check_num_sockets(struct list_head *head, int expected)
150{
151 struct list_head *node;
152 int n = 0;
153
154 list_for_each(node, head)
155 n++;
156
157 if (n != expected) {
158 fprintf(stderr, "expected %d sockets, found %d\n",
159 expected, n);
160 print_vsock_stats(stderr, head);
161 exit(EXIT_FAILURE);
162 }
163}
164
165static void check_socket_state(struct vsock_stat *st, __u8 state)
166{
167 if (st->msg.vdiag_state != state) {
168 fprintf(stderr, "expected socket state %#x, got %#x\n",
169 state, st->msg.vdiag_state);
170 exit(EXIT_FAILURE);
171 }
172}
173
174static void send_req(int fd)
175{
176 struct sockaddr_nl nladdr = {
177 .nl_family = AF_NETLINK,
178 };
179 struct {
180 struct nlmsghdr nlh;
181 struct vsock_diag_req vreq;
182 } req = {
183 .nlh = {
184 .nlmsg_len = sizeof(req),
185 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
186 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
187 },
188 .vreq = {
189 .sdiag_family = AF_VSOCK,
190 .vdiag_states = ~(__u32)0,
191 },
192 };
193 struct iovec iov = {
194 .iov_base = &req,
195 .iov_len = sizeof(req),
196 };
197 struct msghdr msg = {
198 .msg_name = &nladdr,
199 .msg_namelen = sizeof(nladdr),
200 .msg_iov = &iov,
201 .msg_iovlen = 1,
202 };
203
204 for (;;) {
205 if (sendmsg(fd, &msg, 0) < 0) {
206 if (errno == EINTR)
207 continue;
208
209 perror("sendmsg");
210 exit(EXIT_FAILURE);
211 }
212
213 return;
214 }
215}
216
217static ssize_t recv_resp(int fd, void *buf, size_t len)
218{
219 struct sockaddr_nl nladdr = {
220 .nl_family = AF_NETLINK,
221 };
222 struct iovec iov = {
223 .iov_base = buf,
224 .iov_len = len,
225 };
226 struct msghdr msg = {
227 .msg_name = &nladdr,
228 .msg_namelen = sizeof(nladdr),
229 .msg_iov = &iov,
230 .msg_iovlen = 1,
231 };
232 ssize_t ret;
233
234 do {
235 ret = recvmsg(fd, &msg, 0);
236 } while (ret < 0 && errno == EINTR);
237
238 if (ret < 0) {
239 perror("recvmsg");
240 exit(EXIT_FAILURE);
241 }
242
243 return ret;
244}
245
246static void add_vsock_stat(struct list_head *sockets,
247 const struct vsock_diag_msg *resp)
248{
249 struct vsock_stat *st;
250
251 st = malloc(sizeof(*st));
252 if (!st) {
253 perror("malloc");
254 exit(EXIT_FAILURE);
255 }
256
257 st->msg = *resp;
258 list_add_tail(&st->list, sockets);
259}
260
261
262
263
264static void read_vsock_stat(struct list_head *sockets)
265{
266 long buf[8192 / sizeof(long)];
267 int fd;
268
269 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
270 if (fd < 0) {
271 perror("socket");
272 exit(EXIT_FAILURE);
273 }
274
275 send_req(fd);
276
277 for (;;) {
278 const struct nlmsghdr *h;
279 ssize_t ret;
280
281 ret = recv_resp(fd, buf, sizeof(buf));
282 if (ret == 0)
283 goto done;
284 if (ret < sizeof(*h)) {
285 fprintf(stderr, "short read of %zd bytes\n", ret);
286 exit(EXIT_FAILURE);
287 }
288
289 h = (struct nlmsghdr *)buf;
290
291 while (NLMSG_OK(h, ret)) {
292 if (h->nlmsg_type == NLMSG_DONE)
293 goto done;
294
295 if (h->nlmsg_type == NLMSG_ERROR) {
296 const struct nlmsgerr *err = NLMSG_DATA(h);
297
298 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
299 fprintf(stderr, "NLMSG_ERROR\n");
300 else {
301 errno = -err->error;
302 perror("NLMSG_ERROR");
303 }
304
305 exit(EXIT_FAILURE);
306 }
307
308 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
309 fprintf(stderr, "unexpected nlmsg_type %#x\n",
310 h->nlmsg_type);
311 exit(EXIT_FAILURE);
312 }
313 if (h->nlmsg_len <
314 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
315 fprintf(stderr, "short vsock_diag_msg\n");
316 exit(EXIT_FAILURE);
317 }
318
319 add_vsock_stat(sockets, NLMSG_DATA(h));
320
321 h = NLMSG_NEXT(h, ret);
322 }
323 }
324
325done:
326 close(fd);
327}
328
329static void free_sock_stat(struct list_head *sockets)
330{
331 struct vsock_stat *st;
332 struct vsock_stat *next;
333
334 list_for_each_entry_safe(st, next, sockets, list)
335 free(st);
336}
337
338static void test_no_sockets(unsigned int peer_cid)
339{
340 LIST_HEAD(sockets);
341
342 read_vsock_stat(&sockets);
343
344 check_no_sockets(&sockets);
345
346 free_sock_stat(&sockets);
347}
348
349static void test_listen_socket_server(unsigned int peer_cid)
350{
351 union {
352 struct sockaddr sa;
353 struct sockaddr_vm svm;
354 } addr = {
355 .svm = {
356 .svm_family = AF_VSOCK,
357 .svm_port = 1234,
358 .svm_cid = VMADDR_CID_ANY,
359 },
360 };
361 LIST_HEAD(sockets);
362 struct vsock_stat *st;
363 int fd;
364
365 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
366
367 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
368 perror("bind");
369 exit(EXIT_FAILURE);
370 }
371
372 if (listen(fd, 1) < 0) {
373 perror("listen");
374 exit(EXIT_FAILURE);
375 }
376
377 read_vsock_stat(&sockets);
378
379 check_num_sockets(&sockets, 1);
380 st = find_vsock_stat(&sockets, fd);
381 check_socket_state(st, TCP_LISTEN);
382
383 close(fd);
384 free_sock_stat(&sockets);
385}
386
387static void test_connect_client(unsigned int peer_cid)
388{
389 union {
390 struct sockaddr sa;
391 struct sockaddr_vm svm;
392 } addr = {
393 .svm = {
394 .svm_family = AF_VSOCK,
395 .svm_port = 1234,
396 .svm_cid = peer_cid,
397 },
398 };
399 int fd;
400 int ret;
401 LIST_HEAD(sockets);
402 struct vsock_stat *st;
403
404 control_expectln("LISTENING");
405
406 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
407
408 timeout_begin(TIMEOUT);
409 do {
410 ret = connect(fd, &addr.sa, sizeof(addr.svm));
411 timeout_check("connect");
412 } while (ret < 0 && errno == EINTR);
413 timeout_end();
414
415 if (ret < 0) {
416 perror("connect");
417 exit(EXIT_FAILURE);
418 }
419
420 read_vsock_stat(&sockets);
421
422 check_num_sockets(&sockets, 1);
423 st = find_vsock_stat(&sockets, fd);
424 check_socket_state(st, TCP_ESTABLISHED);
425
426 control_expectln("DONE");
427 control_writeln("DONE");
428
429 close(fd);
430 free_sock_stat(&sockets);
431}
432
433static void test_connect_server(unsigned int peer_cid)
434{
435 union {
436 struct sockaddr sa;
437 struct sockaddr_vm svm;
438 } addr = {
439 .svm = {
440 .svm_family = AF_VSOCK,
441 .svm_port = 1234,
442 .svm_cid = VMADDR_CID_ANY,
443 },
444 };
445 union {
446 struct sockaddr sa;
447 struct sockaddr_vm svm;
448 } clientaddr;
449 socklen_t clientaddr_len = sizeof(clientaddr.svm);
450 LIST_HEAD(sockets);
451 struct vsock_stat *st;
452 int fd;
453 int client_fd;
454
455 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
456
457 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
458 perror("bind");
459 exit(EXIT_FAILURE);
460 }
461
462 if (listen(fd, 1) < 0) {
463 perror("listen");
464 exit(EXIT_FAILURE);
465 }
466
467 control_writeln("LISTENING");
468
469 timeout_begin(TIMEOUT);
470 do {
471 client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
472 timeout_check("accept");
473 } while (client_fd < 0 && errno == EINTR);
474 timeout_end();
475
476 if (client_fd < 0) {
477 perror("accept");
478 exit(EXIT_FAILURE);
479 }
480 if (clientaddr.sa.sa_family != AF_VSOCK) {
481 fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
482 clientaddr.sa.sa_family);
483 exit(EXIT_FAILURE);
484 }
485 if (clientaddr.svm.svm_cid != peer_cid) {
486 fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
487 peer_cid, clientaddr.svm.svm_cid);
488 exit(EXIT_FAILURE);
489 }
490
491 read_vsock_stat(&sockets);
492
493 check_num_sockets(&sockets, 2);
494 find_vsock_stat(&sockets, fd);
495 st = find_vsock_stat(&sockets, client_fd);
496 check_socket_state(st, TCP_ESTABLISHED);
497
498 control_writeln("DONE");
499 control_expectln("DONE");
500
501 close(client_fd);
502 close(fd);
503 free_sock_stat(&sockets);
504}
505
506static struct {
507 const char *name;
508 void (*run_client)(unsigned int peer_cid);
509 void (*run_server)(unsigned int peer_cid);
510} test_cases[] = {
511 {
512 .name = "No sockets",
513 .run_server = test_no_sockets,
514 },
515 {
516 .name = "Listen socket",
517 .run_server = test_listen_socket_server,
518 },
519 {
520 .name = "Connect",
521 .run_client = test_connect_client,
522 .run_server = test_connect_server,
523 },
524 {},
525};
526
527static void init_signals(void)
528{
529 struct sigaction act = {
530 .sa_handler = sigalrm,
531 };
532
533 sigaction(SIGALRM, &act, NULL);
534 signal(SIGPIPE, SIG_IGN);
535}
536
537static unsigned int parse_cid(const char *str)
538{
539 char *endptr = NULL;
540 unsigned long int n;
541
542 errno = 0;
543 n = strtoul(str, &endptr, 10);
544 if (errno || *endptr != '\0') {
545 fprintf(stderr, "malformed CID \"%s\"\n", str);
546 exit(EXIT_FAILURE);
547 }
548 return n;
549}
550
551static const char optstring[] = "";
552static const struct option longopts[] = {
553 {
554 .name = "control-host",
555 .has_arg = required_argument,
556 .val = 'H',
557 },
558 {
559 .name = "control-port",
560 .has_arg = required_argument,
561 .val = 'P',
562 },
563 {
564 .name = "mode",
565 .has_arg = required_argument,
566 .val = 'm',
567 },
568 {
569 .name = "peer-cid",
570 .has_arg = required_argument,
571 .val = 'p',
572 },
573 {
574 .name = "help",
575 .has_arg = no_argument,
576 .val = '?',
577 },
578 {},
579};
580
581static void usage(void)
582{
583 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
584 "\n"
585 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
586 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
587 "\n"
588 "Run vsock_diag.ko tests. Must be launched in both\n"
589 "guest and host. One side must use --mode=client and\n"
590 "the other side must use --mode=server.\n"
591 "\n"
592 "A TCP control socket connection is used to coordinate tests\n"
593 "between the client and the server. The server requires a\n"
594 "listen address and the client requires an address to\n"
595 "connect to.\n"
596 "\n"
597 "The CID of the other side must be given with --peer-cid=<cid>.\n");
598 exit(EXIT_FAILURE);
599}
600
601int main(int argc, char **argv)
602{
603 const char *control_host = NULL;
604 const char *control_port = NULL;
605 int mode = TEST_MODE_UNSET;
606 unsigned int peer_cid = VMADDR_CID_ANY;
607 int i;
608
609 init_signals();
610
611 for (;;) {
612 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
613
614 if (opt == -1)
615 break;
616
617 switch (opt) {
618 case 'H':
619 control_host = optarg;
620 break;
621 case 'm':
622 if (strcmp(optarg, "client") == 0)
623 mode = TEST_MODE_CLIENT;
624 else if (strcmp(optarg, "server") == 0)
625 mode = TEST_MODE_SERVER;
626 else {
627 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
628 return EXIT_FAILURE;
629 }
630 break;
631 case 'p':
632 peer_cid = parse_cid(optarg);
633 break;
634 case 'P':
635 control_port = optarg;
636 break;
637 case '?':
638 default:
639 usage();
640 }
641 }
642
643 if (!control_port)
644 usage();
645 if (mode == TEST_MODE_UNSET)
646 usage();
647 if (peer_cid == VMADDR_CID_ANY)
648 usage();
649
650 if (!control_host) {
651 if (mode != TEST_MODE_SERVER)
652 usage();
653 control_host = "0.0.0.0";
654 }
655
656 control_init(control_host, control_port, mode == TEST_MODE_SERVER);
657
658 for (i = 0; test_cases[i].name; i++) {
659 void (*run)(unsigned int peer_cid);
660
661 printf("%s...", test_cases[i].name);
662 fflush(stdout);
663
664 if (mode == TEST_MODE_CLIENT)
665 run = test_cases[i].run_client;
666 else
667 run = test_cases[i].run_server;
668
669 if (run)
670 run(peer_cid);
671
672 printf("ok\n");
673 }
674
675 control_cleanup();
676 return EXIT_SUCCESS;
677}
678