1
2#include <error.h>
3#include <errno.h>
4#include <stdio.h>
5#include <unistd.h>
6#include <sys/types.h>
7#include <sys/socket.h>
8#include <netinet/in.h>
9#include <netinet/tcp.h>
10#include <pthread.h>
11
12#include <linux/filter.h>
13#include <bpf/bpf.h>
14#include <bpf/libbpf.h>
15
16#include "bpf_rlimit.h"
17#include "bpf_util.h"
18#include "cgroup_helpers.h"
19
20#define CG_PATH "/tcp_rtt"
21
22struct tcp_rtt_storage {
23 __u32 invoked;
24 __u32 dsack_dups;
25 __u32 delivered;
26 __u32 delivered_ce;
27 __u32 icsk_retransmits;
28};
29
30static void send_byte(int fd)
31{
32 char b = 0x55;
33
34 if (write(fd, &b, sizeof(b)) != 1)
35 error(1, errno, "Failed to send single byte");
36}
37
38static int wait_for_ack(int fd, int retries)
39{
40 struct tcp_info info;
41 socklen_t optlen;
42 int i, err;
43
44 for (i = 0; i < retries; i++) {
45 optlen = sizeof(info);
46 err = getsockopt(fd, SOL_TCP, TCP_INFO, &info, &optlen);
47 if (err < 0) {
48 log_err("Failed to lookup TCP stats");
49 return err;
50 }
51
52 if (info.tcpi_unacked == 0)
53 return 0;
54
55 usleep(10);
56 }
57
58 log_err("Did not receive ACK");
59 return -1;
60}
61
62static int verify_sk(int map_fd, int client_fd, const char *msg, __u32 invoked,
63 __u32 dsack_dups, __u32 delivered, __u32 delivered_ce,
64 __u32 icsk_retransmits)
65{
66 int err = 0;
67 struct tcp_rtt_storage val;
68
69 if (bpf_map_lookup_elem(map_fd, &client_fd, &val) < 0)
70 error(1, errno, "Failed to read socket storage");
71
72 if (val.invoked != invoked) {
73 log_err("%s: unexpected bpf_tcp_sock.invoked %d != %d",
74 msg, val.invoked, invoked);
75 err++;
76 }
77
78 if (val.dsack_dups != dsack_dups) {
79 log_err("%s: unexpected bpf_tcp_sock.dsack_dups %d != %d",
80 msg, val.dsack_dups, dsack_dups);
81 err++;
82 }
83
84 if (val.delivered != delivered) {
85 log_err("%s: unexpected bpf_tcp_sock.delivered %d != %d",
86 msg, val.delivered, delivered);
87 err++;
88 }
89
90 if (val.delivered_ce != delivered_ce) {
91 log_err("%s: unexpected bpf_tcp_sock.delivered_ce %d != %d",
92 msg, val.delivered_ce, delivered_ce);
93 err++;
94 }
95
96 if (val.icsk_retransmits != icsk_retransmits) {
97 log_err("%s: unexpected bpf_tcp_sock.icsk_retransmits %d != %d",
98 msg, val.icsk_retransmits, icsk_retransmits);
99 err++;
100 }
101
102 return err;
103}
104
105static int connect_to_server(int server_fd)
106{
107 struct sockaddr_storage addr;
108 socklen_t len = sizeof(addr);
109 int fd;
110
111 fd = socket(AF_INET, SOCK_STREAM, 0);
112 if (fd < 0) {
113 log_err("Failed to create client socket");
114 return -1;
115 }
116
117 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
118 log_err("Failed to get server addr");
119 goto out;
120 }
121
122 if (connect(fd, (const struct sockaddr *)&addr, len) < 0) {
123 log_err("Fail to connect to server");
124 goto out;
125 }
126
127 return fd;
128
129out:
130 close(fd);
131 return -1;
132}
133
134static int run_test(int cgroup_fd, int server_fd)
135{
136 struct bpf_prog_load_attr attr = {
137 .prog_type = BPF_PROG_TYPE_SOCK_OPS,
138 .file = "./tcp_rtt.o",
139 .expected_attach_type = BPF_CGROUP_SOCK_OPS,
140 };
141 struct bpf_object *obj;
142 struct bpf_map *map;
143 int client_fd;
144 int prog_fd;
145 int map_fd;
146 int err;
147
148 err = bpf_prog_load_xattr(&attr, &obj, &prog_fd);
149 if (err) {
150 log_err("Failed to load BPF object");
151 return -1;
152 }
153
154 map = bpf_map__next(NULL, obj);
155 map_fd = bpf_map__fd(map);
156
157 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
158 if (err) {
159 log_err("Failed to attach BPF program");
160 goto close_bpf_object;
161 }
162
163 client_fd = connect_to_server(server_fd);
164 if (client_fd < 0) {
165 err = -1;
166 goto close_bpf_object;
167 }
168
169 err += verify_sk(map_fd, client_fd, "syn-ack",
170 1,
171 0,
172 1,
173 0,
174 0);
175
176 send_byte(client_fd);
177 if (wait_for_ack(client_fd, 100) < 0) {
178 err = -1;
179 goto close_client_fd;
180 }
181
182
183 err += verify_sk(map_fd, client_fd, "first payload byte",
184 2,
185 0,
186 2,
187 0,
188 0);
189
190close_client_fd:
191 close(client_fd);
192
193close_bpf_object:
194 bpf_object__close(obj);
195 return err;
196}
197
198static int start_server(void)
199{
200 struct sockaddr_in addr = {
201 .sin_family = AF_INET,
202 .sin_addr.s_addr = htonl(INADDR_LOOPBACK),
203 };
204 int fd;
205
206 fd = socket(AF_INET, SOCK_STREAM, 0);
207 if (fd < 0) {
208 log_err("Failed to create server socket");
209 return -1;
210 }
211
212 if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) {
213 log_err("Failed to bind socket");
214 close(fd);
215 return -1;
216 }
217
218 return fd;
219}
220
221static pthread_mutex_t server_started_mtx = PTHREAD_MUTEX_INITIALIZER;
222static pthread_cond_t server_started = PTHREAD_COND_INITIALIZER;
223
224static void *server_thread(void *arg)
225{
226 struct sockaddr_storage addr;
227 socklen_t len = sizeof(addr);
228 int fd = *(int *)arg;
229 int client_fd;
230 int err;
231
232 err = listen(fd, 1);
233
234 pthread_mutex_lock(&server_started_mtx);
235 pthread_cond_signal(&server_started);
236 pthread_mutex_unlock(&server_started_mtx);
237
238 if (err < 0)
239 error(1, errno, "Failed to listed on socket");
240
241 client_fd = accept(fd, (struct sockaddr *)&addr, &len);
242 if (client_fd < 0)
243 error(1, errno, "Failed to accept client");
244
245
246
247
248
249 if (accept(fd, (struct sockaddr *)&addr, &len) >= 0)
250 error(1, errno, "Unexpected success in second accept");
251
252 close(client_fd);
253
254 return NULL;
255}
256
257int main(int args, char **argv)
258{
259 int server_fd, cgroup_fd;
260 int err = EXIT_SUCCESS;
261 pthread_t tid;
262
263 if (setup_cgroup_environment())
264 goto cleanup_obj;
265
266 cgroup_fd = create_and_get_cgroup(CG_PATH);
267 if (cgroup_fd < 0)
268 goto cleanup_cgroup_env;
269
270 if (join_cgroup(CG_PATH))
271 goto cleanup_cgroup;
272
273 server_fd = start_server();
274 if (server_fd < 0) {
275 err = EXIT_FAILURE;
276 goto cleanup_cgroup;
277 }
278
279 if ((pthread_create(&tid, NULL, server_thread,
280 (void *)&server_fd))) {
281 err = EXIT_FAILURE;
282 goto close_server_fd;
283 }
284
285 pthread_mutex_lock(&server_started_mtx);
286 pthread_cond_wait(&server_started, &server_started_mtx);
287 pthread_mutex_unlock(&server_started_mtx);
288
289 if (run_test(cgroup_fd, server_fd))
290 err = EXIT_FAILURE;
291
292 printf("test_sockopt_sk: %s\n",
293 err == EXIT_SUCCESS ? "PASSED" : "FAILED");
294
295close_server_fd:
296 close(server_fd);
297cleanup_cgroup:
298 close(cgroup_fd);
299cleanup_cgroup_env:
300 cleanup_cgroup_environment();
301cleanup_obj:
302 return err;
303}
304