1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31#include <linux/bpf.h>
32#include <net/sock.h>
33#include <linux/filter.h>
34#include <linux/errno.h>
35#include <linux/file.h>
36#include <linux/kernel.h>
37#include <linux/net.h>
38#include <linux/skbuff.h>
39#include <linux/workqueue.h>
40#include <linux/list.h>
41#include <linux/mm.h>
42#include <net/strparser.h>
43#include <net/tcp.h>
44#include <linux/ptr_ring.h>
45#include <net/inet_common.h>
46#include <linux/sched/signal.h>
47
48#define SOCK_CREATE_FLAG_MASK \
49 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
50
51struct bpf_stab {
52 struct bpf_map map;
53 struct sock **sock_map;
54 struct bpf_prog *bpf_tx_msg;
55 struct bpf_prog *bpf_parse;
56 struct bpf_prog *bpf_verdict;
57};
58
59enum smap_psock_state {
60 SMAP_TX_RUNNING,
61};
62
63struct smap_psock_map_entry {
64 struct list_head list;
65 struct sock **entry;
66};
67
68struct smap_psock {
69 struct rcu_head rcu;
70 refcount_t refcnt;
71
72
73 struct sk_buff_head rxqueue;
74 bool strp_enabled;
75
76
77 int save_rem;
78 int save_off;
79 struct sk_buff *save_skb;
80
81
82 struct sock *sk_redir;
83 int apply_bytes;
84 int cork_bytes;
85 int sg_size;
86 int eval;
87 struct sk_msg_buff *cork;
88 struct list_head ingress;
89
90 struct strparser strp;
91 struct bpf_prog *bpf_tx_msg;
92 struct bpf_prog *bpf_parse;
93 struct bpf_prog *bpf_verdict;
94 struct list_head maps;
95
96
97 struct sock *sock;
98 unsigned long state;
99
100 struct work_struct tx_work;
101 struct work_struct gc_work;
102
103 struct proto *sk_proto;
104 void (*save_close)(struct sock *sk, long timeout);
105 void (*save_data_ready)(struct sock *sk);
106 void (*save_write_space)(struct sock *sk);
107};
108
109static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
110static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
111 int nonblock, int flags, int *addr_len);
112static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
113static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
114 int offset, size_t size, int flags);
115
116static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
117{
118 return rcu_dereference_sk_user_data(sk);
119}
120
121static bool bpf_tcp_stream_read(const struct sock *sk)
122{
123 struct smap_psock *psock;
124 bool empty = true;
125
126 rcu_read_lock();
127 psock = smap_psock_sk(sk);
128 if (unlikely(!psock))
129 goto out;
130 empty = list_empty(&psock->ingress);
131out:
132 rcu_read_unlock();
133 return !empty;
134}
135
136static struct proto tcp_bpf_proto;
137static int bpf_tcp_init(struct sock *sk)
138{
139 struct smap_psock *psock;
140
141 rcu_read_lock();
142 psock = smap_psock_sk(sk);
143 if (unlikely(!psock)) {
144 rcu_read_unlock();
145 return -EINVAL;
146 }
147
148 if (unlikely(psock->sk_proto)) {
149 rcu_read_unlock();
150 return -EBUSY;
151 }
152
153 psock->save_close = sk->sk_prot->close;
154 psock->sk_proto = sk->sk_prot;
155
156 if (psock->bpf_tx_msg) {
157 tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
158 tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
159 tcp_bpf_proto.recvmsg = bpf_tcp_recvmsg;
160 tcp_bpf_proto.stream_memory_read = bpf_tcp_stream_read;
161 }
162
163 sk->sk_prot = &tcp_bpf_proto;
164 rcu_read_unlock();
165 return 0;
166}
167
168static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
169static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
170
171static void bpf_tcp_release(struct sock *sk)
172{
173 struct smap_psock *psock;
174
175 rcu_read_lock();
176 psock = smap_psock_sk(sk);
177 if (unlikely(!psock))
178 goto out;
179
180 if (psock->cork) {
181 free_start_sg(psock->sock, psock->cork);
182 kfree(psock->cork);
183 psock->cork = NULL;
184 }
185
186 if (psock->sk_proto) {
187 sk->sk_prot = psock->sk_proto;
188 psock->sk_proto = NULL;
189 }
190out:
191 rcu_read_unlock();
192}
193
194static void bpf_tcp_close(struct sock *sk, long timeout)
195{
196 void (*close_fun)(struct sock *sk, long timeout);
197 struct smap_psock_map_entry *e, *tmp;
198 struct sk_msg_buff *md, *mtmp;
199 struct smap_psock *psock;
200 struct sock *osk;
201
202 rcu_read_lock();
203 psock = smap_psock_sk(sk);
204 if (unlikely(!psock)) {
205 rcu_read_unlock();
206 return sk->sk_prot->close(sk, timeout);
207 }
208
209
210
211
212
213
214 close_fun = psock->save_close;
215
216 write_lock_bh(&sk->sk_callback_lock);
217 if (psock->cork) {
218 free_start_sg(psock->sock, psock->cork);
219 kfree(psock->cork);
220 psock->cork = NULL;
221 }
222
223 list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
224 list_del(&md->list);
225 free_start_sg(psock->sock, md);
226 kfree(md);
227 }
228
229 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
230 osk = cmpxchg(e->entry, sk, NULL);
231 if (osk == sk) {
232 list_del(&e->list);
233 smap_release_sock(psock, sk);
234 }
235 }
236 write_unlock_bh(&sk->sk_callback_lock);
237 rcu_read_unlock();
238 close_fun(sk, timeout);
239}
240
241enum __sk_action {
242 __SK_DROP = 0,
243 __SK_PASS,
244 __SK_REDIRECT,
245 __SK_NONE,
246};
247
248static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
249 .name = "bpf_tcp",
250 .uid = TCP_ULP_BPF,
251 .user_visible = false,
252 .owner = NULL,
253 .init = bpf_tcp_init,
254 .release = bpf_tcp_release,
255};
256
257static int memcopy_from_iter(struct sock *sk,
258 struct sk_msg_buff *md,
259 struct iov_iter *from, int bytes)
260{
261 struct scatterlist *sg = md->sg_data;
262 int i = md->sg_curr, rc = -ENOSPC;
263
264 do {
265 int copy;
266 char *to;
267
268 if (md->sg_copybreak >= sg[i].length) {
269 md->sg_copybreak = 0;
270
271 if (++i == MAX_SKB_FRAGS)
272 i = 0;
273
274 if (i == md->sg_end)
275 break;
276 }
277
278 copy = sg[i].length - md->sg_copybreak;
279 to = sg_virt(&sg[i]) + md->sg_copybreak;
280 md->sg_copybreak += copy;
281
282 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
283 rc = copy_from_iter_nocache(to, copy, from);
284 else
285 rc = copy_from_iter(to, copy, from);
286
287 if (rc != copy) {
288 rc = -EFAULT;
289 goto out;
290 }
291
292 bytes -= copy;
293 if (!bytes)
294 break;
295
296 md->sg_copybreak = 0;
297 if (++i == MAX_SKB_FRAGS)
298 i = 0;
299 } while (i != md->sg_end);
300out:
301 md->sg_curr = i;
302 return rc;
303}
304
305static int bpf_tcp_push(struct sock *sk, int apply_bytes,
306 struct sk_msg_buff *md,
307 int flags, bool uncharge)
308{
309 bool apply = apply_bytes;
310 struct scatterlist *sg;
311 int offset, ret = 0;
312 struct page *p;
313 size_t size;
314
315 while (1) {
316 sg = md->sg_data + md->sg_start;
317 size = (apply && apply_bytes < sg->length) ?
318 apply_bytes : sg->length;
319 offset = sg->offset;
320
321 tcp_rate_check_app_limited(sk);
322 p = sg_page(sg);
323retry:
324 ret = do_tcp_sendpages(sk, p, offset, size, flags);
325 if (ret != size) {
326 if (ret > 0) {
327 if (apply)
328 apply_bytes -= ret;
329
330 sg->offset += ret;
331 sg->length -= ret;
332 size -= ret;
333 offset += ret;
334 if (uncharge)
335 sk_mem_uncharge(sk, ret);
336 goto retry;
337 }
338
339 return ret;
340 }
341
342 if (apply)
343 apply_bytes -= ret;
344 sg->offset += ret;
345 sg->length -= ret;
346 if (uncharge)
347 sk_mem_uncharge(sk, ret);
348
349 if (!sg->length) {
350 put_page(p);
351 md->sg_start++;
352 if (md->sg_start == MAX_SKB_FRAGS)
353 md->sg_start = 0;
354 sg_init_table(sg, 1);
355
356 if (md->sg_start == md->sg_end)
357 break;
358 }
359
360 if (apply && !apply_bytes)
361 break;
362 }
363 return 0;
364}
365
366static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
367{
368 struct scatterlist *sg = md->sg_data + md->sg_start;
369
370 if (md->sg_copy[md->sg_start]) {
371 md->data = md->data_end = 0;
372 } else {
373 md->data = sg_virt(sg);
374 md->data_end = md->data + sg->length;
375 }
376}
377
378static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
379{
380 struct scatterlist *sg = md->sg_data;
381 int i = md->sg_start;
382
383 do {
384 int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
385
386 sk_mem_uncharge(sk, uncharge);
387 bytes -= uncharge;
388 if (!bytes)
389 break;
390 i++;
391 if (i == MAX_SKB_FRAGS)
392 i = 0;
393 } while (i != md->sg_end);
394}
395
396static void free_bytes_sg(struct sock *sk, int bytes,
397 struct sk_msg_buff *md, bool charge)
398{
399 struct scatterlist *sg = md->sg_data;
400 int i = md->sg_start, free;
401
402 while (bytes && sg[i].length) {
403 free = sg[i].length;
404 if (bytes < free) {
405 sg[i].length -= bytes;
406 sg[i].offset += bytes;
407 if (charge)
408 sk_mem_uncharge(sk, bytes);
409 break;
410 }
411
412 if (charge)
413 sk_mem_uncharge(sk, sg[i].length);
414 put_page(sg_page(&sg[i]));
415 bytes -= sg[i].length;
416 sg[i].length = 0;
417 sg[i].page_link = 0;
418 sg[i].offset = 0;
419 i++;
420
421 if (i == MAX_SKB_FRAGS)
422 i = 0;
423 }
424 md->sg_start = i;
425}
426
427static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
428{
429 struct scatterlist *sg = md->sg_data;
430 int i = start, free = 0;
431
432 while (sg[i].length) {
433 free += sg[i].length;
434 sk_mem_uncharge(sk, sg[i].length);
435 put_page(sg_page(&sg[i]));
436 sg[i].length = 0;
437 sg[i].page_link = 0;
438 sg[i].offset = 0;
439 i++;
440
441 if (i == MAX_SKB_FRAGS)
442 i = 0;
443 }
444
445 return free;
446}
447
448static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
449{
450 int free = free_sg(sk, md->sg_start, md);
451
452 md->sg_start = md->sg_end;
453 return free;
454}
455
456static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
457{
458 return free_sg(sk, md->sg_curr, md);
459}
460
461static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
462{
463 return ((_rc == SK_PASS) ?
464 (md->map ? __SK_REDIRECT : __SK_PASS) :
465 __SK_DROP);
466}
467
468static unsigned int smap_do_tx_msg(struct sock *sk,
469 struct smap_psock *psock,
470 struct sk_msg_buff *md)
471{
472 struct bpf_prog *prog;
473 unsigned int rc, _rc;
474
475 preempt_disable();
476 rcu_read_lock();
477
478
479 prog = READ_ONCE(psock->bpf_tx_msg);
480 if (unlikely(!prog)) {
481 _rc = SK_PASS;
482 goto verdict;
483 }
484
485 bpf_compute_data_pointers_sg(md);
486 rc = (*prog->bpf_func)(md, prog->insnsi);
487 psock->apply_bytes = md->apply_bytes;
488
489
490 _rc = bpf_map_msg_verdict(rc, md);
491
492
493
494
495
496
497 if (_rc == __SK_REDIRECT) {
498 if (psock->sk_redir)
499 sock_put(psock->sk_redir);
500 psock->sk_redir = do_msg_redirect_map(md);
501 if (!psock->sk_redir) {
502 _rc = __SK_DROP;
503 goto verdict;
504 }
505 sock_hold(psock->sk_redir);
506 }
507verdict:
508 rcu_read_unlock();
509 preempt_enable();
510
511 return _rc;
512}
513
514static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
515 struct smap_psock *psock,
516 struct sk_msg_buff *md, int flags)
517{
518 bool apply = apply_bytes;
519 size_t size, copied = 0;
520 struct sk_msg_buff *r;
521 int err = 0, i;
522
523 r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
524 if (unlikely(!r))
525 return -ENOMEM;
526
527 lock_sock(sk);
528 r->sg_start = md->sg_start;
529 i = md->sg_start;
530
531 do {
532 size = (apply && apply_bytes < md->sg_data[i].length) ?
533 apply_bytes : md->sg_data[i].length;
534
535 if (!sk_wmem_schedule(sk, size)) {
536 if (!copied)
537 err = -ENOMEM;
538 break;
539 }
540
541 sk_mem_charge(sk, size);
542 r->sg_data[i] = md->sg_data[i];
543 r->sg_data[i].length = size;
544 md->sg_data[i].length -= size;
545 md->sg_data[i].offset += size;
546 copied += size;
547
548 if (md->sg_data[i].length) {
549 get_page(sg_page(&r->sg_data[i]));
550 r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
551 } else {
552 i++;
553 if (i == MAX_SKB_FRAGS)
554 i = 0;
555 r->sg_end = i;
556 }
557
558 if (apply) {
559 apply_bytes -= size;
560 if (!apply_bytes)
561 break;
562 }
563 } while (i != md->sg_end);
564
565 md->sg_start = i;
566
567 if (!err) {
568 list_add_tail(&r->list, &psock->ingress);
569 sk->sk_data_ready(sk);
570 } else {
571 free_start_sg(sk, r);
572 kfree(r);
573 }
574
575 release_sock(sk);
576 return err;
577}
578
579static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
580 struct sk_msg_buff *md,
581 int flags)
582{
583 bool ingress = !!(md->flags & BPF_F_INGRESS);
584 struct smap_psock *psock;
585 struct scatterlist *sg;
586 int err = 0;
587
588 sg = md->sg_data;
589
590 rcu_read_lock();
591 psock = smap_psock_sk(sk);
592 if (unlikely(!psock))
593 goto out_rcu;
594
595 if (!refcount_inc_not_zero(&psock->refcnt))
596 goto out_rcu;
597
598 rcu_read_unlock();
599
600 if (ingress) {
601 err = bpf_tcp_ingress(sk, send, psock, md, flags);
602 } else {
603 lock_sock(sk);
604 err = bpf_tcp_push(sk, send, md, flags, false);
605 release_sock(sk);
606 }
607 smap_release_sock(psock, sk);
608 if (unlikely(err))
609 goto out;
610 return 0;
611out_rcu:
612 rcu_read_unlock();
613out:
614 free_bytes_sg(NULL, send, md, false);
615 return err;
616}
617
618static inline void bpf_md_init(struct smap_psock *psock)
619{
620 if (!psock->apply_bytes) {
621 psock->eval = __SK_NONE;
622 if (psock->sk_redir) {
623 sock_put(psock->sk_redir);
624 psock->sk_redir = NULL;
625 }
626 }
627}
628
629static void apply_bytes_dec(struct smap_psock *psock, int i)
630{
631 if (psock->apply_bytes) {
632 if (psock->apply_bytes < i)
633 psock->apply_bytes = 0;
634 else
635 psock->apply_bytes -= i;
636 }
637}
638
639static int bpf_exec_tx_verdict(struct smap_psock *psock,
640 struct sk_msg_buff *m,
641 struct sock *sk,
642 int *copied, int flags)
643{
644 bool cork = false, enospc = (m->sg_start == m->sg_end);
645 struct sock *redir;
646 int err = 0;
647 int send;
648
649more_data:
650 if (psock->eval == __SK_NONE)
651 psock->eval = smap_do_tx_msg(sk, psock, m);
652
653 if (m->cork_bytes &&
654 m->cork_bytes > psock->sg_size && !enospc) {
655 psock->cork_bytes = m->cork_bytes - psock->sg_size;
656 if (!psock->cork) {
657 psock->cork = kcalloc(1,
658 sizeof(struct sk_msg_buff),
659 GFP_ATOMIC | __GFP_NOWARN);
660
661 if (!psock->cork) {
662 err = -ENOMEM;
663 goto out_err;
664 }
665 }
666 memcpy(psock->cork, m, sizeof(*m));
667 goto out_err;
668 }
669
670 send = psock->sg_size;
671 if (psock->apply_bytes && psock->apply_bytes < send)
672 send = psock->apply_bytes;
673
674 switch (psock->eval) {
675 case __SK_PASS:
676 err = bpf_tcp_push(sk, send, m, flags, true);
677 if (unlikely(err)) {
678 *copied -= free_start_sg(sk, m);
679 break;
680 }
681
682 apply_bytes_dec(psock, send);
683 psock->sg_size -= send;
684 break;
685 case __SK_REDIRECT:
686 redir = psock->sk_redir;
687 apply_bytes_dec(psock, send);
688
689 if (psock->cork) {
690 cork = true;
691 psock->cork = NULL;
692 }
693
694 return_mem_sg(sk, send, m);
695 release_sock(sk);
696
697 err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
698 lock_sock(sk);
699
700 if (unlikely(err < 0)) {
701 free_start_sg(sk, m);
702 psock->sg_size = 0;
703 if (!cork)
704 *copied -= send;
705 } else {
706 psock->sg_size -= send;
707 }
708
709 if (cork) {
710 free_start_sg(sk, m);
711 psock->sg_size = 0;
712 kfree(m);
713 m = NULL;
714 err = 0;
715 }
716 break;
717 case __SK_DROP:
718 default:
719 free_bytes_sg(sk, send, m, true);
720 apply_bytes_dec(psock, send);
721 *copied -= send;
722 psock->sg_size -= send;
723 err = -EACCES;
724 break;
725 }
726
727 if (likely(!err)) {
728 bpf_md_init(psock);
729 if (m &&
730 m->sg_data[m->sg_start].page_link &&
731 m->sg_data[m->sg_start].length)
732 goto more_data;
733 }
734
735out_err:
736 return err;
737}
738
739static int bpf_wait_data(struct sock *sk,
740 struct smap_psock *psk, int flags,
741 long timeo, int *err)
742{
743 int rc;
744
745 DEFINE_WAIT_FUNC(wait, woken_wake_function);
746
747 add_wait_queue(sk_sleep(sk), &wait);
748 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
749 rc = sk_wait_event(sk, &timeo,
750 !list_empty(&psk->ingress) ||
751 !skb_queue_empty(&sk->sk_receive_queue),
752 &wait);
753 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
754 remove_wait_queue(sk_sleep(sk), &wait);
755
756 return rc;
757}
758
759static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
760 int nonblock, int flags, int *addr_len)
761{
762 struct iov_iter *iter = &msg->msg_iter;
763 struct smap_psock *psock;
764 int copied = 0;
765
766 if (unlikely(flags & MSG_ERRQUEUE))
767 return inet_recv_error(sk, msg, len, addr_len);
768
769 rcu_read_lock();
770 psock = smap_psock_sk(sk);
771 if (unlikely(!psock))
772 goto out;
773
774 if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
775 goto out;
776 rcu_read_unlock();
777
778 if (!skb_queue_empty(&sk->sk_receive_queue))
779 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
780
781 lock_sock(sk);
782bytes_ready:
783 while (copied != len) {
784 struct scatterlist *sg;
785 struct sk_msg_buff *md;
786 int i;
787
788 md = list_first_entry_or_null(&psock->ingress,
789 struct sk_msg_buff, list);
790 if (unlikely(!md))
791 break;
792 i = md->sg_start;
793 do {
794 struct page *page;
795 int n, copy;
796
797 sg = &md->sg_data[i];
798 copy = sg->length;
799 page = sg_page(sg);
800
801 if (copied + copy > len)
802 copy = len - copied;
803
804 n = copy_page_to_iter(page, sg->offset, copy, iter);
805 if (n != copy) {
806 md->sg_start = i;
807 release_sock(sk);
808 smap_release_sock(psock, sk);
809 return -EFAULT;
810 }
811
812 copied += copy;
813 sg->offset += copy;
814 sg->length -= copy;
815 sk_mem_uncharge(sk, copy);
816
817 if (!sg->length) {
818 i++;
819 if (i == MAX_SKB_FRAGS)
820 i = 0;
821 if (!md->skb)
822 put_page(page);
823 }
824 if (copied == len)
825 break;
826 } while (i != md->sg_end);
827 md->sg_start = i;
828
829 if (!sg->length && md->sg_start == md->sg_end) {
830 list_del(&md->list);
831 if (md->skb)
832 consume_skb(md->skb);
833 kfree(md);
834 }
835 }
836
837 if (!copied) {
838 long timeo;
839 int data;
840 int err = 0;
841
842 timeo = sock_rcvtimeo(sk, nonblock);
843 data = bpf_wait_data(sk, psock, flags, timeo, &err);
844
845 if (data) {
846 if (!skb_queue_empty(&sk->sk_receive_queue)) {
847 release_sock(sk);
848 smap_release_sock(psock, sk);
849 copied = tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
850 return copied;
851 }
852 goto bytes_ready;
853 }
854
855 if (err)
856 copied = err;
857 }
858
859 release_sock(sk);
860 smap_release_sock(psock, sk);
861 return copied;
862out:
863 rcu_read_unlock();
864 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
865}
866
867
868static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
869{
870 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
871 struct sk_msg_buff md = {0};
872 unsigned int sg_copy = 0;
873 struct smap_psock *psock;
874 int copied = 0, err = 0;
875 struct scatterlist *sg;
876 long timeo;
877
878
879
880
881
882
883 rcu_read_lock();
884 psock = smap_psock_sk(sk);
885 if (unlikely(!psock)) {
886 rcu_read_unlock();
887 return tcp_sendmsg(sk, msg, size);
888 }
889
890
891
892
893
894
895 if (!refcount_inc_not_zero(&psock->refcnt)) {
896 rcu_read_unlock();
897 return tcp_sendmsg(sk, msg, size);
898 }
899
900 sg = md.sg_data;
901 sg_init_marker(sg, MAX_SKB_FRAGS);
902 rcu_read_unlock();
903
904 lock_sock(sk);
905 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
906
907 while (msg_data_left(msg)) {
908 struct sk_msg_buff *m;
909 bool enospc = false;
910 int copy;
911
912 if (sk->sk_err) {
913 err = sk->sk_err;
914 goto out_err;
915 }
916
917 copy = msg_data_left(msg);
918 if (!sk_stream_memory_free(sk))
919 goto wait_for_sndbuf;
920
921 m = psock->cork_bytes ? psock->cork : &md;
922 m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
923 err = sk_alloc_sg(sk, copy, m->sg_data,
924 m->sg_start, &m->sg_end, &sg_copy,
925 m->sg_end - 1);
926 if (err) {
927 if (err != -ENOSPC)
928 goto wait_for_memory;
929 enospc = true;
930 copy = sg_copy;
931 }
932
933 err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
934 if (err < 0) {
935 free_curr_sg(sk, m);
936 goto out_err;
937 }
938
939 psock->sg_size += copy;
940 copied += copy;
941 sg_copy = 0;
942
943
944
945
946
947
948
949
950
951
952
953
954 if (psock->cork_bytes) {
955 if (copy > psock->cork_bytes)
956 psock->cork_bytes = 0;
957 else
958 psock->cork_bytes -= copy;
959
960 if (psock->cork_bytes && !enospc)
961 goto out_cork;
962
963
964 psock->eval = __SK_NONE;
965 psock->cork_bytes = 0;
966 }
967
968 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
969 if (unlikely(err < 0))
970 goto out_err;
971 continue;
972wait_for_sndbuf:
973 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
974wait_for_memory:
975 err = sk_stream_wait_memory(sk, &timeo);
976 if (err)
977 goto out_err;
978 }
979out_err:
980 if (err < 0)
981 err = sk_stream_error(sk, msg->msg_flags, err);
982out_cork:
983 release_sock(sk);
984 smap_release_sock(psock, sk);
985 return copied ? copied : err;
986}
987
988static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
989 int offset, size_t size, int flags)
990{
991 struct sk_msg_buff md = {0}, *m = NULL;
992 int err = 0, copied = 0;
993 struct smap_psock *psock;
994 struct scatterlist *sg;
995 bool enospc = false;
996
997 rcu_read_lock();
998 psock = smap_psock_sk(sk);
999 if (unlikely(!psock))
1000 goto accept;
1001
1002 if (!refcount_inc_not_zero(&psock->refcnt))
1003 goto accept;
1004 rcu_read_unlock();
1005
1006 lock_sock(sk);
1007
1008 if (psock->cork_bytes) {
1009 m = psock->cork;
1010 sg = &m->sg_data[m->sg_end];
1011 } else {
1012 m = &md;
1013 sg = m->sg_data;
1014 sg_init_marker(sg, MAX_SKB_FRAGS);
1015 }
1016
1017
1018 if (unlikely(m->sg_end == m->sg_start &&
1019 m->sg_data[m->sg_end].length))
1020 goto out_err;
1021
1022 psock->sg_size += size;
1023 sg_set_page(sg, page, size, offset);
1024 get_page(page);
1025 m->sg_copy[m->sg_end] = true;
1026 sk_mem_charge(sk, size);
1027 m->sg_end++;
1028 copied = size;
1029
1030 if (m->sg_end == MAX_SKB_FRAGS)
1031 m->sg_end = 0;
1032
1033 if (m->sg_end == m->sg_start)
1034 enospc = true;
1035
1036 if (psock->cork_bytes) {
1037 if (size > psock->cork_bytes)
1038 psock->cork_bytes = 0;
1039 else
1040 psock->cork_bytes -= size;
1041
1042 if (psock->cork_bytes && !enospc)
1043 goto out_err;
1044
1045
1046 psock->eval = __SK_NONE;
1047 psock->cork_bytes = 0;
1048 }
1049
1050 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1051out_err:
1052 release_sock(sk);
1053 smap_release_sock(psock, sk);
1054 return copied ? copied : err;
1055accept:
1056 rcu_read_unlock();
1057 return tcp_sendpage(sk, page, offset, size, flags);
1058}
1059
1060static void bpf_tcp_msg_add(struct smap_psock *psock,
1061 struct sock *sk,
1062 struct bpf_prog *tx_msg)
1063{
1064 struct bpf_prog *orig_tx_msg;
1065
1066 orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
1067 if (orig_tx_msg)
1068 bpf_prog_put(orig_tx_msg);
1069}
1070
1071static int bpf_tcp_ulp_register(void)
1072{
1073 tcp_bpf_proto = tcp_prot;
1074 tcp_bpf_proto.close = bpf_tcp_close;
1075
1076
1077
1078
1079 return tcp_register_ulp(&bpf_tcp_ulp_ops);
1080}
1081
1082static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
1083{
1084 struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
1085 int rc;
1086
1087 if (unlikely(!prog))
1088 return __SK_DROP;
1089
1090 skb_orphan(skb);
1091
1092
1093
1094
1095 TCP_SKB_CB(skb)->bpf.map = NULL;
1096 skb->sk = psock->sock;
1097 bpf_compute_data_pointers(skb);
1098 preempt_disable();
1099 rc = (*prog->bpf_func)(skb, prog->insnsi);
1100 preempt_enable();
1101 skb->sk = NULL;
1102
1103
1104 return rc == SK_PASS ?
1105 (TCP_SKB_CB(skb)->bpf.map ? __SK_REDIRECT : __SK_PASS) :
1106 __SK_DROP;
1107}
1108
1109static int smap_do_ingress(struct smap_psock *psock, struct sk_buff *skb)
1110{
1111 struct sock *sk = psock->sock;
1112 int copied = 0, num_sg;
1113 struct sk_msg_buff *r;
1114
1115 r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_ATOMIC);
1116 if (unlikely(!r))
1117 return -EAGAIN;
1118
1119 if (!sk_rmem_schedule(sk, skb, skb->len)) {
1120 kfree(r);
1121 return -EAGAIN;
1122 }
1123
1124 sg_init_table(r->sg_data, MAX_SKB_FRAGS);
1125 num_sg = skb_to_sgvec(skb, r->sg_data, 0, skb->len);
1126 if (unlikely(num_sg < 0)) {
1127 kfree(r);
1128 return num_sg;
1129 }
1130 sk_mem_charge(sk, skb->len);
1131 copied = skb->len;
1132 r->sg_start = 0;
1133 r->sg_end = num_sg == MAX_SKB_FRAGS ? 0 : num_sg;
1134 r->skb = skb;
1135 list_add_tail(&r->list, &psock->ingress);
1136 sk->sk_data_ready(sk);
1137 return copied;
1138}
1139
1140static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
1141{
1142 struct smap_psock *peer;
1143 struct sock *sk;
1144 __u32 in;
1145 int rc;
1146
1147 rc = smap_verdict_func(psock, skb);
1148 switch (rc) {
1149 case __SK_REDIRECT:
1150 sk = do_sk_redirect_map(skb);
1151 if (!sk) {
1152 kfree_skb(skb);
1153 break;
1154 }
1155
1156 peer = smap_psock_sk(sk);
1157 in = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1158
1159 if (unlikely(!peer || sock_flag(sk, SOCK_DEAD) ||
1160 !test_bit(SMAP_TX_RUNNING, &peer->state))) {
1161 kfree_skb(skb);
1162 break;
1163 }
1164
1165 if (!in && sock_writeable(sk)) {
1166 skb_set_owner_w(skb, sk);
1167 skb_queue_tail(&peer->rxqueue, skb);
1168 schedule_work(&peer->tx_work);
1169 break;
1170 } else if (in &&
1171 atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) {
1172 skb_queue_tail(&peer->rxqueue, skb);
1173 schedule_work(&peer->tx_work);
1174 break;
1175 }
1176
1177 case __SK_DROP:
1178 default:
1179 kfree_skb(skb);
1180 }
1181}
1182
1183static void smap_report_sk_error(struct smap_psock *psock, int err)
1184{
1185 struct sock *sk = psock->sock;
1186
1187 sk->sk_err = err;
1188 sk->sk_error_report(sk);
1189}
1190
1191static void smap_read_sock_strparser(struct strparser *strp,
1192 struct sk_buff *skb)
1193{
1194 struct smap_psock *psock;
1195
1196 rcu_read_lock();
1197 psock = container_of(strp, struct smap_psock, strp);
1198 smap_do_verdict(psock, skb);
1199 rcu_read_unlock();
1200}
1201
1202
1203static void smap_data_ready(struct sock *sk)
1204{
1205 struct smap_psock *psock;
1206
1207 rcu_read_lock();
1208 psock = smap_psock_sk(sk);
1209 if (likely(psock)) {
1210 write_lock_bh(&sk->sk_callback_lock);
1211 strp_data_ready(&psock->strp);
1212 write_unlock_bh(&sk->sk_callback_lock);
1213 }
1214 rcu_read_unlock();
1215}
1216
1217static void smap_tx_work(struct work_struct *w)
1218{
1219 struct smap_psock *psock;
1220 struct sk_buff *skb;
1221 int rem, off, n;
1222
1223 psock = container_of(w, struct smap_psock, tx_work);
1224
1225
1226 lock_sock(psock->sock);
1227 if (psock->save_skb) {
1228 skb = psock->save_skb;
1229 rem = psock->save_rem;
1230 off = psock->save_off;
1231 psock->save_skb = NULL;
1232 goto start;
1233 }
1234
1235 while ((skb = skb_dequeue(&psock->rxqueue))) {
1236 __u32 flags;
1237
1238 rem = skb->len;
1239 off = 0;
1240start:
1241 flags = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1242 do {
1243 if (likely(psock->sock->sk_socket)) {
1244 if (flags)
1245 n = smap_do_ingress(psock, skb);
1246 else
1247 n = skb_send_sock_locked(psock->sock,
1248 skb, off, rem);
1249 } else {
1250 n = -EINVAL;
1251 }
1252
1253 if (n <= 0) {
1254 if (n == -EAGAIN) {
1255
1256 psock->save_skb = skb;
1257 psock->save_rem = rem;
1258 psock->save_off = off;
1259 goto out;
1260 }
1261
1262 smap_report_sk_error(psock, n ? -n : EPIPE);
1263 clear_bit(SMAP_TX_RUNNING, &psock->state);
1264 kfree_skb(skb);
1265 goto out;
1266 }
1267 rem -= n;
1268 off += n;
1269 } while (rem);
1270
1271 if (!flags)
1272 kfree_skb(skb);
1273 }
1274out:
1275 release_sock(psock->sock);
1276}
1277
1278static void smap_write_space(struct sock *sk)
1279{
1280 struct smap_psock *psock;
1281
1282 rcu_read_lock();
1283 psock = smap_psock_sk(sk);
1284 if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
1285 schedule_work(&psock->tx_work);
1286 rcu_read_unlock();
1287}
1288
1289static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
1290{
1291 if (!psock->strp_enabled)
1292 return;
1293 sk->sk_data_ready = psock->save_data_ready;
1294 sk->sk_write_space = psock->save_write_space;
1295 psock->save_data_ready = NULL;
1296 psock->save_write_space = NULL;
1297 strp_stop(&psock->strp);
1298 psock->strp_enabled = false;
1299}
1300
1301static void smap_destroy_psock(struct rcu_head *rcu)
1302{
1303 struct smap_psock *psock = container_of(rcu,
1304 struct smap_psock, rcu);
1305
1306
1307
1308
1309
1310
1311
1312 schedule_work(&psock->gc_work);
1313}
1314
1315static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1316{
1317 if (refcount_dec_and_test(&psock->refcnt)) {
1318 tcp_cleanup_ulp(sock);
1319 smap_stop_sock(psock, sock);
1320 clear_bit(SMAP_TX_RUNNING, &psock->state);
1321 rcu_assign_sk_user_data(sock, NULL);
1322 call_rcu_sched(&psock->rcu, smap_destroy_psock);
1323 }
1324}
1325
1326static int smap_parse_func_strparser(struct strparser *strp,
1327 struct sk_buff *skb)
1328{
1329 struct smap_psock *psock;
1330 struct bpf_prog *prog;
1331 int rc;
1332
1333 rcu_read_lock();
1334 psock = container_of(strp, struct smap_psock, strp);
1335 prog = READ_ONCE(psock->bpf_parse);
1336
1337 if (unlikely(!prog)) {
1338 rcu_read_unlock();
1339 return skb->len;
1340 }
1341
1342
1343
1344
1345
1346
1347
1348
1349 skb->sk = psock->sock;
1350 bpf_compute_data_pointers(skb);
1351 rc = (*prog->bpf_func)(skb, prog->insnsi);
1352 skb->sk = NULL;
1353 rcu_read_unlock();
1354 return rc;
1355}
1356
1357static int smap_read_sock_done(struct strparser *strp, int err)
1358{
1359 return err;
1360}
1361
1362static int smap_init_sock(struct smap_psock *psock,
1363 struct sock *sk)
1364{
1365 static const struct strp_callbacks cb = {
1366 .rcv_msg = smap_read_sock_strparser,
1367 .parse_msg = smap_parse_func_strparser,
1368 .read_sock_done = smap_read_sock_done,
1369 };
1370
1371 return strp_init(&psock->strp, sk, &cb);
1372}
1373
1374static void smap_init_progs(struct smap_psock *psock,
1375 struct bpf_stab *stab,
1376 struct bpf_prog *verdict,
1377 struct bpf_prog *parse)
1378{
1379 struct bpf_prog *orig_parse, *orig_verdict;
1380
1381 orig_parse = xchg(&psock->bpf_parse, parse);
1382 orig_verdict = xchg(&psock->bpf_verdict, verdict);
1383
1384 if (orig_verdict)
1385 bpf_prog_put(orig_verdict);
1386 if (orig_parse)
1387 bpf_prog_put(orig_parse);
1388}
1389
1390static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1391{
1392 if (sk->sk_data_ready == smap_data_ready)
1393 return;
1394 psock->save_data_ready = sk->sk_data_ready;
1395 psock->save_write_space = sk->sk_write_space;
1396 sk->sk_data_ready = smap_data_ready;
1397 sk->sk_write_space = smap_write_space;
1398 psock->strp_enabled = true;
1399}
1400
1401static void sock_map_remove_complete(struct bpf_stab *stab)
1402{
1403 bpf_map_area_free(stab->sock_map);
1404 kfree(stab);
1405}
1406
1407static void smap_gc_work(struct work_struct *w)
1408{
1409 struct smap_psock_map_entry *e, *tmp;
1410 struct sk_msg_buff *md, *mtmp;
1411 struct smap_psock *psock;
1412
1413 psock = container_of(w, struct smap_psock, gc_work);
1414
1415
1416 if (psock->strp_enabled)
1417 strp_done(&psock->strp);
1418
1419 cancel_work_sync(&psock->tx_work);
1420 __skb_queue_purge(&psock->rxqueue);
1421
1422
1423 if (psock->bpf_parse)
1424 bpf_prog_put(psock->bpf_parse);
1425 if (psock->bpf_verdict)
1426 bpf_prog_put(psock->bpf_verdict);
1427 if (psock->bpf_tx_msg)
1428 bpf_prog_put(psock->bpf_tx_msg);
1429
1430 if (psock->cork) {
1431 free_start_sg(psock->sock, psock->cork);
1432 kfree(psock->cork);
1433 }
1434
1435 list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1436 list_del(&md->list);
1437 free_start_sg(psock->sock, md);
1438 kfree(md);
1439 }
1440
1441 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1442 list_del(&e->list);
1443 kfree(e);
1444 }
1445
1446 if (psock->sk_redir)
1447 sock_put(psock->sk_redir);
1448
1449 sock_put(psock->sock);
1450 kfree(psock);
1451}
1452
1453static struct smap_psock *smap_init_psock(struct sock *sock,
1454 struct bpf_stab *stab)
1455{
1456 struct smap_psock *psock;
1457
1458 psock = kzalloc_node(sizeof(struct smap_psock),
1459 GFP_ATOMIC | __GFP_NOWARN,
1460 stab->map.numa_node);
1461 if (!psock)
1462 return ERR_PTR(-ENOMEM);
1463
1464 psock->eval = __SK_NONE;
1465 psock->sock = sock;
1466 skb_queue_head_init(&psock->rxqueue);
1467 INIT_WORK(&psock->tx_work, smap_tx_work);
1468 INIT_WORK(&psock->gc_work, smap_gc_work);
1469 INIT_LIST_HEAD(&psock->maps);
1470 INIT_LIST_HEAD(&psock->ingress);
1471 refcount_set(&psock->refcnt, 1);
1472
1473 rcu_assign_sk_user_data(sock, psock);
1474 sock_hold(sock);
1475 return psock;
1476}
1477
1478static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1479{
1480 struct bpf_stab *stab;
1481 u64 cost;
1482 int err;
1483
1484 if (!capable(CAP_NET_ADMIN))
1485 return ERR_PTR(-EPERM);
1486
1487
1488 if (attr->max_entries == 0 || attr->key_size != 4 ||
1489 attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1490 return ERR_PTR(-EINVAL);
1491
1492 err = bpf_tcp_ulp_register();
1493 if (err && err != -EEXIST)
1494 return ERR_PTR(err);
1495
1496 stab = kzalloc(sizeof(*stab), GFP_USER);
1497 if (!stab)
1498 return ERR_PTR(-ENOMEM);
1499
1500 bpf_map_init_from_attr(&stab->map, attr);
1501
1502
1503 cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1504 err = -EINVAL;
1505 if (cost >= U32_MAX - PAGE_SIZE)
1506 goto free_stab;
1507
1508 stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1509
1510
1511 err = bpf_map_precharge_memlock(stab->map.pages);
1512 if (err)
1513 goto free_stab;
1514
1515 err = -ENOMEM;
1516 stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1517 sizeof(struct sock *),
1518 stab->map.numa_node);
1519 if (!stab->sock_map)
1520 goto free_stab;
1521
1522 return &stab->map;
1523free_stab:
1524 kfree(stab);
1525 return ERR_PTR(err);
1526}
1527
1528static void smap_list_remove(struct smap_psock *psock, struct sock **entry)
1529{
1530 struct smap_psock_map_entry *e, *tmp;
1531
1532 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1533 if (e->entry == entry) {
1534 list_del(&e->list);
1535 break;
1536 }
1537 }
1538}
1539
1540static void sock_map_free(struct bpf_map *map)
1541{
1542 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1543 int i;
1544
1545 synchronize_rcu();
1546
1547
1548
1549
1550
1551
1552
1553
1554 rcu_read_lock();
1555 for (i = 0; i < stab->map.max_entries; i++) {
1556 struct smap_psock *psock;
1557 struct sock *sock;
1558
1559 sock = xchg(&stab->sock_map[i], NULL);
1560 if (!sock)
1561 continue;
1562
1563 write_lock_bh(&sock->sk_callback_lock);
1564 psock = smap_psock_sk(sock);
1565
1566
1567
1568
1569
1570 if (likely(psock)) {
1571 smap_list_remove(psock, &stab->sock_map[i]);
1572 smap_release_sock(psock, sock);
1573 }
1574 write_unlock_bh(&sock->sk_callback_lock);
1575 }
1576 rcu_read_unlock();
1577
1578 sock_map_remove_complete(stab);
1579}
1580
1581static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1582{
1583 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1584 u32 i = key ? *(u32 *)key : U32_MAX;
1585 u32 *next = (u32 *)next_key;
1586
1587 if (i >= stab->map.max_entries) {
1588 *next = 0;
1589 return 0;
1590 }
1591
1592 if (i == stab->map.max_entries - 1)
1593 return -ENOENT;
1594
1595 *next = i + 1;
1596 return 0;
1597}
1598
1599struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1600{
1601 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1602
1603 if (key >= map->max_entries)
1604 return NULL;
1605
1606 return READ_ONCE(stab->sock_map[key]);
1607}
1608
1609static int sock_map_delete_elem(struct bpf_map *map, void *key)
1610{
1611 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1612 struct smap_psock *psock;
1613 int k = *(u32 *)key;
1614 struct sock *sock;
1615
1616 if (k >= map->max_entries)
1617 return -EINVAL;
1618
1619 sock = xchg(&stab->sock_map[k], NULL);
1620 if (!sock)
1621 return -EINVAL;
1622
1623 write_lock_bh(&sock->sk_callback_lock);
1624 psock = smap_psock_sk(sock);
1625 if (!psock)
1626 goto out;
1627
1628 if (psock->bpf_parse)
1629 smap_stop_sock(psock, sock);
1630 smap_list_remove(psock, &stab->sock_map[k]);
1631 smap_release_sock(psock, sock);
1632out:
1633 write_unlock_bh(&sock->sk_callback_lock);
1634 return 0;
1635}
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1666 struct bpf_map *map,
1667 void *key, u64 flags)
1668{
1669 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1670 struct smap_psock_map_entry *e = NULL;
1671 struct bpf_prog *verdict, *parse, *tx_msg;
1672 struct sock *osock, *sock;
1673 struct smap_psock *psock;
1674 u32 i = *(u32 *)key;
1675 bool new = false;
1676 int err;
1677
1678 if (unlikely(flags > BPF_EXIST))
1679 return -EINVAL;
1680
1681 if (unlikely(i >= stab->map.max_entries))
1682 return -E2BIG;
1683
1684 sock = READ_ONCE(stab->sock_map[i]);
1685 if (flags == BPF_EXIST && !sock)
1686 return -ENOENT;
1687 else if (flags == BPF_NOEXIST && sock)
1688 return -EEXIST;
1689
1690 sock = skops->sk;
1691
1692
1693
1694
1695
1696 verdict = READ_ONCE(stab->bpf_verdict);
1697 parse = READ_ONCE(stab->bpf_parse);
1698 tx_msg = READ_ONCE(stab->bpf_tx_msg);
1699
1700 if (parse && verdict) {
1701
1702
1703
1704
1705
1706 verdict = bpf_prog_inc_not_zero(verdict);
1707 if (IS_ERR(verdict))
1708 return PTR_ERR(verdict);
1709
1710 parse = bpf_prog_inc_not_zero(parse);
1711 if (IS_ERR(parse)) {
1712 bpf_prog_put(verdict);
1713 return PTR_ERR(parse);
1714 }
1715 }
1716
1717 if (tx_msg) {
1718 tx_msg = bpf_prog_inc_not_zero(tx_msg);
1719 if (IS_ERR(tx_msg)) {
1720 if (parse && verdict) {
1721 bpf_prog_put(parse);
1722 bpf_prog_put(verdict);
1723 }
1724 return PTR_ERR(tx_msg);
1725 }
1726 }
1727
1728 write_lock_bh(&sock->sk_callback_lock);
1729 psock = smap_psock_sk(sock);
1730
1731
1732
1733
1734
1735
1736
1737 if (psock) {
1738 if (READ_ONCE(psock->bpf_parse) && parse) {
1739 err = -EBUSY;
1740 goto out_progs;
1741 }
1742 if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1743 err = -EBUSY;
1744 goto out_progs;
1745 }
1746 if (!refcount_inc_not_zero(&psock->refcnt)) {
1747 err = -EAGAIN;
1748 goto out_progs;
1749 }
1750 } else {
1751 psock = smap_init_psock(sock, stab);
1752 if (IS_ERR(psock)) {
1753 err = PTR_ERR(psock);
1754 goto out_progs;
1755 }
1756
1757 set_bit(SMAP_TX_RUNNING, &psock->state);
1758 new = true;
1759 }
1760
1761 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
1762 if (!e) {
1763 err = -ENOMEM;
1764 goto out_progs;
1765 }
1766 e->entry = &stab->sock_map[i];
1767
1768
1769
1770
1771 if (tx_msg)
1772 bpf_tcp_msg_add(psock, sock, tx_msg);
1773 if (new) {
1774 err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
1775 if (err)
1776 goto out_free;
1777 }
1778
1779 if (parse && verdict && !psock->strp_enabled) {
1780 err = smap_init_sock(psock, sock);
1781 if (err)
1782 goto out_free;
1783 smap_init_progs(psock, stab, verdict, parse);
1784 smap_start_sock(psock, sock);
1785 }
1786
1787
1788
1789
1790
1791
1792 list_add_tail(&e->list, &psock->maps);
1793 write_unlock_bh(&sock->sk_callback_lock);
1794
1795 osock = xchg(&stab->sock_map[i], sock);
1796 if (osock) {
1797 struct smap_psock *opsock = smap_psock_sk(osock);
1798
1799 write_lock_bh(&osock->sk_callback_lock);
1800 smap_list_remove(opsock, &stab->sock_map[i]);
1801 smap_release_sock(opsock, osock);
1802 write_unlock_bh(&osock->sk_callback_lock);
1803 }
1804 return 0;
1805out_free:
1806 smap_release_sock(psock, sock);
1807out_progs:
1808 if (parse && verdict) {
1809 bpf_prog_put(parse);
1810 bpf_prog_put(verdict);
1811 }
1812 if (tx_msg)
1813 bpf_prog_put(tx_msg);
1814 write_unlock_bh(&sock->sk_callback_lock);
1815 kfree(e);
1816 return err;
1817}
1818
1819int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
1820{
1821 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1822 struct bpf_prog *orig;
1823
1824 if (unlikely(map->map_type != BPF_MAP_TYPE_SOCKMAP))
1825 return -EINVAL;
1826
1827 switch (type) {
1828 case BPF_SK_MSG_VERDICT:
1829 orig = xchg(&stab->bpf_tx_msg, prog);
1830 break;
1831 case BPF_SK_SKB_STREAM_PARSER:
1832 orig = xchg(&stab->bpf_parse, prog);
1833 break;
1834 case BPF_SK_SKB_STREAM_VERDICT:
1835 orig = xchg(&stab->bpf_verdict, prog);
1836 break;
1837 default:
1838 return -EOPNOTSUPP;
1839 }
1840
1841 if (orig)
1842 bpf_prog_put(orig);
1843
1844 return 0;
1845}
1846
1847static void *sock_map_lookup(struct bpf_map *map, void *key)
1848{
1849 return NULL;
1850}
1851
1852static int sock_map_update_elem(struct bpf_map *map,
1853 void *key, void *value, u64 flags)
1854{
1855 struct bpf_sock_ops_kern skops;
1856 u32 fd = *(u32 *)value;
1857 struct socket *socket;
1858 int err;
1859
1860 socket = sockfd_lookup(fd, &err);
1861 if (!socket)
1862 return err;
1863
1864 skops.sk = socket->sk;
1865 if (!skops.sk) {
1866 fput(socket->file);
1867 return -EINVAL;
1868 }
1869
1870 if (skops.sk->sk_type != SOCK_STREAM ||
1871 skops.sk->sk_protocol != IPPROTO_TCP) {
1872 fput(socket->file);
1873 return -EOPNOTSUPP;
1874 }
1875
1876 err = sock_map_ctx_update_elem(&skops, map, key, flags);
1877 fput(socket->file);
1878 return err;
1879}
1880
1881static void sock_map_release(struct bpf_map *map)
1882{
1883 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1884 struct bpf_prog *orig;
1885
1886 orig = xchg(&stab->bpf_parse, NULL);
1887 if (orig)
1888 bpf_prog_put(orig);
1889 orig = xchg(&stab->bpf_verdict, NULL);
1890 if (orig)
1891 bpf_prog_put(orig);
1892
1893 orig = xchg(&stab->bpf_tx_msg, NULL);
1894 if (orig)
1895 bpf_prog_put(orig);
1896}
1897
1898const struct bpf_map_ops sock_map_ops = {
1899 .map_alloc = sock_map_alloc,
1900 .map_free = sock_map_free,
1901 .map_lookup_elem = sock_map_lookup,
1902 .map_get_next_key = sock_map_get_next_key,
1903 .map_update_elem = sock_map_update_elem,
1904 .map_delete_elem = sock_map_delete_elem,
1905 .map_release_uref = sock_map_release,
1906};
1907
1908BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
1909 struct bpf_map *, map, void *, key, u64, flags)
1910{
1911 WARN_ON_ONCE(!rcu_read_lock_held());
1912 return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
1913}
1914
1915const struct bpf_func_proto bpf_sock_map_update_proto = {
1916 .func = bpf_sock_map_update,
1917 .gpl_only = false,
1918 .pkt_access = true,
1919 .ret_type = RET_INTEGER,
1920 .arg1_type = ARG_PTR_TO_CTX,
1921 .arg2_type = ARG_CONST_MAP_PTR,
1922 .arg3_type = ARG_PTR_TO_MAP_KEY,
1923 .arg4_type = ARG_ANYTHING,
1924};
1925