1
2
3#define _GNU_SOURCE
4
5#include <errno.h>
6#include <limits.h>
7#include <fcntl.h>
8#include <string.h>
9#include <stdbool.h>
10#include <stdint.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <strings.h>
14#include <signal.h>
15#include <unistd.h>
16
17#include <sys/poll.h>
18#include <sys/sendfile.h>
19#include <sys/stat.h>
20#include <sys/socket.h>
21#include <sys/types.h>
22#include <sys/mman.h>
23
24#include <netdb.h>
25#include <netinet/in.h>
26
27#include <linux/tcp.h>
28
29extern int optind;
30
31#ifndef IPPROTO_MPTCP
32#define IPPROTO_MPTCP 262
33#endif
34#ifndef TCP_ULP
35#define TCP_ULP 31
36#endif
37
38static int poll_timeout = 10 * 1000;
39static bool listen_mode;
40static bool quit;
41
42enum cfg_mode {
43 CFG_MODE_POLL,
44 CFG_MODE_MMAP,
45 CFG_MODE_SENDFILE,
46};
47
48static enum cfg_mode cfg_mode = CFG_MODE_POLL;
49static const char *cfg_host;
50static const char *cfg_port = "12000";
51static int cfg_sock_proto = IPPROTO_MPTCP;
52static bool tcpulp_audit;
53static int pf = AF_INET;
54static int cfg_sndbuf;
55static int cfg_rcvbuf;
56static bool cfg_join;
57static int cfg_wait;
58
59static void die_usage(void)
60{
61 fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]"
62 "[-l] [-w sec] connect_address\n");
63 fprintf(stderr, "\t-6 use ipv6\n");
64 fprintf(stderr, "\t-t num -- set poll timeout to num\n");
65 fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
66 fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n");
67 fprintf(stderr, "\t-p num -- use port num\n");
68 fprintf(stderr, "\t-m [MPTCP|TCP] -- use tcp or mptcp sockets\n");
69 fprintf(stderr, "\t-s [mmap|poll] -- use poll (default) or mmap\n");
70 fprintf(stderr, "\t-u -- check mptcp ulp\n");
71 fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
72 exit(1);
73}
74
75static void handle_signal(int nr)
76{
77 quit = true;
78}
79
80static const char *getxinfo_strerr(int err)
81{
82 if (err == EAI_SYSTEM)
83 return strerror(errno);
84
85 return gai_strerror(err);
86}
87
88static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen,
89 char *host, socklen_t hostlen,
90 char *serv, socklen_t servlen)
91{
92 int flags = NI_NUMERICHOST | NI_NUMERICSERV;
93 int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen,
94 flags);
95
96 if (err) {
97 const char *errstr = getxinfo_strerr(err);
98
99 fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr);
100 exit(1);
101 }
102}
103
104static void xgetaddrinfo(const char *node, const char *service,
105 const struct addrinfo *hints,
106 struct addrinfo **res)
107{
108 int err = getaddrinfo(node, service, hints, res);
109
110 if (err) {
111 const char *errstr = getxinfo_strerr(err);
112
113 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
114 node ? node : "", service ? service : "", errstr);
115 exit(1);
116 }
117}
118
119static void set_rcvbuf(int fd, unsigned int size)
120{
121 int err;
122
123 err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size));
124 if (err) {
125 perror("set SO_RCVBUF");
126 exit(1);
127 }
128}
129
130static void set_sndbuf(int fd, unsigned int size)
131{
132 int err;
133
134 err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size));
135 if (err) {
136 perror("set SO_SNDBUF");
137 exit(1);
138 }
139}
140
141static int sock_listen_mptcp(const char * const listenaddr,
142 const char * const port)
143{
144 int sock;
145 struct addrinfo hints = {
146 .ai_protocol = IPPROTO_TCP,
147 .ai_socktype = SOCK_STREAM,
148 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
149 };
150
151 hints.ai_family = pf;
152
153 struct addrinfo *a, *addr;
154 int one = 1;
155
156 xgetaddrinfo(listenaddr, port, &hints, &addr);
157 hints.ai_family = pf;
158
159 for (a = addr; a; a = a->ai_next) {
160 sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto);
161 if (sock < 0)
162 continue;
163
164 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
165 sizeof(one)))
166 perror("setsockopt");
167
168 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
169 break;
170
171 perror("bind");
172 close(sock);
173 sock = -1;
174 }
175
176 freeaddrinfo(addr);
177
178 if (sock < 0) {
179 fprintf(stderr, "Could not create listen socket\n");
180 return sock;
181 }
182
183 if (listen(sock, 20)) {
184 perror("listen");
185 close(sock);
186 return -1;
187 }
188
189 return sock;
190}
191
192static bool sock_test_tcpulp(const char * const remoteaddr,
193 const char * const port)
194{
195 struct addrinfo hints = {
196 .ai_protocol = IPPROTO_TCP,
197 .ai_socktype = SOCK_STREAM,
198 };
199 struct addrinfo *a, *addr;
200 int sock = -1, ret = 0;
201 bool test_pass = false;
202
203 hints.ai_family = AF_INET;
204
205 xgetaddrinfo(remoteaddr, port, &hints, &addr);
206 for (a = addr; a; a = a->ai_next) {
207 sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP);
208 if (sock < 0) {
209 perror("socket");
210 continue;
211 }
212 ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp",
213 sizeof("mptcp"));
214 if (ret == -1 && errno == EOPNOTSUPP)
215 test_pass = true;
216 close(sock);
217
218 if (test_pass)
219 break;
220 if (!ret)
221 fprintf(stderr,
222 "setsockopt(TCP_ULP) returned 0\n");
223 else
224 perror("setsockopt(TCP_ULP)");
225 }
226 return test_pass;
227}
228
229static int sock_connect_mptcp(const char * const remoteaddr,
230 const char * const port, int proto)
231{
232 struct addrinfo hints = {
233 .ai_protocol = IPPROTO_TCP,
234 .ai_socktype = SOCK_STREAM,
235 };
236 struct addrinfo *a, *addr;
237 int sock = -1;
238
239 hints.ai_family = pf;
240
241 xgetaddrinfo(remoteaddr, port, &hints, &addr);
242 for (a = addr; a; a = a->ai_next) {
243 sock = socket(a->ai_family, a->ai_socktype, proto);
244 if (sock < 0) {
245 perror("socket");
246 continue;
247 }
248
249 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
250 break;
251
252 perror("connect()");
253 close(sock);
254 sock = -1;
255 }
256
257 freeaddrinfo(addr);
258 return sock;
259}
260
261static size_t do_rnd_write(const int fd, char *buf, const size_t len)
262{
263 static bool first = true;
264 unsigned int do_w;
265 ssize_t bw;
266
267 do_w = rand() & 0xffff;
268 if (do_w == 0 || do_w > len)
269 do_w = len;
270
271 if (cfg_join && first && do_w > 100)
272 do_w = 100;
273
274 bw = write(fd, buf, do_w);
275 if (bw < 0)
276 perror("write");
277
278
279 if (cfg_join && first) {
280 usleep(200000);
281 first = false;
282 }
283
284 return bw;
285}
286
287static size_t do_write(const int fd, char *buf, const size_t len)
288{
289 size_t offset = 0;
290
291 while (offset < len) {
292 size_t written;
293 ssize_t bw;
294
295 bw = write(fd, buf + offset, len - offset);
296 if (bw < 0) {
297 perror("write");
298 return 0;
299 }
300
301 written = (size_t)bw;
302 offset += written;
303 }
304
305 return offset;
306}
307
308static ssize_t do_rnd_read(const int fd, char *buf, const size_t len)
309{
310 size_t cap = rand();
311
312 cap &= 0xffff;
313
314 if (cap == 0)
315 cap = 1;
316 else if (cap > len)
317 cap = len;
318
319 return read(fd, buf, cap);
320}
321
322static void set_nonblock(int fd)
323{
324 int flags = fcntl(fd, F_GETFL);
325
326 if (flags == -1)
327 return;
328
329 fcntl(fd, F_SETFL, flags | O_NONBLOCK);
330}
331
332static int copyfd_io_poll(int infd, int peerfd, int outfd)
333{
334 struct pollfd fds = {
335 .fd = peerfd,
336 .events = POLLIN | POLLOUT,
337 };
338 unsigned int woff = 0, wlen = 0;
339 char wbuf[8192];
340
341 set_nonblock(peerfd);
342
343 for (;;) {
344 char rbuf[8192];
345 ssize_t len;
346
347 if (fds.events == 0)
348 break;
349
350 switch (poll(&fds, 1, poll_timeout)) {
351 case -1:
352 if (errno == EINTR)
353 continue;
354 perror("poll");
355 return 1;
356 case 0:
357 fprintf(stderr, "%s: poll timed out (events: "
358 "POLLIN %u, POLLOUT %u)\n", __func__,
359 fds.events & POLLIN, fds.events & POLLOUT);
360 return 2;
361 }
362
363 if (fds.revents & POLLIN) {
364 len = do_rnd_read(peerfd, rbuf, sizeof(rbuf));
365 if (len == 0) {
366
367
368
369 fds.events &= ~POLLIN;
370
371 if ((fds.events & POLLOUT) == 0)
372
373 break;
374
375
376 } else if (len < 0) {
377 perror("read");
378 return 3;
379 }
380
381 do_write(outfd, rbuf, len);
382 }
383
384 if (fds.revents & POLLOUT) {
385 if (wlen == 0) {
386 woff = 0;
387 wlen = read(infd, wbuf, sizeof(wbuf));
388 }
389
390 if (wlen > 0) {
391 ssize_t bw;
392
393 bw = do_rnd_write(peerfd, wbuf + woff, wlen);
394 if (bw < 0)
395 return 111;
396
397 woff += bw;
398 wlen -= bw;
399 } else if (wlen == 0) {
400
401 fds.events &= ~POLLOUT;
402
403 if ((fds.events & POLLIN) == 0)
404
405 break;
406
407
408
409
410
411
412 if (cfg_wait)
413 usleep(cfg_wait);
414 shutdown(peerfd, SHUT_WR);
415 } else {
416 if (errno == EINTR)
417 continue;
418 perror("read");
419 return 4;
420 }
421 }
422
423 if (fds.revents & (POLLERR | POLLNVAL)) {
424 fprintf(stderr, "Unexpected revents: "
425 "POLLERR/POLLNVAL(%x)\n", fds.revents);
426 return 5;
427 }
428 }
429
430
431 if (cfg_join)
432 usleep(cfg_wait);
433
434 close(peerfd);
435 return 0;
436}
437
438static int do_recvfile(int infd, int outfd)
439{
440 ssize_t r;
441
442 do {
443 char buf[16384];
444
445 r = do_rnd_read(infd, buf, sizeof(buf));
446 if (r > 0) {
447 if (write(outfd, buf, r) != r)
448 break;
449 } else if (r < 0) {
450 perror("read");
451 }
452 } while (r > 0);
453
454 return (int)r;
455}
456
457static int do_mmap(int infd, int outfd, unsigned int size)
458{
459 char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0);
460 ssize_t ret = 0, off = 0;
461 size_t rem;
462
463 if (inbuf == MAP_FAILED) {
464 perror("mmap");
465 return 1;
466 }
467
468 rem = size;
469
470 while (rem > 0) {
471 ret = write(outfd, inbuf + off, rem);
472
473 if (ret < 0) {
474 perror("write");
475 break;
476 }
477
478 off += ret;
479 rem -= ret;
480 }
481
482 munmap(inbuf, size);
483 return rem;
484}
485
486static int get_infd_size(int fd)
487{
488 struct stat sb;
489 ssize_t count;
490 int err;
491
492 err = fstat(fd, &sb);
493 if (err < 0) {
494 perror("fstat");
495 return -1;
496 }
497
498 if ((sb.st_mode & S_IFMT) != S_IFREG) {
499 fprintf(stderr, "%s: stdin is not a regular file\n", __func__);
500 return -2;
501 }
502
503 count = sb.st_size;
504 if (count > INT_MAX) {
505 fprintf(stderr, "File too large: %zu\n", count);
506 return -3;
507 }
508
509 return (int)count;
510}
511
512static int do_sendfile(int infd, int outfd, unsigned int count)
513{
514 while (count > 0) {
515 ssize_t r;
516
517 r = sendfile(outfd, infd, NULL, count);
518 if (r < 0) {
519 perror("sendfile");
520 return 3;
521 }
522
523 count -= r;
524 }
525
526 return 0;
527}
528
529static int copyfd_io_mmap(int infd, int peerfd, int outfd,
530 unsigned int size)
531{
532 int err;
533
534 if (listen_mode) {
535 err = do_recvfile(peerfd, outfd);
536 if (err)
537 return err;
538
539 err = do_mmap(infd, peerfd, size);
540 } else {
541 err = do_mmap(infd, peerfd, size);
542 if (err)
543 return err;
544
545 shutdown(peerfd, SHUT_WR);
546
547 err = do_recvfile(peerfd, outfd);
548 }
549
550 return err;
551}
552
553static int copyfd_io_sendfile(int infd, int peerfd, int outfd,
554 unsigned int size)
555{
556 int err;
557
558 if (listen_mode) {
559 err = do_recvfile(peerfd, outfd);
560 if (err)
561 return err;
562
563 err = do_sendfile(infd, peerfd, size);
564 } else {
565 err = do_sendfile(infd, peerfd, size);
566 if (err)
567 return err;
568 err = do_recvfile(peerfd, outfd);
569 }
570
571 return err;
572}
573
574static int copyfd_io(int infd, int peerfd, int outfd)
575{
576 int file_size;
577
578 switch (cfg_mode) {
579 case CFG_MODE_POLL:
580 return copyfd_io_poll(infd, peerfd, outfd);
581 case CFG_MODE_MMAP:
582 file_size = get_infd_size(infd);
583 if (file_size < 0)
584 return file_size;
585 return copyfd_io_mmap(infd, peerfd, outfd, file_size);
586 case CFG_MODE_SENDFILE:
587 file_size = get_infd_size(infd);
588 if (file_size < 0)
589 return file_size;
590 return copyfd_io_sendfile(infd, peerfd, outfd, file_size);
591 }
592
593 fprintf(stderr, "Invalid mode %d\n", cfg_mode);
594
595 die_usage();
596 return 1;
597}
598
599static void check_sockaddr(int pf, struct sockaddr_storage *ss,
600 socklen_t salen)
601{
602 struct sockaddr_in6 *sin6;
603 struct sockaddr_in *sin;
604 socklen_t wanted_size = 0;
605
606 switch (pf) {
607 case AF_INET:
608 wanted_size = sizeof(*sin);
609 sin = (void *)ss;
610 if (!sin->sin_port)
611 fprintf(stderr, "accept: something wrong: ip connection from port 0");
612 break;
613 case AF_INET6:
614 wanted_size = sizeof(*sin6);
615 sin6 = (void *)ss;
616 if (!sin6->sin6_port)
617 fprintf(stderr, "accept: something wrong: ipv6 connection from port 0");
618 break;
619 default:
620 fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen);
621 return;
622 }
623
624 if (salen != wanted_size)
625 fprintf(stderr, "accept: size mismatch, got %d expected %d\n",
626 (int)salen, wanted_size);
627
628 if (ss->ss_family != pf)
629 fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n",
630 (int)ss->ss_family, pf);
631}
632
633static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen)
634{
635 struct sockaddr_storage peerss;
636 socklen_t peersalen = sizeof(peerss);
637
638 if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) {
639 perror("getpeername");
640 return;
641 }
642
643 if (peersalen != salen) {
644 fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen);
645 return;
646 }
647
648 if (memcmp(ss, &peerss, peersalen)) {
649 char a[INET6_ADDRSTRLEN];
650 char b[INET6_ADDRSTRLEN];
651 char c[INET6_ADDRSTRLEN];
652 char d[INET6_ADDRSTRLEN];
653
654 xgetnameinfo((struct sockaddr *)ss, salen,
655 a, sizeof(a), b, sizeof(b));
656
657 xgetnameinfo((struct sockaddr *)&peerss, peersalen,
658 c, sizeof(c), d, sizeof(d));
659
660 fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n",
661 __func__, a, c, b, d, peersalen, salen);
662 }
663}
664
665static void check_getpeername_connect(int fd)
666{
667 struct sockaddr_storage ss;
668 socklen_t salen = sizeof(ss);
669 char a[INET6_ADDRSTRLEN];
670 char b[INET6_ADDRSTRLEN];
671
672 if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) {
673 perror("getpeername");
674 return;
675 }
676
677 xgetnameinfo((struct sockaddr *)&ss, salen,
678 a, sizeof(a), b, sizeof(b));
679
680 if (strcmp(cfg_host, a) || strcmp(cfg_port, b))
681 fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__,
682 cfg_host, a, cfg_port, b);
683}
684
685static void maybe_close(int fd)
686{
687 unsigned int r = rand();
688
689 if (!cfg_join && (r & 1))
690 close(fd);
691}
692
693int main_loop_s(int listensock)
694{
695 struct sockaddr_storage ss;
696 struct pollfd polls;
697 socklen_t salen;
698 int remotesock;
699
700 polls.fd = listensock;
701 polls.events = POLLIN;
702
703 switch (poll(&polls, 1, poll_timeout)) {
704 case -1:
705 perror("poll");
706 return 1;
707 case 0:
708 fprintf(stderr, "%s: timed out\n", __func__);
709 close(listensock);
710 return 2;
711 }
712
713 salen = sizeof(ss);
714 remotesock = accept(listensock, (struct sockaddr *)&ss, &salen);
715 if (remotesock >= 0) {
716 maybe_close(listensock);
717 check_sockaddr(pf, &ss, salen);
718 check_getpeername(remotesock, &ss, salen);
719
720 return copyfd_io(0, remotesock, 1);
721 }
722
723 perror("accept");
724
725 return 1;
726}
727
728static void init_rng(void)
729{
730 int fd = open("/dev/urandom", O_RDONLY);
731 unsigned int foo;
732
733 if (fd > 0) {
734 int ret = read(fd, &foo, sizeof(foo));
735
736 if (ret < 0)
737 srand(fd + foo);
738 close(fd);
739 }
740
741 srand(foo);
742}
743
744int main_loop(void)
745{
746 int fd;
747
748
749 fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto);
750 if (fd < 0)
751 return 2;
752
753 check_getpeername_connect(fd);
754
755 if (cfg_rcvbuf)
756 set_rcvbuf(fd, cfg_rcvbuf);
757 if (cfg_sndbuf)
758 set_sndbuf(fd, cfg_sndbuf);
759
760 return copyfd_io(0, fd, 1);
761}
762
763int parse_proto(const char *proto)
764{
765 if (!strcasecmp(proto, "MPTCP"))
766 return IPPROTO_MPTCP;
767 if (!strcasecmp(proto, "TCP"))
768 return IPPROTO_TCP;
769
770 fprintf(stderr, "Unknown protocol: %s\n.", proto);
771 die_usage();
772
773
774 return 0;
775}
776
777int parse_mode(const char *mode)
778{
779 if (!strcasecmp(mode, "poll"))
780 return CFG_MODE_POLL;
781 if (!strcasecmp(mode, "mmap"))
782 return CFG_MODE_MMAP;
783 if (!strcasecmp(mode, "sendfile"))
784 return CFG_MODE_SENDFILE;
785
786 fprintf(stderr, "Unknown test mode: %s\n", mode);
787 fprintf(stderr, "Supported modes are:\n");
788 fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n");
789 fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n");
790 fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n");
791
792 die_usage();
793
794
795 return 0;
796}
797
798static int parse_int(const char *size)
799{
800 unsigned long s;
801
802 errno = 0;
803
804 s = strtoul(size, NULL, 0);
805
806 if (errno) {
807 fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
808 size, strerror(errno));
809 die_usage();
810 }
811
812 if (s > INT_MAX) {
813 fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
814 size, strerror(ERANGE));
815 die_usage();
816 }
817
818 return (int)s;
819}
820
821static void parse_opts(int argc, char **argv)
822{
823 int c;
824
825 while ((c = getopt(argc, argv, "6jlp:s:hut:m:S:R:w:")) != -1) {
826 switch (c) {
827 case 'j':
828 cfg_join = true;
829 cfg_mode = CFG_MODE_POLL;
830 cfg_wait = 400000;
831 break;
832 case 'l':
833 listen_mode = true;
834 break;
835 case 'p':
836 cfg_port = optarg;
837 break;
838 case 's':
839 cfg_sock_proto = parse_proto(optarg);
840 break;
841 case 'h':
842 die_usage();
843 break;
844 case 'u':
845 tcpulp_audit = true;
846 break;
847 case '6':
848 pf = AF_INET6;
849 break;
850 case 't':
851 poll_timeout = atoi(optarg) * 1000;
852 if (poll_timeout <= 0)
853 poll_timeout = -1;
854 break;
855 case 'm':
856 cfg_mode = parse_mode(optarg);
857 break;
858 case 'S':
859 cfg_sndbuf = parse_int(optarg);
860 break;
861 case 'R':
862 cfg_rcvbuf = parse_int(optarg);
863 break;
864 case 'w':
865 cfg_wait = atoi(optarg)*1000000;
866 break;
867 }
868 }
869
870 if (optind + 1 != argc)
871 die_usage();
872 cfg_host = argv[optind];
873
874 if (strchr(cfg_host, ':'))
875 pf = AF_INET6;
876}
877
878int main(int argc, char *argv[])
879{
880 init_rng();
881
882 signal(SIGUSR1, handle_signal);
883 parse_opts(argc, argv);
884
885 if (tcpulp_audit)
886 return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1;
887
888 if (listen_mode) {
889 int fd = sock_listen_mptcp(cfg_host, cfg_port);
890
891 if (fd < 0)
892 return 1;
893
894 if (cfg_rcvbuf)
895 set_rcvbuf(fd, cfg_rcvbuf);
896 if (cfg_sndbuf)
897 set_sndbuf(fd, cfg_sndbuf);
898
899 return main_loop_s(fd);
900 }
901
902 return main_loop();
903}
904