1
2
3
4
5
6
7
8
9
10
11
12
13#include <linux/kernel.h>
14#include <linux/module.h>
15#include <linux/skbuff.h>
16#include <linux/types.h>
17#include <linux/bpf.h>
18#include <net/lwtunnel.h>
19#include <net/gre.h>
20#include <net/ip6_route.h>
21#include <net/ipv6_stubs.h>
22
23#include <linux/rh_flags.h>
24
25struct bpf_lwt_prog {
26 struct bpf_prog *prog;
27 char *name;
28};
29
30struct bpf_lwt {
31 struct bpf_lwt_prog in;
32 struct bpf_lwt_prog out;
33 struct bpf_lwt_prog xmit;
34 int family;
35};
36
37#define MAX_PROG_NAME 256
38
39static inline struct bpf_lwt *bpf_lwt_lwtunnel(struct lwtunnel_state *lwt)
40{
41 return (struct bpf_lwt *)lwt->data;
42}
43
44#define NO_REDIRECT false
45#define CAN_REDIRECT true
46
47static int run_lwt_bpf(struct sk_buff *skb, struct bpf_lwt_prog *lwt,
48 struct dst_entry *dst, bool can_redirect)
49{
50 int ret;
51
52
53
54
55 migrate_disable();
56 local_bh_disable();
57 bpf_compute_data_pointers(skb);
58 ret = bpf_prog_run_save_cb(lwt->prog, skb);
59
60 switch (ret) {
61 case BPF_OK:
62 case BPF_LWT_REROUTE:
63 break;
64
65 case BPF_REDIRECT:
66 if (unlikely(!can_redirect)) {
67 pr_warn_once("Illegal redirect return code in prog %s\n",
68 lwt->name ? : "<unknown>");
69 ret = BPF_OK;
70 } else {
71 skb_reset_mac_header(skb);
72 ret = skb_do_redirect(skb);
73 if (ret == 0)
74 ret = BPF_REDIRECT;
75 }
76 break;
77
78 case BPF_DROP:
79 kfree_skb(skb);
80 ret = -EPERM;
81 break;
82
83 default:
84 pr_warn_once("bpf-lwt: Illegal return value %u, expect packet loss\n", ret);
85 kfree_skb(skb);
86 ret = -EINVAL;
87 break;
88 }
89
90 local_bh_enable();
91 migrate_enable();
92
93 return ret;
94}
95
96static int bpf_lwt_input_reroute(struct sk_buff *skb)
97{
98 int err = -EINVAL;
99
100 if (skb->protocol == htons(ETH_P_IP)) {
101 struct net_device *dev = skb_dst(skb)->dev;
102 struct iphdr *iph = ip_hdr(skb);
103
104 dev_hold(dev);
105 skb_dst_drop(skb);
106 err = ip_route_input_noref(skb, iph->daddr, iph->saddr,
107 iph->tos, dev);
108 dev_put(dev);
109 } else if (skb->protocol == htons(ETH_P_IPV6)) {
110 skb_dst_drop(skb);
111 err = ipv6_stub->ipv6_route_input(skb);
112 } else {
113 err = -EAFNOSUPPORT;
114 }
115
116 if (err)
117 goto err;
118 return dst_input(skb);
119
120err:
121 kfree_skb(skb);
122 return err;
123}
124
125static int bpf_input(struct sk_buff *skb)
126{
127 struct dst_entry *dst = skb_dst(skb);
128 struct bpf_lwt *bpf;
129 int ret;
130
131 bpf = bpf_lwt_lwtunnel(dst->lwtstate);
132 if (bpf->in.prog) {
133 ret = run_lwt_bpf(skb, &bpf->in, dst, NO_REDIRECT);
134 if (ret < 0)
135 return ret;
136 if (ret == BPF_LWT_REROUTE)
137 return bpf_lwt_input_reroute(skb);
138 }
139
140 if (unlikely(!dst->lwtstate->orig_input)) {
141 kfree_skb(skb);
142 return -EINVAL;
143 }
144
145 return dst->lwtstate->orig_input(skb);
146}
147
148static int bpf_output(struct net *net, struct sock *sk, struct sk_buff *skb)
149{
150 struct dst_entry *dst = skb_dst(skb);
151 struct bpf_lwt *bpf;
152 int ret;
153
154 bpf = bpf_lwt_lwtunnel(dst->lwtstate);
155 if (bpf->out.prog) {
156 ret = run_lwt_bpf(skb, &bpf->out, dst, NO_REDIRECT);
157 if (ret < 0)
158 return ret;
159 }
160
161 if (unlikely(!dst->lwtstate->orig_output)) {
162 pr_warn_once("orig_output not set on dst for prog %s\n",
163 bpf->out.name);
164 kfree_skb(skb);
165 return -EINVAL;
166 }
167
168 return dst->lwtstate->orig_output(net, sk, skb);
169}
170
171static int xmit_check_hhlen(struct sk_buff *skb)
172{
173 int hh_len = skb_dst(skb)->dev->hard_header_len;
174
175 if (skb_headroom(skb) < hh_len) {
176 int nhead = HH_DATA_ALIGN(hh_len - skb_headroom(skb));
177
178 if (pskb_expand_head(skb, nhead, 0, GFP_ATOMIC))
179 return -ENOMEM;
180 }
181
182 return 0;
183}
184
185static int bpf_lwt_xmit_reroute(struct sk_buff *skb)
186{
187 struct net_device *l3mdev = l3mdev_master_dev_rcu(skb_dst(skb)->dev);
188 int oif = l3mdev ? l3mdev->ifindex : 0;
189 struct dst_entry *dst = NULL;
190 int err = -EAFNOSUPPORT;
191 struct sock *sk;
192 struct net *net;
193 bool ipv4;
194
195 if (skb->protocol == htons(ETH_P_IP))
196 ipv4 = true;
197 else if (skb->protocol == htons(ETH_P_IPV6))
198 ipv4 = false;
199 else
200 goto err;
201
202 sk = sk_to_full_sk(skb->sk);
203 if (sk) {
204 if (sk->sk_bound_dev_if)
205 oif = sk->sk_bound_dev_if;
206 net = sock_net(sk);
207 } else {
208 net = dev_net(skb_dst(skb)->dev);
209 }
210
211 if (ipv4) {
212 struct iphdr *iph = ip_hdr(skb);
213 struct flowi4 fl4 = {};
214 struct rtable *rt;
215
216 fl4.flowi4_oif = oif;
217 fl4.flowi4_mark = skb->mark;
218 fl4.flowi4_uid = sock_net_uid(net, sk);
219 fl4.flowi4_tos = RT_TOS(iph->tos);
220 fl4.flowi4_flags = FLOWI_FLAG_ANYSRC;
221 fl4.flowi4_proto = iph->protocol;
222 fl4.daddr = iph->daddr;
223 fl4.saddr = iph->saddr;
224
225 rt = ip_route_output_key(net, &fl4);
226 if (IS_ERR(rt)) {
227 err = PTR_ERR(rt);
228 goto err;
229 }
230 dst = &rt->dst;
231 } else {
232 struct ipv6hdr *iph6 = ipv6_hdr(skb);
233 struct flowi6 fl6 = {};
234
235 fl6.flowi6_oif = oif;
236 fl6.flowi6_mark = skb->mark;
237 fl6.flowi6_uid = sock_net_uid(net, sk);
238 fl6.flowlabel = ip6_flowinfo(iph6);
239 fl6.flowi6_proto = iph6->nexthdr;
240 fl6.daddr = iph6->daddr;
241 fl6.saddr = iph6->saddr;
242
243 dst = ipv6_stub->ipv6_dst_lookup_flow(net, skb->sk, &fl6, NULL);
244 if (IS_ERR(dst)) {
245 err = PTR_ERR(dst);
246 goto err;
247 }
248 }
249 if (unlikely(dst->error)) {
250 err = dst->error;
251 dst_release(dst);
252 goto err;
253 }
254
255
256
257
258
259
260 err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev));
261 if (unlikely(err))
262 goto err;
263
264 skb_dst_drop(skb);
265 skb_dst_set(skb, dst);
266
267 err = dst_output(dev_net(skb_dst(skb)->dev), skb->sk, skb);
268 if (unlikely(err))
269 return err;
270
271
272 return LWTUNNEL_XMIT_DONE;
273
274err:
275 kfree_skb(skb);
276 return err;
277}
278
279static int bpf_xmit(struct sk_buff *skb)
280{
281 struct dst_entry *dst = skb_dst(skb);
282 struct bpf_lwt *bpf;
283
284 bpf = bpf_lwt_lwtunnel(dst->lwtstate);
285 if (bpf->xmit.prog) {
286 __be16 proto = skb->protocol;
287 int ret;
288
289 ret = run_lwt_bpf(skb, &bpf->xmit, dst, CAN_REDIRECT);
290 switch (ret) {
291 case BPF_OK:
292
293
294
295
296 if (skb->protocol != proto) {
297 kfree_skb(skb);
298 return -EINVAL;
299 }
300
301
302
303 ret = xmit_check_hhlen(skb);
304 if (unlikely(ret))
305 return ret;
306
307 return LWTUNNEL_XMIT_CONTINUE;
308 case BPF_REDIRECT:
309 return LWTUNNEL_XMIT_DONE;
310 case BPF_LWT_REROUTE:
311 return bpf_lwt_xmit_reroute(skb);
312 default:
313 return ret;
314 }
315 }
316
317 return LWTUNNEL_XMIT_CONTINUE;
318}
319
320static void bpf_lwt_prog_destroy(struct bpf_lwt_prog *prog)
321{
322 if (prog->prog)
323 bpf_prog_put(prog->prog);
324
325 kfree(prog->name);
326}
327
328static void bpf_destroy_state(struct lwtunnel_state *lwt)
329{
330 struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt);
331
332 bpf_lwt_prog_destroy(&bpf->in);
333 bpf_lwt_prog_destroy(&bpf->out);
334 bpf_lwt_prog_destroy(&bpf->xmit);
335}
336
337static const struct nla_policy bpf_prog_policy[LWT_BPF_PROG_MAX + 1] = {
338 [LWT_BPF_PROG_FD] = { .type = NLA_U32, },
339 [LWT_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
340 .len = MAX_PROG_NAME },
341};
342
343static int bpf_parse_prog(struct nlattr *attr, struct bpf_lwt_prog *prog,
344 enum bpf_prog_type type)
345{
346 struct nlattr *tb[LWT_BPF_PROG_MAX + 1];
347 struct bpf_prog *p;
348 int ret;
349 u32 fd;
350
351 ret = nla_parse_nested_deprecated(tb, LWT_BPF_PROG_MAX, attr,
352 bpf_prog_policy, NULL);
353 if (ret < 0)
354 return ret;
355
356 if (!tb[LWT_BPF_PROG_FD] || !tb[LWT_BPF_PROG_NAME])
357 return -EINVAL;
358
359 prog->name = nla_memdup(tb[LWT_BPF_PROG_NAME], GFP_ATOMIC);
360 if (!prog->name)
361 return -ENOMEM;
362
363 fd = nla_get_u32(tb[LWT_BPF_PROG_FD]);
364 p = bpf_prog_get_type(fd, type);
365 if (IS_ERR(p))
366 return PTR_ERR(p);
367
368 rh_add_flag("eBPF/lwt");
369
370 prog->prog = p;
371
372 return 0;
373}
374
375static const struct nla_policy bpf_nl_policy[LWT_BPF_MAX + 1] = {
376 [LWT_BPF_IN] = { .type = NLA_NESTED, },
377 [LWT_BPF_OUT] = { .type = NLA_NESTED, },
378 [LWT_BPF_XMIT] = { .type = NLA_NESTED, },
379 [LWT_BPF_XMIT_HEADROOM] = { .type = NLA_U32 },
380};
381
382static int bpf_build_state(struct nlattr *nla,
383 unsigned int family, const void *cfg,
384 struct lwtunnel_state **ts,
385 struct netlink_ext_ack *extack)
386{
387 struct nlattr *tb[LWT_BPF_MAX + 1];
388 struct lwtunnel_state *newts;
389 struct bpf_lwt *bpf;
390 int ret;
391
392 if (family != AF_INET && family != AF_INET6)
393 return -EAFNOSUPPORT;
394
395 ret = nla_parse_nested_deprecated(tb, LWT_BPF_MAX, nla, bpf_nl_policy,
396 extack);
397 if (ret < 0)
398 return ret;
399
400 if (!tb[LWT_BPF_IN] && !tb[LWT_BPF_OUT] && !tb[LWT_BPF_XMIT])
401 return -EINVAL;
402
403 newts = lwtunnel_state_alloc(sizeof(*bpf));
404 if (!newts)
405 return -ENOMEM;
406
407 newts->type = LWTUNNEL_ENCAP_BPF;
408 bpf = bpf_lwt_lwtunnel(newts);
409
410 if (tb[LWT_BPF_IN]) {
411 newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT;
412 ret = bpf_parse_prog(tb[LWT_BPF_IN], &bpf->in,
413 BPF_PROG_TYPE_LWT_IN);
414 if (ret < 0)
415 goto errout;
416 }
417
418 if (tb[LWT_BPF_OUT]) {
419 newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT;
420 ret = bpf_parse_prog(tb[LWT_BPF_OUT], &bpf->out,
421 BPF_PROG_TYPE_LWT_OUT);
422 if (ret < 0)
423 goto errout;
424 }
425
426 if (tb[LWT_BPF_XMIT]) {
427 newts->flags |= LWTUNNEL_STATE_XMIT_REDIRECT;
428 ret = bpf_parse_prog(tb[LWT_BPF_XMIT], &bpf->xmit,
429 BPF_PROG_TYPE_LWT_XMIT);
430 if (ret < 0)
431 goto errout;
432 }
433
434 if (tb[LWT_BPF_XMIT_HEADROOM]) {
435 u32 headroom = nla_get_u32(tb[LWT_BPF_XMIT_HEADROOM]);
436
437 if (headroom > LWT_BPF_MAX_HEADROOM) {
438 ret = -ERANGE;
439 goto errout;
440 }
441
442 newts->headroom = headroom;
443 }
444
445 bpf->family = family;
446 *ts = newts;
447
448 return 0;
449
450errout:
451 bpf_destroy_state(newts);
452 kfree(newts);
453 return ret;
454}
455
456static int bpf_fill_lwt_prog(struct sk_buff *skb, int attr,
457 struct bpf_lwt_prog *prog)
458{
459 struct nlattr *nest;
460
461 if (!prog->prog)
462 return 0;
463
464 nest = nla_nest_start_noflag(skb, attr);
465 if (!nest)
466 return -EMSGSIZE;
467
468 if (prog->name &&
469 nla_put_string(skb, LWT_BPF_PROG_NAME, prog->name))
470 return -EMSGSIZE;
471
472 return nla_nest_end(skb, nest);
473}
474
475static int bpf_fill_encap_info(struct sk_buff *skb, struct lwtunnel_state *lwt)
476{
477 struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt);
478
479 if (bpf_fill_lwt_prog(skb, LWT_BPF_IN, &bpf->in) < 0 ||
480 bpf_fill_lwt_prog(skb, LWT_BPF_OUT, &bpf->out) < 0 ||
481 bpf_fill_lwt_prog(skb, LWT_BPF_XMIT, &bpf->xmit) < 0)
482 return -EMSGSIZE;
483
484 return 0;
485}
486
487static int bpf_encap_nlsize(struct lwtunnel_state *lwtstate)
488{
489 int nest_len = nla_total_size(sizeof(struct nlattr)) +
490 nla_total_size(MAX_PROG_NAME) +
491 0;
492
493 return nest_len +
494 nest_len +
495 nest_len +
496 0;
497}
498
499static int bpf_lwt_prog_cmp(struct bpf_lwt_prog *a, struct bpf_lwt_prog *b)
500{
501
502
503
504
505 if (!a->name && !b->name)
506 return 0;
507
508 if (!a->name || !b->name)
509 return 1;
510
511 return strcmp(a->name, b->name);
512}
513
514static int bpf_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b)
515{
516 struct bpf_lwt *a_bpf = bpf_lwt_lwtunnel(a);
517 struct bpf_lwt *b_bpf = bpf_lwt_lwtunnel(b);
518
519 return bpf_lwt_prog_cmp(&a_bpf->in, &b_bpf->in) ||
520 bpf_lwt_prog_cmp(&a_bpf->out, &b_bpf->out) ||
521 bpf_lwt_prog_cmp(&a_bpf->xmit, &b_bpf->xmit);
522}
523
524static const struct lwtunnel_encap_ops bpf_encap_ops = {
525 .build_state = bpf_build_state,
526 .destroy_state = bpf_destroy_state,
527 .input = bpf_input,
528 .output = bpf_output,
529 .xmit = bpf_xmit,
530 .fill_encap = bpf_fill_encap_info,
531 .get_encap_size = bpf_encap_nlsize,
532 .cmp_encap = bpf_encap_cmp,
533 .owner = THIS_MODULE,
534};
535
536static int handle_gso_type(struct sk_buff *skb, unsigned int gso_type,
537 int encap_len)
538{
539 struct skb_shared_info *shinfo = skb_shinfo(skb);
540
541 gso_type |= SKB_GSO_DODGY;
542 shinfo->gso_type |= gso_type;
543 skb_decrease_gso_size(shinfo, encap_len);
544 shinfo->gso_segs = 0;
545 return 0;
546}
547
548static int handle_gso_encap(struct sk_buff *skb, bool ipv4, int encap_len)
549{
550 int next_hdr_offset;
551 void *next_hdr;
552 __u8 protocol;
553
554
555
556
557
558 if (!(skb_shinfo(skb)->gso_type & (SKB_GSO_TCPV4 | SKB_GSO_TCPV6)))
559 return -ENOTSUPP;
560
561 if (ipv4) {
562 protocol = ip_hdr(skb)->protocol;
563 next_hdr_offset = sizeof(struct iphdr);
564 next_hdr = skb_network_header(skb) + next_hdr_offset;
565 } else {
566 protocol = ipv6_hdr(skb)->nexthdr;
567 next_hdr_offset = sizeof(struct ipv6hdr);
568 next_hdr = skb_network_header(skb) + next_hdr_offset;
569 }
570
571 switch (protocol) {
572 case IPPROTO_GRE:
573 next_hdr_offset += sizeof(struct gre_base_hdr);
574 if (next_hdr_offset > encap_len)
575 return -EINVAL;
576
577 if (((struct gre_base_hdr *)next_hdr)->flags & GRE_CSUM)
578 return handle_gso_type(skb, SKB_GSO_GRE_CSUM,
579 encap_len);
580 return handle_gso_type(skb, SKB_GSO_GRE, encap_len);
581
582 case IPPROTO_UDP:
583 next_hdr_offset += sizeof(struct udphdr);
584 if (next_hdr_offset > encap_len)
585 return -EINVAL;
586
587 if (((struct udphdr *)next_hdr)->check)
588 return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL_CSUM,
589 encap_len);
590 return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL, encap_len);
591
592 case IPPROTO_IP:
593 case IPPROTO_IPV6:
594 if (ipv4)
595 return handle_gso_type(skb, SKB_GSO_IPXIP4, encap_len);
596 else
597 return handle_gso_type(skb, SKB_GSO_IPXIP6, encap_len);
598
599 default:
600 return -EPROTONOSUPPORT;
601 }
602}
603
604int bpf_lwt_push_ip_encap(struct sk_buff *skb, void *hdr, u32 len, bool ingress)
605{
606 struct iphdr *iph;
607 bool ipv4;
608 int err;
609
610 if (unlikely(len < sizeof(struct iphdr) || len > LWT_BPF_MAX_HEADROOM))
611 return -EINVAL;
612
613
614 iph = (struct iphdr *)hdr;
615 if (iph->version == 4) {
616 ipv4 = true;
617 if (unlikely(len < iph->ihl * 4))
618 return -EINVAL;
619 } else if (iph->version == 6) {
620 ipv4 = false;
621 if (unlikely(len < sizeof(struct ipv6hdr)))
622 return -EINVAL;
623 } else {
624 return -EINVAL;
625 }
626
627 if (ingress)
628 err = skb_cow_head(skb, len + skb->mac_len);
629 else
630 err = skb_cow_head(skb,
631 len + LL_RESERVED_SPACE(skb_dst(skb)->dev));
632 if (unlikely(err))
633 return err;
634
635
636 skb_reset_inner_headers(skb);
637 skb_reset_inner_mac_header(skb);
638 skb_set_inner_protocol(skb, skb->protocol);
639 skb->encapsulation = 1;
640 skb_push(skb, len);
641 if (ingress)
642 skb_postpush_rcsum(skb, iph, len);
643 skb_reset_network_header(skb);
644 memcpy(skb_network_header(skb), hdr, len);
645 bpf_compute_data_pointers(skb);
646 skb_clear_hash(skb);
647
648 if (ipv4) {
649 skb->protocol = htons(ETH_P_IP);
650 iph = ip_hdr(skb);
651
652 if (!iph->check)
653 iph->check = ip_fast_csum((unsigned char *)iph,
654 iph->ihl);
655 } else {
656 skb->protocol = htons(ETH_P_IPV6);
657 }
658
659 if (skb_is_gso(skb))
660 return handle_gso_encap(skb, ipv4, len);
661
662 return 0;
663}
664
665static int __init bpf_lwt_init(void)
666{
667 return lwtunnel_encap_add_ops(&bpf_encap_ops, LWTUNNEL_ENCAP_BPF);
668}
669
670subsys_initcall(bpf_lwt_init)
671