1
2#include <error.h>
3#include <errno.h>
4#include <stdio.h>
5#include <unistd.h>
6#include <sys/types.h>
7#include <sys/socket.h>
8#include <netinet/in.h>
9#include <pthread.h>
10
11#include <linux/filter.h>
12#include <bpf/bpf.h>
13#include <bpf/libbpf.h>
14
15#include "bpf_rlimit.h"
16#include "bpf_util.h"
17#include "cgroup_helpers.h"
18
19#define CG_PATH "/tcp_rtt"
20
21struct tcp_rtt_storage {
22 __u32 invoked;
23 __u32 dsack_dups;
24 __u32 delivered;
25 __u32 delivered_ce;
26 __u32 icsk_retransmits;
27};
28
29static void send_byte(int fd)
30{
31 char b = 0x55;
32
33 if (write(fd, &b, sizeof(b)) != 1)
34 error(1, errno, "Failed to send single byte");
35}
36
37static int verify_sk(int map_fd, int client_fd, const char *msg, __u32 invoked,
38 __u32 dsack_dups, __u32 delivered, __u32 delivered_ce,
39 __u32 icsk_retransmits)
40{
41 int err = 0;
42 struct tcp_rtt_storage val;
43
44 if (bpf_map_lookup_elem(map_fd, &client_fd, &val) < 0)
45 error(1, errno, "Failed to read socket storage");
46
47 if (val.invoked != invoked) {
48 log_err("%s: unexpected bpf_tcp_sock.invoked %d != %d",
49 msg, val.invoked, invoked);
50 err++;
51 }
52
53 if (val.dsack_dups != dsack_dups) {
54 log_err("%s: unexpected bpf_tcp_sock.dsack_dups %d != %d",
55 msg, val.dsack_dups, dsack_dups);
56 err++;
57 }
58
59 if (val.delivered != delivered) {
60 log_err("%s: unexpected bpf_tcp_sock.delivered %d != %d",
61 msg, val.delivered, delivered);
62 err++;
63 }
64
65 if (val.delivered_ce != delivered_ce) {
66 log_err("%s: unexpected bpf_tcp_sock.delivered_ce %d != %d",
67 msg, val.delivered_ce, delivered_ce);
68 err++;
69 }
70
71 if (val.icsk_retransmits != icsk_retransmits) {
72 log_err("%s: unexpected bpf_tcp_sock.icsk_retransmits %d != %d",
73 msg, val.icsk_retransmits, icsk_retransmits);
74 err++;
75 }
76
77 return err;
78}
79
80static int connect_to_server(int server_fd)
81{
82 struct sockaddr_storage addr;
83 socklen_t len = sizeof(addr);
84 int fd;
85
86 fd = socket(AF_INET, SOCK_STREAM, 0);
87 if (fd < 0) {
88 log_err("Failed to create client socket");
89 return -1;
90 }
91
92 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
93 log_err("Failed to get server addr");
94 goto out;
95 }
96
97 if (connect(fd, (const struct sockaddr *)&addr, len) < 0) {
98 log_err("Fail to connect to server");
99 goto out;
100 }
101
102 return fd;
103
104out:
105 close(fd);
106 return -1;
107}
108
109static int run_test(int cgroup_fd, int server_fd)
110{
111 struct bpf_prog_load_attr attr = {
112 .prog_type = BPF_PROG_TYPE_SOCK_OPS,
113 .file = "./tcp_rtt.o",
114 .expected_attach_type = BPF_CGROUP_SOCK_OPS,
115 };
116 struct bpf_object *obj;
117 struct bpf_map *map;
118 int client_fd;
119 int prog_fd;
120 int map_fd;
121 int err;
122
123 err = bpf_prog_load_xattr(&attr, &obj, &prog_fd);
124 if (err) {
125 log_err("Failed to load BPF object");
126 return -1;
127 }
128
129 map = bpf_map__next(NULL, obj);
130 map_fd = bpf_map__fd(map);
131
132 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
133 if (err) {
134 log_err("Failed to attach BPF program");
135 goto close_bpf_object;
136 }
137
138 client_fd = connect_to_server(server_fd);
139 if (client_fd < 0) {
140 err = -1;
141 goto close_bpf_object;
142 }
143
144 err += verify_sk(map_fd, client_fd, "syn-ack",
145 1,
146 0,
147 1,
148 0,
149 0);
150
151 send_byte(client_fd);
152
153 err += verify_sk(map_fd, client_fd, "first payload byte",
154 2,
155 0,
156 2,
157 0,
158 0);
159
160 close(client_fd);
161
162close_bpf_object:
163 bpf_object__close(obj);
164 return err;
165}
166
167static int start_server(void)
168{
169 struct sockaddr_in addr = {
170 .sin_family = AF_INET,
171 .sin_addr.s_addr = htonl(INADDR_LOOPBACK),
172 };
173 int fd;
174
175 fd = socket(AF_INET, SOCK_STREAM, 0);
176 if (fd < 0) {
177 log_err("Failed to create server socket");
178 return -1;
179 }
180
181 if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) {
182 log_err("Failed to bind socket");
183 close(fd);
184 return -1;
185 }
186
187 return fd;
188}
189
190static void *server_thread(void *arg)
191{
192 struct sockaddr_storage addr;
193 socklen_t len = sizeof(addr);
194 int fd = *(int *)arg;
195 int client_fd;
196
197 if (listen(fd, 1) < 0)
198 error(1, errno, "Failed to listed on socket");
199
200 client_fd = accept(fd, (struct sockaddr *)&addr, &len);
201 if (client_fd < 0)
202 error(1, errno, "Failed to accept client");
203
204
205
206
207
208 if (accept(fd, (struct sockaddr *)&addr, &len) >= 0)
209 error(1, errno, "Unexpected success in second accept");
210
211 close(client_fd);
212
213 return NULL;
214}
215
216int main(int args, char **argv)
217{
218 int server_fd, cgroup_fd;
219 int err = EXIT_SUCCESS;
220 pthread_t tid;
221
222 if (setup_cgroup_environment())
223 goto cleanup_obj;
224
225 cgroup_fd = create_and_get_cgroup(CG_PATH);
226 if (cgroup_fd < 0)
227 goto cleanup_cgroup_env;
228
229 if (join_cgroup(CG_PATH))
230 goto cleanup_cgroup;
231
232 server_fd = start_server();
233 if (server_fd < 0) {
234 err = EXIT_FAILURE;
235 goto cleanup_cgroup;
236 }
237
238 pthread_create(&tid, NULL, server_thread, (void *)&server_fd);
239
240 if (run_test(cgroup_fd, server_fd))
241 err = EXIT_FAILURE;
242
243 close(server_fd);
244
245 printf("test_sockopt_sk: %s\n",
246 err == EXIT_SUCCESS ? "PASSED" : "FAILED");
247
248cleanup_cgroup:
249 close(cgroup_fd);
250cleanup_cgroup_env:
251 cleanup_cgroup_environment();
252cleanup_obj:
253 return err;
254}
255