1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31#include <linux/bpf.h>
32#include <net/sock.h>
33#include <linux/filter.h>
34#include <linux/errno.h>
35#include <linux/file.h>
36#include <linux/kernel.h>
37#include <linux/net.h>
38#include <linux/skbuff.h>
39#include <linux/workqueue.h>
40#include <linux/list.h>
41#include <linux/mm.h>
42#include <net/strparser.h>
43#include <net/tcp.h>
44#include <linux/ptr_ring.h>
45#include <net/inet_common.h>
46#include <linux/sched/signal.h>
47
48#define SOCK_CREATE_FLAG_MASK \
49 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
50
51struct bpf_sock_progs {
52 struct bpf_prog *bpf_tx_msg;
53 struct bpf_prog *bpf_parse;
54 struct bpf_prog *bpf_verdict;
55};
56
57struct bpf_stab {
58 struct bpf_map map;
59 struct sock **sock_map;
60 struct bpf_sock_progs progs;
61 raw_spinlock_t lock;
62};
63
64struct bucket {
65 struct hlist_head head;
66 raw_spinlock_t lock;
67};
68
69struct bpf_htab {
70 struct bpf_map map;
71 struct bucket *buckets;
72 atomic_t count;
73 u32 n_buckets;
74 u32 elem_size;
75 struct bpf_sock_progs progs;
76 struct rcu_head rcu;
77};
78
79struct htab_elem {
80 struct rcu_head rcu;
81 struct hlist_node hash_node;
82 u32 hash;
83 struct sock *sk;
84 char key[0];
85};
86
87enum smap_psock_state {
88 SMAP_TX_RUNNING,
89};
90
91struct smap_psock_map_entry {
92 struct list_head list;
93 struct bpf_map *map;
94 struct sock **entry;
95 struct htab_elem __rcu *hash_link;
96};
97
98struct smap_psock {
99 struct rcu_head rcu;
100 refcount_t refcnt;
101
102
103 struct sk_buff_head rxqueue;
104 bool strp_enabled;
105
106
107 int save_rem;
108 int save_off;
109 struct sk_buff *save_skb;
110
111
112 struct sock *sk_redir;
113 int apply_bytes;
114 int cork_bytes;
115 int sg_size;
116 int eval;
117 struct sk_msg_buff *cork;
118 struct list_head ingress;
119
120 struct strparser strp;
121 struct bpf_prog *bpf_tx_msg;
122 struct bpf_prog *bpf_parse;
123 struct bpf_prog *bpf_verdict;
124 struct list_head maps;
125 spinlock_t maps_lock;
126
127
128 struct sock *sock;
129 unsigned long state;
130
131 struct work_struct tx_work;
132 struct work_struct gc_work;
133
134 struct proto *sk_proto;
135 void (*save_unhash)(struct sock *sk);
136 void (*save_close)(struct sock *sk, long timeout);
137 void (*save_data_ready)(struct sock *sk);
138 void (*save_write_space)(struct sock *sk);
139};
140
141static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
142static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
143 int nonblock, int flags, int *addr_len);
144static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
145static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
146 int offset, size_t size, int flags);
147static void bpf_tcp_unhash(struct sock *sk);
148static void bpf_tcp_close(struct sock *sk, long timeout);
149
150static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
151{
152 return rcu_dereference_sk_user_data(sk);
153}
154
155static bool bpf_tcp_stream_read(const struct sock *sk)
156{
157 struct smap_psock *psock;
158 bool empty = true;
159
160 rcu_read_lock();
161 psock = smap_psock_sk(sk);
162 if (unlikely(!psock))
163 goto out;
164 empty = list_empty(&psock->ingress);
165out:
166 rcu_read_unlock();
167 return !empty;
168}
169
170enum {
171 SOCKMAP_IPV4,
172 SOCKMAP_IPV6,
173 SOCKMAP_NUM_PROTS,
174};
175
176enum {
177 SOCKMAP_BASE,
178 SOCKMAP_TX,
179 SOCKMAP_NUM_CONFIGS,
180};
181
182static struct proto *saved_tcpv6_prot __read_mostly;
183static DEFINE_SPINLOCK(tcpv6_prot_lock);
184static struct proto bpf_tcp_prots[SOCKMAP_NUM_PROTS][SOCKMAP_NUM_CONFIGS];
185static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS],
186 struct proto *base)
187{
188 prot[SOCKMAP_BASE] = *base;
189 prot[SOCKMAP_BASE].unhash = bpf_tcp_unhash;
190 prot[SOCKMAP_BASE].close = bpf_tcp_close;
191 prot[SOCKMAP_BASE].recvmsg = bpf_tcp_recvmsg;
192 prot[SOCKMAP_BASE].stream_memory_read = bpf_tcp_stream_read;
193
194 prot[SOCKMAP_TX] = prot[SOCKMAP_BASE];
195 prot[SOCKMAP_TX].sendmsg = bpf_tcp_sendmsg;
196 prot[SOCKMAP_TX].sendpage = bpf_tcp_sendpage;
197}
198
199static void update_sk_prot(struct sock *sk, struct smap_psock *psock)
200{
201 int family = sk->sk_family == AF_INET6 ? SOCKMAP_IPV6 : SOCKMAP_IPV4;
202 int conf = psock->bpf_tx_msg ? SOCKMAP_TX : SOCKMAP_BASE;
203
204 sk->sk_prot = &bpf_tcp_prots[family][conf];
205}
206
207static int bpf_tcp_init(struct sock *sk)
208{
209 struct smap_psock *psock;
210
211 rcu_read_lock();
212 psock = smap_psock_sk(sk);
213 if (unlikely(!psock)) {
214 rcu_read_unlock();
215 return -EINVAL;
216 }
217
218 if (unlikely(psock->sk_proto)) {
219 rcu_read_unlock();
220 return -EBUSY;
221 }
222
223 psock->save_unhash = sk->sk_prot->unhash;
224 psock->save_close = sk->sk_prot->close;
225 psock->sk_proto = sk->sk_prot;
226
227
228 if (sk->sk_family == AF_INET6 &&
229 unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
230 spin_lock_bh(&tcpv6_prot_lock);
231 if (likely(sk->sk_prot != saved_tcpv6_prot)) {
232 build_protos(bpf_tcp_prots[SOCKMAP_IPV6], sk->sk_prot);
233 smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
234 }
235 spin_unlock_bh(&tcpv6_prot_lock);
236 }
237 update_sk_prot(sk, psock);
238 rcu_read_unlock();
239 return 0;
240}
241
242static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
243static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);
244
245static void bpf_tcp_release(struct sock *sk)
246{
247 struct smap_psock *psock;
248
249 rcu_read_lock();
250 psock = smap_psock_sk(sk);
251 if (unlikely(!psock))
252 goto out;
253
254 if (psock->cork) {
255 free_start_sg(psock->sock, psock->cork, true);
256 kfree(psock->cork);
257 psock->cork = NULL;
258 }
259
260 if (psock->sk_proto) {
261 sk->sk_prot = psock->sk_proto;
262 psock->sk_proto = NULL;
263 }
264out:
265 rcu_read_unlock();
266}
267
268static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
269 u32 hash, void *key, u32 key_size)
270{
271 struct htab_elem *l;
272
273 hlist_for_each_entry_rcu(l, head, hash_node) {
274 if (l->hash == hash && !memcmp(&l->key, key, key_size))
275 return l;
276 }
277
278 return NULL;
279}
280
281static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
282{
283 return &htab->buckets[hash & (htab->n_buckets - 1)];
284}
285
286static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
287{
288 return &__select_bucket(htab, hash)->head;
289}
290
291static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
292{
293 atomic_dec(&htab->count);
294 kfree_rcu(l, rcu);
295}
296
297static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
298 struct smap_psock *psock)
299{
300 struct smap_psock_map_entry *e;
301
302 spin_lock_bh(&psock->maps_lock);
303 e = list_first_entry_or_null(&psock->maps,
304 struct smap_psock_map_entry,
305 list);
306 if (e)
307 list_del(&e->list);
308 spin_unlock_bh(&psock->maps_lock);
309 return e;
310}
311
312static void bpf_tcp_remove(struct sock *sk, struct smap_psock *psock)
313{
314 struct smap_psock_map_entry *e;
315 struct sk_msg_buff *md, *mtmp;
316 struct sock *osk;
317
318 if (psock->cork) {
319 free_start_sg(psock->sock, psock->cork, true);
320 kfree(psock->cork);
321 psock->cork = NULL;
322 }
323
324 list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
325 list_del(&md->list);
326 free_start_sg(psock->sock, md, true);
327 kfree(md);
328 }
329
330 e = psock_map_pop(sk, psock);
331 while (e) {
332 if (e->entry) {
333 struct bpf_stab *stab = container_of(e->map, struct bpf_stab, map);
334
335 raw_spin_lock_bh(&stab->lock);
336 osk = *e->entry;
337 if (osk == sk) {
338 *e->entry = NULL;
339 smap_release_sock(psock, sk);
340 }
341 raw_spin_unlock_bh(&stab->lock);
342 } else {
343 struct htab_elem *link = rcu_dereference(e->hash_link);
344 struct bpf_htab *htab = container_of(e->map, struct bpf_htab, map);
345 struct hlist_head *head;
346 struct htab_elem *l;
347 struct bucket *b;
348
349 b = __select_bucket(htab, link->hash);
350 head = &b->head;
351 raw_spin_lock_bh(&b->lock);
352 l = lookup_elem_raw(head,
353 link->hash, link->key,
354 htab->map.key_size);
355
356
357
358 if (l && l == link) {
359 hlist_del_rcu(&link->hash_node);
360 smap_release_sock(psock, link->sk);
361 free_htab_elem(htab, link);
362 }
363 raw_spin_unlock_bh(&b->lock);
364 }
365 kfree(e);
366 e = psock_map_pop(sk, psock);
367 }
368}
369
370static void bpf_tcp_unhash(struct sock *sk)
371{
372 void (*unhash_fun)(struct sock *sk);
373 struct smap_psock *psock;
374
375 rcu_read_lock();
376 psock = smap_psock_sk(sk);
377 if (unlikely(!psock)) {
378 rcu_read_unlock();
379 if (sk->sk_prot->unhash)
380 sk->sk_prot->unhash(sk);
381 return;
382 }
383 unhash_fun = psock->save_unhash;
384 bpf_tcp_remove(sk, psock);
385 rcu_read_unlock();
386 unhash_fun(sk);
387}
388
389static void bpf_tcp_close(struct sock *sk, long timeout)
390{
391 void (*close_fun)(struct sock *sk, long timeout);
392 struct smap_psock *psock;
393
394 lock_sock(sk);
395 rcu_read_lock();
396 psock = smap_psock_sk(sk);
397 if (unlikely(!psock)) {
398 rcu_read_unlock();
399 release_sock(sk);
400 return sk->sk_prot->close(sk, timeout);
401 }
402 close_fun = psock->save_close;
403 bpf_tcp_remove(sk, psock);
404 rcu_read_unlock();
405 release_sock(sk);
406 close_fun(sk, timeout);
407}
408
409enum __sk_action {
410 __SK_DROP = 0,
411 __SK_PASS,
412 __SK_REDIRECT,
413 __SK_NONE,
414};
415
416static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
417 .name = "bpf_tcp",
418 .uid = TCP_ULP_BPF,
419 .user_visible = false,
420 .owner = NULL,
421 .init = bpf_tcp_init,
422 .release = bpf_tcp_release,
423};
424
425static int memcopy_from_iter(struct sock *sk,
426 struct sk_msg_buff *md,
427 struct iov_iter *from, int bytes)
428{
429 struct scatterlist *sg = md->sg_data;
430 int i = md->sg_curr, rc = -ENOSPC;
431
432 do {
433 int copy;
434 char *to;
435
436 if (md->sg_copybreak >= sg[i].length) {
437 md->sg_copybreak = 0;
438
439 if (++i == MAX_SKB_FRAGS)
440 i = 0;
441
442 if (i == md->sg_end)
443 break;
444 }
445
446 copy = sg[i].length - md->sg_copybreak;
447 to = sg_virt(&sg[i]) + md->sg_copybreak;
448 md->sg_copybreak += copy;
449
450 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
451 rc = copy_from_iter_nocache(to, copy, from);
452 else
453 rc = copy_from_iter(to, copy, from);
454
455 if (rc != copy) {
456 rc = -EFAULT;
457 goto out;
458 }
459
460 bytes -= copy;
461 if (!bytes)
462 break;
463
464 md->sg_copybreak = 0;
465 if (++i == MAX_SKB_FRAGS)
466 i = 0;
467 } while (i != md->sg_end);
468out:
469 md->sg_curr = i;
470 return rc;
471}
472
473static int bpf_tcp_push(struct sock *sk, int apply_bytes,
474 struct sk_msg_buff *md,
475 int flags, bool uncharge)
476{
477 bool apply = apply_bytes;
478 struct scatterlist *sg;
479 int offset, ret = 0;
480 struct page *p;
481 size_t size;
482
483 while (1) {
484 sg = md->sg_data + md->sg_start;
485 size = (apply && apply_bytes < sg->length) ?
486 apply_bytes : sg->length;
487 offset = sg->offset;
488
489 tcp_rate_check_app_limited(sk);
490 p = sg_page(sg);
491retry:
492 ret = do_tcp_sendpages(sk, p, offset, size, flags);
493 if (ret != size) {
494 if (ret > 0) {
495 if (apply)
496 apply_bytes -= ret;
497
498 sg->offset += ret;
499 sg->length -= ret;
500 size -= ret;
501 offset += ret;
502 if (uncharge)
503 sk_mem_uncharge(sk, ret);
504 goto retry;
505 }
506
507 return ret;
508 }
509
510 if (apply)
511 apply_bytes -= ret;
512 sg->offset += ret;
513 sg->length -= ret;
514 if (uncharge)
515 sk_mem_uncharge(sk, ret);
516
517 if (!sg->length) {
518 put_page(p);
519 md->sg_start++;
520 if (md->sg_start == MAX_SKB_FRAGS)
521 md->sg_start = 0;
522 sg_init_table(sg, 1);
523
524 if (md->sg_start == md->sg_end)
525 break;
526 }
527
528 if (apply && !apply_bytes)
529 break;
530 }
531 return 0;
532}
533
534static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
535{
536 struct scatterlist *sg = md->sg_data + md->sg_start;
537
538 if (md->sg_copy[md->sg_start]) {
539 md->data = md->data_end = 0;
540 } else {
541 md->data = sg_virt(sg);
542 md->data_end = md->data + sg->length;
543 }
544}
545
546static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
547{
548 struct scatterlist *sg = md->sg_data;
549 int i = md->sg_start;
550
551 do {
552 int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
553
554 sk_mem_uncharge(sk, uncharge);
555 bytes -= uncharge;
556 if (!bytes)
557 break;
558 i++;
559 if (i == MAX_SKB_FRAGS)
560 i = 0;
561 } while (i != md->sg_end);
562}
563
564static void free_bytes_sg(struct sock *sk, int bytes,
565 struct sk_msg_buff *md, bool charge)
566{
567 struct scatterlist *sg = md->sg_data;
568 int i = md->sg_start, free;
569
570 while (bytes && sg[i].length) {
571 free = sg[i].length;
572 if (bytes < free) {
573 sg[i].length -= bytes;
574 sg[i].offset += bytes;
575 if (charge)
576 sk_mem_uncharge(sk, bytes);
577 break;
578 }
579
580 if (charge)
581 sk_mem_uncharge(sk, sg[i].length);
582 put_page(sg_page(&sg[i]));
583 bytes -= sg[i].length;
584 sg[i].length = 0;
585 sg[i].page_link = 0;
586 sg[i].offset = 0;
587 i++;
588
589 if (i == MAX_SKB_FRAGS)
590 i = 0;
591 }
592 md->sg_start = i;
593}
594
595static int free_sg(struct sock *sk, int start,
596 struct sk_msg_buff *md, bool charge)
597{
598 struct scatterlist *sg = md->sg_data;
599 int i = start, free = 0;
600
601 while (sg[i].length) {
602 free += sg[i].length;
603 if (charge)
604 sk_mem_uncharge(sk, sg[i].length);
605 if (!md->skb)
606 put_page(sg_page(&sg[i]));
607 sg[i].length = 0;
608 sg[i].page_link = 0;
609 sg[i].offset = 0;
610 i++;
611
612 if (i == MAX_SKB_FRAGS)
613 i = 0;
614 }
615 if (md->skb)
616 consume_skb(md->skb);
617
618 return free;
619}
620
621static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)
622{
623 int free = free_sg(sk, md->sg_start, md, charge);
624
625 md->sg_start = md->sg_end;
626 return free;
627}
628
629static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
630{
631 return free_sg(sk, md->sg_curr, md, true);
632}
633
634static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
635{
636 return ((_rc == SK_PASS) ?
637 (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
638 __SK_DROP);
639}
640
641static unsigned int smap_do_tx_msg(struct sock *sk,
642 struct smap_psock *psock,
643 struct sk_msg_buff *md)
644{
645 struct bpf_prog *prog;
646 unsigned int rc, _rc;
647
648 preempt_disable();
649 rcu_read_lock();
650
651
652 prog = READ_ONCE(psock->bpf_tx_msg);
653 if (unlikely(!prog)) {
654 _rc = SK_PASS;
655 goto verdict;
656 }
657
658 bpf_compute_data_pointers_sg(md);
659 md->sk = sk;
660 rc = (*prog->bpf_func)(md, prog->insnsi);
661 psock->apply_bytes = md->apply_bytes;
662
663
664 _rc = bpf_map_msg_verdict(rc, md);
665
666
667
668
669
670
671 if (_rc == __SK_REDIRECT) {
672 if (psock->sk_redir)
673 sock_put(psock->sk_redir);
674 psock->sk_redir = do_msg_redirect_map(md);
675 if (!psock->sk_redir) {
676 _rc = __SK_DROP;
677 goto verdict;
678 }
679 sock_hold(psock->sk_redir);
680 }
681verdict:
682 rcu_read_unlock();
683 preempt_enable();
684
685 return _rc;
686}
687
688static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
689 struct smap_psock *psock,
690 struct sk_msg_buff *md, int flags)
691{
692 bool apply = apply_bytes;
693 size_t size, copied = 0;
694 struct sk_msg_buff *r;
695 int err = 0, i;
696
697 r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
698 if (unlikely(!r))
699 return -ENOMEM;
700
701 lock_sock(sk);
702 r->sg_start = md->sg_start;
703 i = md->sg_start;
704
705 do {
706 size = (apply && apply_bytes < md->sg_data[i].length) ?
707 apply_bytes : md->sg_data[i].length;
708
709 if (!sk_wmem_schedule(sk, size)) {
710 if (!copied)
711 err = -ENOMEM;
712 break;
713 }
714
715 sk_mem_charge(sk, size);
716 r->sg_data[i] = md->sg_data[i];
717 r->sg_data[i].length = size;
718 md->sg_data[i].length -= size;
719 md->sg_data[i].offset += size;
720 copied += size;
721
722 if (md->sg_data[i].length) {
723 get_page(sg_page(&r->sg_data[i]));
724 r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
725 } else {
726 i++;
727 if (i == MAX_SKB_FRAGS)
728 i = 0;
729 r->sg_end = i;
730 }
731
732 if (apply) {
733 apply_bytes -= size;
734 if (!apply_bytes)
735 break;
736 }
737 } while (i != md->sg_end);
738
739 md->sg_start = i;
740
741 if (!err) {
742 list_add_tail(&r->list, &psock->ingress);
743 sk->sk_data_ready(sk);
744 } else {
745 free_start_sg(sk, r, true);
746 kfree(r);
747 }
748
749 release_sock(sk);
750 return err;
751}
752
753static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
754 struct sk_msg_buff *md,
755 int flags)
756{
757 bool ingress = !!(md->flags & BPF_F_INGRESS);
758 struct smap_psock *psock;
759 int err = 0;
760
761 rcu_read_lock();
762 psock = smap_psock_sk(sk);
763 if (unlikely(!psock))
764 goto out_rcu;
765
766 if (!refcount_inc_not_zero(&psock->refcnt))
767 goto out_rcu;
768
769 rcu_read_unlock();
770
771 if (ingress) {
772 err = bpf_tcp_ingress(sk, send, psock, md, flags);
773 } else {
774 lock_sock(sk);
775 err = bpf_tcp_push(sk, send, md, flags, false);
776 release_sock(sk);
777 }
778 smap_release_sock(psock, sk);
779 return err;
780out_rcu:
781 rcu_read_unlock();
782 return 0;
783}
784
785static inline void bpf_md_init(struct smap_psock *psock)
786{
787 if (!psock->apply_bytes) {
788 psock->eval = __SK_NONE;
789 if (psock->sk_redir) {
790 sock_put(psock->sk_redir);
791 psock->sk_redir = NULL;
792 }
793 }
794}
795
796static void apply_bytes_dec(struct smap_psock *psock, int i)
797{
798 if (psock->apply_bytes) {
799 if (psock->apply_bytes < i)
800 psock->apply_bytes = 0;
801 else
802 psock->apply_bytes -= i;
803 }
804}
805
806static int bpf_exec_tx_verdict(struct smap_psock *psock,
807 struct sk_msg_buff *m,
808 struct sock *sk,
809 int *copied, int flags)
810{
811 bool cork = false, enospc = (m->sg_start == m->sg_end);
812 struct sock *redir;
813 int err = 0;
814 int send;
815
816more_data:
817 if (psock->eval == __SK_NONE)
818 psock->eval = smap_do_tx_msg(sk, psock, m);
819
820 if (m->cork_bytes &&
821 m->cork_bytes > psock->sg_size && !enospc) {
822 psock->cork_bytes = m->cork_bytes - psock->sg_size;
823 if (!psock->cork) {
824 psock->cork = kcalloc(1,
825 sizeof(struct sk_msg_buff),
826 GFP_ATOMIC | __GFP_NOWARN);
827
828 if (!psock->cork) {
829 err = -ENOMEM;
830 goto out_err;
831 }
832 }
833 memcpy(psock->cork, m, sizeof(*m));
834 goto out_err;
835 }
836
837 send = psock->sg_size;
838 if (psock->apply_bytes && psock->apply_bytes < send)
839 send = psock->apply_bytes;
840
841 switch (psock->eval) {
842 case __SK_PASS:
843 err = bpf_tcp_push(sk, send, m, flags, true);
844 if (unlikely(err)) {
845 *copied -= free_start_sg(sk, m, true);
846 break;
847 }
848
849 apply_bytes_dec(psock, send);
850 psock->sg_size -= send;
851 break;
852 case __SK_REDIRECT:
853 redir = psock->sk_redir;
854 apply_bytes_dec(psock, send);
855
856 if (psock->cork) {
857 cork = true;
858 psock->cork = NULL;
859 }
860
861 return_mem_sg(sk, send, m);
862 release_sock(sk);
863
864 err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
865 lock_sock(sk);
866
867 if (unlikely(err < 0)) {
868 int free = free_start_sg(sk, m, false);
869
870 psock->sg_size = 0;
871 if (!cork)
872 *copied -= free;
873 } else {
874 psock->sg_size -= send;
875 }
876
877 if (cork) {
878 free_start_sg(sk, m, true);
879 psock->sg_size = 0;
880 kfree(m);
881 m = NULL;
882 err = 0;
883 }
884 break;
885 case __SK_DROP:
886 default:
887 free_bytes_sg(sk, send, m, true);
888 apply_bytes_dec(psock, send);
889 *copied -= send;
890 psock->sg_size -= send;
891 err = -EACCES;
892 break;
893 }
894
895 if (likely(!err)) {
896 bpf_md_init(psock);
897 if (m &&
898 m->sg_data[m->sg_start].page_link &&
899 m->sg_data[m->sg_start].length)
900 goto more_data;
901 }
902
903out_err:
904 return err;
905}
906
907static int bpf_wait_data(struct sock *sk,
908 struct smap_psock *psk, int flags,
909 long timeo, int *err)
910{
911 int rc;
912
913 DEFINE_WAIT_FUNC(wait, woken_wake_function);
914
915 add_wait_queue(sk_sleep(sk), &wait);
916 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
917 rc = sk_wait_event(sk, &timeo,
918 !list_empty(&psk->ingress) ||
919 !skb_queue_empty(&sk->sk_receive_queue),
920 &wait);
921 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
922 remove_wait_queue(sk_sleep(sk), &wait);
923
924 return rc;
925}
926
927static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
928 int nonblock, int flags, int *addr_len)
929{
930 struct iov_iter *iter = &msg->msg_iter;
931 struct smap_psock *psock;
932 int copied = 0;
933
934 if (unlikely(flags & MSG_ERRQUEUE))
935 return inet_recv_error(sk, msg, len, addr_len);
936 if (!skb_queue_empty(&sk->sk_receive_queue))
937 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
938
939 rcu_read_lock();
940 psock = smap_psock_sk(sk);
941 if (unlikely(!psock))
942 goto out;
943
944 if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
945 goto out;
946 rcu_read_unlock();
947
948 lock_sock(sk);
949bytes_ready:
950 while (copied != len) {
951 struct scatterlist *sg;
952 struct sk_msg_buff *md;
953 int i;
954
955 md = list_first_entry_or_null(&psock->ingress,
956 struct sk_msg_buff, list);
957 if (unlikely(!md))
958 break;
959 i = md->sg_start;
960 do {
961 struct page *page;
962 int n, copy;
963
964 sg = &md->sg_data[i];
965 copy = sg->length;
966 page = sg_page(sg);
967
968 if (copied + copy > len)
969 copy = len - copied;
970
971 n = copy_page_to_iter(page, sg->offset, copy, iter);
972 if (n != copy) {
973 md->sg_start = i;
974 release_sock(sk);
975 smap_release_sock(psock, sk);
976 return -EFAULT;
977 }
978
979 copied += copy;
980 sg->offset += copy;
981 sg->length -= copy;
982 sk_mem_uncharge(sk, copy);
983
984 if (!sg->length) {
985 i++;
986 if (i == MAX_SKB_FRAGS)
987 i = 0;
988 if (!md->skb)
989 put_page(page);
990 }
991 if (copied == len)
992 break;
993 } while (i != md->sg_end);
994 md->sg_start = i;
995
996 if (!sg->length && md->sg_start == md->sg_end) {
997 list_del(&md->list);
998 if (md->skb)
999 consume_skb(md->skb);
1000 kfree(md);
1001 }
1002 }
1003
1004 if (!copied) {
1005 long timeo;
1006 int data;
1007 int err = 0;
1008
1009 timeo = sock_rcvtimeo(sk, nonblock);
1010 data = bpf_wait_data(sk, psock, flags, timeo, &err);
1011
1012 if (data) {
1013 if (!skb_queue_empty(&sk->sk_receive_queue)) {
1014 release_sock(sk);
1015 smap_release_sock(psock, sk);
1016 copied = tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1017 return copied;
1018 }
1019 goto bytes_ready;
1020 }
1021
1022 if (err)
1023 copied = err;
1024 }
1025
1026 release_sock(sk);
1027 smap_release_sock(psock, sk);
1028 return copied;
1029out:
1030 rcu_read_unlock();
1031 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1032}
1033
1034
1035static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
1036{
1037 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
1038 struct sk_msg_buff md = {0};
1039 unsigned int sg_copy = 0;
1040 struct smap_psock *psock;
1041 int copied = 0, err = 0;
1042 struct scatterlist *sg;
1043 long timeo;
1044
1045
1046
1047
1048
1049
1050 rcu_read_lock();
1051 psock = smap_psock_sk(sk);
1052 if (unlikely(!psock)) {
1053 rcu_read_unlock();
1054 return tcp_sendmsg(sk, msg, size);
1055 }
1056
1057
1058
1059
1060
1061
1062 if (!refcount_inc_not_zero(&psock->refcnt)) {
1063 rcu_read_unlock();
1064 return tcp_sendmsg(sk, msg, size);
1065 }
1066
1067 sg = md.sg_data;
1068 sg_init_marker(sg, MAX_SKB_FRAGS);
1069 rcu_read_unlock();
1070
1071 lock_sock(sk);
1072 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1073
1074 while (msg_data_left(msg)) {
1075 struct sk_msg_buff *m = NULL;
1076 bool enospc = false;
1077 int copy;
1078
1079 if (sk->sk_err) {
1080 err = -sk->sk_err;
1081 goto out_err;
1082 }
1083
1084 copy = msg_data_left(msg);
1085 if (!sk_stream_memory_free(sk))
1086 goto wait_for_sndbuf;
1087
1088 m = psock->cork_bytes ? psock->cork : &md;
1089 m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
1090 err = sk_alloc_sg(sk, copy, m->sg_data,
1091 m->sg_start, &m->sg_end, &sg_copy,
1092 m->sg_end - 1);
1093 if (err) {
1094 if (err != -ENOSPC)
1095 goto wait_for_memory;
1096 enospc = true;
1097 copy = sg_copy;
1098 }
1099
1100 err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
1101 if (err < 0) {
1102 free_curr_sg(sk, m);
1103 goto out_err;
1104 }
1105
1106 psock->sg_size += copy;
1107 copied += copy;
1108 sg_copy = 0;
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121 if (psock->cork_bytes) {
1122 if (copy > psock->cork_bytes)
1123 psock->cork_bytes = 0;
1124 else
1125 psock->cork_bytes -= copy;
1126
1127 if (psock->cork_bytes && !enospc)
1128 goto out_cork;
1129
1130
1131 psock->eval = __SK_NONE;
1132 psock->cork_bytes = 0;
1133 }
1134
1135 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1136 if (unlikely(err < 0))
1137 goto out_err;
1138 continue;
1139wait_for_sndbuf:
1140 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1141wait_for_memory:
1142 err = sk_stream_wait_memory(sk, &timeo);
1143 if (err) {
1144 if (m && m != psock->cork)
1145 free_start_sg(sk, m, true);
1146 goto out_err;
1147 }
1148 }
1149out_err:
1150 if (err < 0)
1151 err = sk_stream_error(sk, msg->msg_flags, err);
1152out_cork:
1153 release_sock(sk);
1154 smap_release_sock(psock, sk);
1155 return copied ? copied : err;
1156}
1157
1158static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
1159 int offset, size_t size, int flags)
1160{
1161 struct sk_msg_buff md = {0}, *m = NULL;
1162 int err = 0, copied = 0;
1163 struct smap_psock *psock;
1164 struct scatterlist *sg;
1165 bool enospc = false;
1166
1167 rcu_read_lock();
1168 psock = smap_psock_sk(sk);
1169 if (unlikely(!psock))
1170 goto accept;
1171
1172 if (!refcount_inc_not_zero(&psock->refcnt))
1173 goto accept;
1174 rcu_read_unlock();
1175
1176 lock_sock(sk);
1177
1178 if (psock->cork_bytes) {
1179 m = psock->cork;
1180 sg = &m->sg_data[m->sg_end];
1181 } else {
1182 m = &md;
1183 sg = m->sg_data;
1184 sg_init_marker(sg, MAX_SKB_FRAGS);
1185 }
1186
1187
1188 if (unlikely(m->sg_end == m->sg_start &&
1189 m->sg_data[m->sg_end].length))
1190 goto out_err;
1191
1192 psock->sg_size += size;
1193 sg_set_page(sg, page, size, offset);
1194 get_page(page);
1195 m->sg_copy[m->sg_end] = true;
1196 sk_mem_charge(sk, size);
1197 m->sg_end++;
1198 copied = size;
1199
1200 if (m->sg_end == MAX_SKB_FRAGS)
1201 m->sg_end = 0;
1202
1203 if (m->sg_end == m->sg_start)
1204 enospc = true;
1205
1206 if (psock->cork_bytes) {
1207 if (size > psock->cork_bytes)
1208 psock->cork_bytes = 0;
1209 else
1210 psock->cork_bytes -= size;
1211
1212 if (psock->cork_bytes && !enospc)
1213 goto out_err;
1214
1215
1216 psock->eval = __SK_NONE;
1217 psock->cork_bytes = 0;
1218 }
1219
1220 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1221out_err:
1222 release_sock(sk);
1223 smap_release_sock(psock, sk);
1224 return copied ? copied : err;
1225accept:
1226 rcu_read_unlock();
1227 return tcp_sendpage(sk, page, offset, size, flags);
1228}
1229
1230static void bpf_tcp_msg_add(struct smap_psock *psock,
1231 struct sock *sk,
1232 struct bpf_prog *tx_msg)
1233{
1234 struct bpf_prog *orig_tx_msg;
1235
1236 orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
1237 if (orig_tx_msg)
1238 bpf_prog_put(orig_tx_msg);
1239}
1240
1241static int bpf_tcp_ulp_register(void)
1242{
1243 build_protos(bpf_tcp_prots[SOCKMAP_IPV4], &tcp_prot);
1244
1245
1246
1247
1248 return tcp_register_ulp(&bpf_tcp_ulp_ops);
1249}
1250
1251static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
1252{
1253 struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
1254 int rc;
1255
1256 if (unlikely(!prog))
1257 return __SK_DROP;
1258
1259 skb_orphan(skb);
1260
1261
1262
1263
1264 TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
1265 skb->sk = psock->sock;
1266 bpf_compute_data_end_sk_skb(skb);
1267 preempt_disable();
1268 rc = (*prog->bpf_func)(skb, prog->insnsi);
1269 preempt_enable();
1270 skb->sk = NULL;
1271
1272
1273 return rc == SK_PASS ?
1274 (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
1275 __SK_DROP;
1276}
1277
1278static int smap_do_ingress(struct smap_psock *psock, struct sk_buff *skb)
1279{
1280 struct sock *sk = psock->sock;
1281 int copied = 0, num_sg;
1282 struct sk_msg_buff *r;
1283
1284 r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_ATOMIC);
1285 if (unlikely(!r))
1286 return -EAGAIN;
1287
1288 if (!sk_rmem_schedule(sk, skb, skb->len)) {
1289 kfree(r);
1290 return -EAGAIN;
1291 }
1292
1293 sg_init_table(r->sg_data, MAX_SKB_FRAGS);
1294 num_sg = skb_to_sgvec(skb, r->sg_data, 0, skb->len);
1295 if (unlikely(num_sg < 0)) {
1296 kfree(r);
1297 return num_sg;
1298 }
1299 sk_mem_charge(sk, skb->len);
1300 copied = skb->len;
1301 r->sg_start = 0;
1302 r->sg_end = num_sg == MAX_SKB_FRAGS ? 0 : num_sg;
1303 r->skb = skb;
1304 list_add_tail(&r->list, &psock->ingress);
1305 sk->sk_data_ready(sk);
1306 return copied;
1307}
1308
1309static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
1310{
1311 struct smap_psock *peer;
1312 struct sock *sk;
1313 __u32 in;
1314 int rc;
1315
1316 rc = smap_verdict_func(psock, skb);
1317 switch (rc) {
1318 case __SK_REDIRECT:
1319 sk = do_sk_redirect_map(skb);
1320 if (!sk) {
1321 kfree_skb(skb);
1322 break;
1323 }
1324
1325 peer = smap_psock_sk(sk);
1326 in = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1327
1328 if (unlikely(!peer || sock_flag(sk, SOCK_DEAD) ||
1329 !test_bit(SMAP_TX_RUNNING, &peer->state))) {
1330 kfree_skb(skb);
1331 break;
1332 }
1333
1334 if (!in && sock_writeable(sk)) {
1335 skb_set_owner_w(skb, sk);
1336 skb_queue_tail(&peer->rxqueue, skb);
1337 schedule_work(&peer->tx_work);
1338 break;
1339 } else if (in &&
1340 atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) {
1341 skb_queue_tail(&peer->rxqueue, skb);
1342 schedule_work(&peer->tx_work);
1343 break;
1344 }
1345
1346 case __SK_DROP:
1347 default:
1348 kfree_skb(skb);
1349 }
1350}
1351
1352static void smap_report_sk_error(struct smap_psock *psock, int err)
1353{
1354 struct sock *sk = psock->sock;
1355
1356 sk->sk_err = err;
1357 sk->sk_error_report(sk);
1358}
1359
1360static void smap_read_sock_strparser(struct strparser *strp,
1361 struct sk_buff *skb)
1362{
1363 struct smap_psock *psock;
1364
1365 rcu_read_lock();
1366 psock = container_of(strp, struct smap_psock, strp);
1367 smap_do_verdict(psock, skb);
1368 rcu_read_unlock();
1369}
1370
1371
1372static void smap_data_ready(struct sock *sk)
1373{
1374 struct smap_psock *psock;
1375
1376 rcu_read_lock();
1377 psock = smap_psock_sk(sk);
1378 if (likely(psock)) {
1379 write_lock_bh(&sk->sk_callback_lock);
1380 strp_data_ready(&psock->strp);
1381 write_unlock_bh(&sk->sk_callback_lock);
1382 }
1383 rcu_read_unlock();
1384}
1385
1386static void smap_tx_work(struct work_struct *w)
1387{
1388 struct smap_psock *psock;
1389 struct sk_buff *skb;
1390 int rem, off, n;
1391
1392 psock = container_of(w, struct smap_psock, tx_work);
1393
1394
1395 lock_sock(psock->sock);
1396 if (psock->save_skb) {
1397 skb = psock->save_skb;
1398 rem = psock->save_rem;
1399 off = psock->save_off;
1400 psock->save_skb = NULL;
1401 goto start;
1402 }
1403
1404 while ((skb = skb_dequeue(&psock->rxqueue))) {
1405 __u32 flags;
1406
1407 rem = skb->len;
1408 off = 0;
1409start:
1410 flags = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1411 do {
1412 if (likely(psock->sock->sk_socket)) {
1413 if (flags)
1414 n = smap_do_ingress(psock, skb);
1415 else
1416 n = skb_send_sock_locked(psock->sock,
1417 skb, off, rem);
1418 } else {
1419 n = -EINVAL;
1420 }
1421
1422 if (n <= 0) {
1423 if (n == -EAGAIN) {
1424
1425 psock->save_skb = skb;
1426 psock->save_rem = rem;
1427 psock->save_off = off;
1428 goto out;
1429 }
1430
1431 smap_report_sk_error(psock, n ? -n : EPIPE);
1432 clear_bit(SMAP_TX_RUNNING, &psock->state);
1433 kfree_skb(skb);
1434 goto out;
1435 }
1436 rem -= n;
1437 off += n;
1438 } while (rem);
1439
1440 if (!flags)
1441 kfree_skb(skb);
1442 }
1443out:
1444 release_sock(psock->sock);
1445}
1446
1447static void smap_write_space(struct sock *sk)
1448{
1449 struct smap_psock *psock;
1450 void (*write_space)(struct sock *sk);
1451
1452 rcu_read_lock();
1453 psock = smap_psock_sk(sk);
1454 if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
1455 schedule_work(&psock->tx_work);
1456 write_space = psock->save_write_space;
1457 rcu_read_unlock();
1458 write_space(sk);
1459}
1460
1461static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
1462{
1463 if (!psock->strp_enabled)
1464 return;
1465 sk->sk_data_ready = psock->save_data_ready;
1466 sk->sk_write_space = psock->save_write_space;
1467 psock->save_data_ready = NULL;
1468 psock->save_write_space = NULL;
1469 strp_stop(&psock->strp);
1470 psock->strp_enabled = false;
1471}
1472
1473static void smap_destroy_psock(struct rcu_head *rcu)
1474{
1475 struct smap_psock *psock = container_of(rcu,
1476 struct smap_psock, rcu);
1477
1478
1479
1480
1481
1482
1483
1484 schedule_work(&psock->gc_work);
1485}
1486
1487static bool psock_is_smap_sk(struct sock *sk)
1488{
1489 return inet_csk(sk)->icsk_ulp_ops == &bpf_tcp_ulp_ops;
1490}
1491
1492static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1493{
1494 if (refcount_dec_and_test(&psock->refcnt)) {
1495 if (psock_is_smap_sk(sock))
1496 tcp_cleanup_ulp(sock);
1497 write_lock_bh(&sock->sk_callback_lock);
1498 smap_stop_sock(psock, sock);
1499 write_unlock_bh(&sock->sk_callback_lock);
1500 clear_bit(SMAP_TX_RUNNING, &psock->state);
1501 rcu_assign_sk_user_data(sock, NULL);
1502 call_rcu_sched(&psock->rcu, smap_destroy_psock);
1503 }
1504}
1505
1506static int smap_parse_func_strparser(struct strparser *strp,
1507 struct sk_buff *skb)
1508{
1509 struct smap_psock *psock;
1510 struct bpf_prog *prog;
1511 int rc;
1512
1513 rcu_read_lock();
1514 psock = container_of(strp, struct smap_psock, strp);
1515 prog = READ_ONCE(psock->bpf_parse);
1516
1517 if (unlikely(!prog)) {
1518 rcu_read_unlock();
1519 return skb->len;
1520 }
1521
1522
1523
1524
1525
1526
1527
1528
1529 skb->sk = psock->sock;
1530 bpf_compute_data_end_sk_skb(skb);
1531 rc = (*prog->bpf_func)(skb, prog->insnsi);
1532 skb->sk = NULL;
1533 rcu_read_unlock();
1534 return rc;
1535}
1536
1537static int smap_read_sock_done(struct strparser *strp, int err)
1538{
1539 return err;
1540}
1541
1542static int smap_init_sock(struct smap_psock *psock,
1543 struct sock *sk)
1544{
1545 static const struct strp_callbacks cb = {
1546 .rcv_msg = smap_read_sock_strparser,
1547 .parse_msg = smap_parse_func_strparser,
1548 .read_sock_done = smap_read_sock_done,
1549 };
1550
1551 return strp_init(&psock->strp, sk, &cb);
1552}
1553
1554static void smap_init_progs(struct smap_psock *psock,
1555 struct bpf_prog *verdict,
1556 struct bpf_prog *parse)
1557{
1558 struct bpf_prog *orig_parse, *orig_verdict;
1559
1560 orig_parse = xchg(&psock->bpf_parse, parse);
1561 orig_verdict = xchg(&psock->bpf_verdict, verdict);
1562
1563 if (orig_verdict)
1564 bpf_prog_put(orig_verdict);
1565 if (orig_parse)
1566 bpf_prog_put(orig_parse);
1567}
1568
1569static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1570{
1571 if (sk->sk_data_ready == smap_data_ready)
1572 return;
1573 psock->save_data_ready = sk->sk_data_ready;
1574 psock->save_write_space = sk->sk_write_space;
1575 sk->sk_data_ready = smap_data_ready;
1576 sk->sk_write_space = smap_write_space;
1577 psock->strp_enabled = true;
1578}
1579
1580static void sock_map_remove_complete(struct bpf_stab *stab)
1581{
1582 bpf_map_area_free(stab->sock_map);
1583 kfree(stab);
1584}
1585
1586static void smap_gc_work(struct work_struct *w)
1587{
1588 struct smap_psock_map_entry *e, *tmp;
1589 struct sk_msg_buff *md, *mtmp;
1590 struct smap_psock *psock;
1591
1592 psock = container_of(w, struct smap_psock, gc_work);
1593
1594
1595 if (psock->strp_enabled)
1596 strp_done(&psock->strp);
1597
1598 cancel_work_sync(&psock->tx_work);
1599 __skb_queue_purge(&psock->rxqueue);
1600
1601
1602 if (psock->bpf_parse)
1603 bpf_prog_put(psock->bpf_parse);
1604 if (psock->bpf_verdict)
1605 bpf_prog_put(psock->bpf_verdict);
1606 if (psock->bpf_tx_msg)
1607 bpf_prog_put(psock->bpf_tx_msg);
1608
1609 if (psock->cork) {
1610 free_start_sg(psock->sock, psock->cork, true);
1611 kfree(psock->cork);
1612 }
1613
1614 list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1615 list_del(&md->list);
1616 free_start_sg(psock->sock, md, true);
1617 kfree(md);
1618 }
1619
1620 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1621 list_del(&e->list);
1622 kfree(e);
1623 }
1624
1625 if (psock->sk_redir)
1626 sock_put(psock->sk_redir);
1627
1628 sock_put(psock->sock);
1629 kfree(psock);
1630}
1631
1632static struct smap_psock *smap_init_psock(struct sock *sock, int node)
1633{
1634 struct smap_psock *psock;
1635
1636 psock = kzalloc_node(sizeof(struct smap_psock),
1637 GFP_ATOMIC | __GFP_NOWARN,
1638 node);
1639 if (!psock)
1640 return ERR_PTR(-ENOMEM);
1641
1642 psock->eval = __SK_NONE;
1643 psock->sock = sock;
1644 skb_queue_head_init(&psock->rxqueue);
1645 INIT_WORK(&psock->tx_work, smap_tx_work);
1646 INIT_WORK(&psock->gc_work, smap_gc_work);
1647 INIT_LIST_HEAD(&psock->maps);
1648 INIT_LIST_HEAD(&psock->ingress);
1649 refcount_set(&psock->refcnt, 1);
1650 spin_lock_init(&psock->maps_lock);
1651
1652 rcu_assign_sk_user_data(sock, psock);
1653 sock_hold(sock);
1654 return psock;
1655}
1656
1657static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1658{
1659 struct bpf_stab *stab;
1660 u64 cost;
1661 int err;
1662
1663 if (!capable(CAP_NET_ADMIN))
1664 return ERR_PTR(-EPERM);
1665
1666
1667 if (attr->max_entries == 0 || attr->key_size != 4 ||
1668 attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1669 return ERR_PTR(-EINVAL);
1670
1671 err = bpf_tcp_ulp_register();
1672 if (err && err != -EEXIST)
1673 return ERR_PTR(err);
1674
1675 stab = kzalloc(sizeof(*stab), GFP_USER);
1676 if (!stab)
1677 return ERR_PTR(-ENOMEM);
1678
1679 bpf_map_init_from_attr(&stab->map, attr);
1680 raw_spin_lock_init(&stab->lock);
1681
1682
1683 cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1684 err = -EINVAL;
1685 if (cost >= U32_MAX - PAGE_SIZE)
1686 goto free_stab;
1687
1688 stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1689
1690
1691 err = bpf_map_precharge_memlock(stab->map.pages);
1692 if (err)
1693 goto free_stab;
1694
1695 err = -ENOMEM;
1696 stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1697 sizeof(struct sock *),
1698 stab->map.numa_node);
1699 if (!stab->sock_map)
1700 goto free_stab;
1701
1702 return &stab->map;
1703free_stab:
1704 kfree(stab);
1705 return ERR_PTR(err);
1706}
1707
1708static void smap_list_map_remove(struct smap_psock *psock,
1709 struct sock **entry)
1710{
1711 struct smap_psock_map_entry *e, *tmp;
1712
1713 spin_lock_bh(&psock->maps_lock);
1714 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1715 if (e->entry == entry) {
1716 list_del(&e->list);
1717 kfree(e);
1718 }
1719 }
1720 spin_unlock_bh(&psock->maps_lock);
1721}
1722
1723static void smap_list_hash_remove(struct smap_psock *psock,
1724 struct htab_elem *hash_link)
1725{
1726 struct smap_psock_map_entry *e, *tmp;
1727
1728 spin_lock_bh(&psock->maps_lock);
1729 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1730 struct htab_elem *c = rcu_dereference(e->hash_link);
1731
1732 if (c == hash_link) {
1733 list_del(&e->list);
1734 kfree(e);
1735 }
1736 }
1737 spin_unlock_bh(&psock->maps_lock);
1738}
1739
1740static void sock_map_free(struct bpf_map *map)
1741{
1742 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1743 int i;
1744
1745 synchronize_rcu();
1746
1747
1748
1749
1750
1751
1752
1753
1754 rcu_read_lock();
1755 raw_spin_lock_bh(&stab->lock);
1756 for (i = 0; i < stab->map.max_entries; i++) {
1757 struct smap_psock *psock;
1758 struct sock *sock;
1759
1760 sock = stab->sock_map[i];
1761 if (!sock)
1762 continue;
1763 stab->sock_map[i] = NULL;
1764 psock = smap_psock_sk(sock);
1765
1766
1767
1768
1769
1770 if (likely(psock)) {
1771 smap_list_map_remove(psock, &stab->sock_map[i]);
1772 smap_release_sock(psock, sock);
1773 }
1774 }
1775 raw_spin_unlock_bh(&stab->lock);
1776 rcu_read_unlock();
1777
1778 sock_map_remove_complete(stab);
1779}
1780
1781static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1782{
1783 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1784 u32 i = key ? *(u32 *)key : U32_MAX;
1785 u32 *next = (u32 *)next_key;
1786
1787 if (i >= stab->map.max_entries) {
1788 *next = 0;
1789 return 0;
1790 }
1791
1792 if (i == stab->map.max_entries - 1)
1793 return -ENOENT;
1794
1795 *next = i + 1;
1796 return 0;
1797}
1798
1799struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1800{
1801 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1802
1803 if (key >= map->max_entries)
1804 return NULL;
1805
1806 return READ_ONCE(stab->sock_map[key]);
1807}
1808
1809static int sock_map_delete_elem(struct bpf_map *map, void *key)
1810{
1811 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1812 struct smap_psock *psock;
1813 int k = *(u32 *)key;
1814 struct sock *sock;
1815
1816 if (k >= map->max_entries)
1817 return -EINVAL;
1818
1819 raw_spin_lock_bh(&stab->lock);
1820 sock = stab->sock_map[k];
1821 stab->sock_map[k] = NULL;
1822 raw_spin_unlock_bh(&stab->lock);
1823 if (!sock)
1824 return -EINVAL;
1825
1826 psock = smap_psock_sk(sock);
1827 if (!psock)
1828 return 0;
1829 if (psock->bpf_parse) {
1830 write_lock_bh(&sock->sk_callback_lock);
1831 smap_stop_sock(psock, sock);
1832 write_unlock_bh(&sock->sk_callback_lock);
1833 }
1834 smap_list_map_remove(psock, &stab->sock_map[k]);
1835 smap_release_sock(psock, sock);
1836 return 0;
1837}
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868static int __sock_map_ctx_update_elem(struct bpf_map *map,
1869 struct bpf_sock_progs *progs,
1870 struct sock *sock,
1871 void *key)
1872{
1873 struct bpf_prog *verdict, *parse, *tx_msg;
1874 struct smap_psock *psock;
1875 bool new = false;
1876 int err = 0;
1877
1878
1879
1880
1881
1882 verdict = READ_ONCE(progs->bpf_verdict);
1883 parse = READ_ONCE(progs->bpf_parse);
1884 tx_msg = READ_ONCE(progs->bpf_tx_msg);
1885
1886 if (parse && verdict) {
1887
1888
1889
1890
1891
1892 verdict = bpf_prog_inc_not_zero(verdict);
1893 if (IS_ERR(verdict))
1894 return PTR_ERR(verdict);
1895
1896 parse = bpf_prog_inc_not_zero(parse);
1897 if (IS_ERR(parse)) {
1898 bpf_prog_put(verdict);
1899 return PTR_ERR(parse);
1900 }
1901 }
1902
1903 if (tx_msg) {
1904 tx_msg = bpf_prog_inc_not_zero(tx_msg);
1905 if (IS_ERR(tx_msg)) {
1906 if (parse && verdict) {
1907 bpf_prog_put(parse);
1908 bpf_prog_put(verdict);
1909 }
1910 return PTR_ERR(tx_msg);
1911 }
1912 }
1913
1914 psock = smap_psock_sk(sock);
1915
1916
1917
1918
1919
1920
1921
1922 if (psock) {
1923 if (!psock_is_smap_sk(sock)) {
1924 err = -EBUSY;
1925 goto out_progs;
1926 }
1927 if (READ_ONCE(psock->bpf_parse) && parse) {
1928 err = -EBUSY;
1929 goto out_progs;
1930 }
1931 if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1932 err = -EBUSY;
1933 goto out_progs;
1934 }
1935 if (!refcount_inc_not_zero(&psock->refcnt)) {
1936 err = -EAGAIN;
1937 goto out_progs;
1938 }
1939 } else {
1940 psock = smap_init_psock(sock, map->numa_node);
1941 if (IS_ERR(psock)) {
1942 err = PTR_ERR(psock);
1943 goto out_progs;
1944 }
1945
1946 set_bit(SMAP_TX_RUNNING, &psock->state);
1947 new = true;
1948 }
1949
1950
1951
1952
1953 if (tx_msg)
1954 bpf_tcp_msg_add(psock, sock, tx_msg);
1955 if (new) {
1956 err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
1957 if (err)
1958 goto out_free;
1959 }
1960
1961 if (parse && verdict && !psock->strp_enabled) {
1962 err = smap_init_sock(psock, sock);
1963 if (err)
1964 goto out_free;
1965 smap_init_progs(psock, verdict, parse);
1966 write_lock_bh(&sock->sk_callback_lock);
1967 smap_start_sock(psock, sock);
1968 write_unlock_bh(&sock->sk_callback_lock);
1969 }
1970
1971 return err;
1972out_free:
1973 smap_release_sock(psock, sock);
1974out_progs:
1975 if (parse && verdict) {
1976 bpf_prog_put(parse);
1977 bpf_prog_put(verdict);
1978 }
1979 if (tx_msg)
1980 bpf_prog_put(tx_msg);
1981 return err;
1982}
1983
1984static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1985 struct bpf_map *map,
1986 void *key, u64 flags)
1987{
1988 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1989 struct bpf_sock_progs *progs = &stab->progs;
1990 struct sock *osock, *sock = skops->sk;
1991 struct smap_psock_map_entry *e;
1992 struct smap_psock *psock;
1993 u32 i = *(u32 *)key;
1994 int err;
1995
1996 if (unlikely(flags > BPF_EXIST))
1997 return -EINVAL;
1998 if (unlikely(i >= stab->map.max_entries))
1999 return -E2BIG;
2000
2001 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2002 if (!e)
2003 return -ENOMEM;
2004
2005 err = __sock_map_ctx_update_elem(map, progs, sock, key);
2006 if (err)
2007 goto out;
2008
2009
2010 psock = smap_psock_sk(sock);
2011 raw_spin_lock_bh(&stab->lock);
2012 osock = stab->sock_map[i];
2013 if (osock && flags == BPF_NOEXIST) {
2014 err = -EEXIST;
2015 goto out_unlock;
2016 }
2017 if (!osock && flags == BPF_EXIST) {
2018 err = -ENOENT;
2019 goto out_unlock;
2020 }
2021
2022 e->entry = &stab->sock_map[i];
2023 e->map = map;
2024 spin_lock_bh(&psock->maps_lock);
2025 list_add_tail(&e->list, &psock->maps);
2026 spin_unlock_bh(&psock->maps_lock);
2027
2028 stab->sock_map[i] = sock;
2029 if (osock) {
2030 psock = smap_psock_sk(osock);
2031 smap_list_map_remove(psock, &stab->sock_map[i]);
2032 smap_release_sock(psock, osock);
2033 }
2034 raw_spin_unlock_bh(&stab->lock);
2035 return 0;
2036out_unlock:
2037 smap_release_sock(psock, sock);
2038 raw_spin_unlock_bh(&stab->lock);
2039out:
2040 kfree(e);
2041 return err;
2042}
2043
2044int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
2045{
2046 struct bpf_sock_progs *progs;
2047 struct bpf_prog *orig;
2048
2049 if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2050 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2051
2052 progs = &stab->progs;
2053 } else if (map->map_type == BPF_MAP_TYPE_SOCKHASH) {
2054 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2055
2056 progs = &htab->progs;
2057 } else {
2058 return -EINVAL;
2059 }
2060
2061 switch (type) {
2062 case BPF_SK_MSG_VERDICT:
2063 orig = xchg(&progs->bpf_tx_msg, prog);
2064 break;
2065 case BPF_SK_SKB_STREAM_PARSER:
2066 orig = xchg(&progs->bpf_parse, prog);
2067 break;
2068 case BPF_SK_SKB_STREAM_VERDICT:
2069 orig = xchg(&progs->bpf_verdict, prog);
2070 break;
2071 default:
2072 return -EOPNOTSUPP;
2073 }
2074
2075 if (orig)
2076 bpf_prog_put(orig);
2077
2078 return 0;
2079}
2080
2081int sockmap_get_from_fd(const union bpf_attr *attr, int type,
2082 struct bpf_prog *prog)
2083{
2084 int ufd = attr->target_fd;
2085 struct bpf_map *map;
2086 struct fd f;
2087 int err;
2088
2089 f = fdget(ufd);
2090 map = __bpf_map_get(f);
2091 if (IS_ERR(map))
2092 return PTR_ERR(map);
2093
2094 err = sock_map_prog(map, prog, attr->attach_type);
2095 fdput(f);
2096 return err;
2097}
2098
2099static void *sock_map_lookup(struct bpf_map *map, void *key)
2100{
2101 return NULL;
2102}
2103
2104static int sock_map_update_elem(struct bpf_map *map,
2105 void *key, void *value, u64 flags)
2106{
2107 struct bpf_sock_ops_kern skops;
2108 u32 fd = *(u32 *)value;
2109 struct socket *socket;
2110 int err;
2111
2112 socket = sockfd_lookup(fd, &err);
2113 if (!socket)
2114 return err;
2115
2116 skops.sk = socket->sk;
2117 if (!skops.sk) {
2118 fput(socket->file);
2119 return -EINVAL;
2120 }
2121
2122
2123
2124
2125 if (skops.sk->sk_type != SOCK_STREAM ||
2126 skops.sk->sk_protocol != IPPROTO_TCP ||
2127 skops.sk->sk_state != TCP_ESTABLISHED) {
2128 fput(socket->file);
2129 return -EOPNOTSUPP;
2130 }
2131
2132 lock_sock(skops.sk);
2133 preempt_disable();
2134 rcu_read_lock();
2135 err = sock_map_ctx_update_elem(&skops, map, key, flags);
2136 rcu_read_unlock();
2137 preempt_enable();
2138 release_sock(skops.sk);
2139 fput(socket->file);
2140 return err;
2141}
2142
2143static void sock_map_release(struct bpf_map *map)
2144{
2145 struct bpf_sock_progs *progs;
2146 struct bpf_prog *orig;
2147
2148 if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2149 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2150
2151 progs = &stab->progs;
2152 } else {
2153 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2154
2155 progs = &htab->progs;
2156 }
2157
2158 orig = xchg(&progs->bpf_parse, NULL);
2159 if (orig)
2160 bpf_prog_put(orig);
2161 orig = xchg(&progs->bpf_verdict, NULL);
2162 if (orig)
2163 bpf_prog_put(orig);
2164
2165 orig = xchg(&progs->bpf_tx_msg, NULL);
2166 if (orig)
2167 bpf_prog_put(orig);
2168}
2169
2170static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
2171{
2172 struct bpf_htab *htab;
2173 int i, err;
2174 u64 cost;
2175
2176 if (!capable(CAP_NET_ADMIN))
2177 return ERR_PTR(-EPERM);
2178
2179
2180 if (attr->max_entries == 0 ||
2181 attr->key_size == 0 ||
2182 attr->value_size != 4 ||
2183 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
2184 return ERR_PTR(-EINVAL);
2185
2186 if (attr->key_size > MAX_BPF_STACK)
2187
2188
2189
2190 return ERR_PTR(-E2BIG);
2191
2192 err = bpf_tcp_ulp_register();
2193 if (err && err != -EEXIST)
2194 return ERR_PTR(err);
2195
2196 htab = kzalloc(sizeof(*htab), GFP_USER);
2197 if (!htab)
2198 return ERR_PTR(-ENOMEM);
2199
2200 bpf_map_init_from_attr(&htab->map, attr);
2201
2202 htab->n_buckets = roundup_pow_of_two(htab->map.max_entries);
2203 htab->elem_size = sizeof(struct htab_elem) +
2204 round_up(htab->map.key_size, 8);
2205 err = -EINVAL;
2206 if (htab->n_buckets == 0 ||
2207 htab->n_buckets > U32_MAX / sizeof(struct bucket))
2208 goto free_htab;
2209
2210 cost = (u64) htab->n_buckets * sizeof(struct bucket) +
2211 (u64) htab->elem_size * htab->map.max_entries;
2212
2213 if (cost >= U32_MAX - PAGE_SIZE)
2214 goto free_htab;
2215
2216 htab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
2217 err = bpf_map_precharge_memlock(htab->map.pages);
2218 if (err)
2219 goto free_htab;
2220
2221 err = -ENOMEM;
2222 htab->buckets = bpf_map_area_alloc(
2223 htab->n_buckets * sizeof(struct bucket),
2224 htab->map.numa_node);
2225 if (!htab->buckets)
2226 goto free_htab;
2227
2228 for (i = 0; i < htab->n_buckets; i++) {
2229 INIT_HLIST_HEAD(&htab->buckets[i].head);
2230 raw_spin_lock_init(&htab->buckets[i].lock);
2231 }
2232
2233 return &htab->map;
2234free_htab:
2235 kfree(htab);
2236 return ERR_PTR(err);
2237}
2238
2239static void __bpf_htab_free(struct rcu_head *rcu)
2240{
2241 struct bpf_htab *htab;
2242
2243 htab = container_of(rcu, struct bpf_htab, rcu);
2244 bpf_map_area_free(htab->buckets);
2245 kfree(htab);
2246}
2247
2248static void sock_hash_free(struct bpf_map *map)
2249{
2250 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2251 int i;
2252
2253 synchronize_rcu();
2254
2255
2256
2257
2258
2259
2260
2261
2262 rcu_read_lock();
2263 for (i = 0; i < htab->n_buckets; i++) {
2264 struct bucket *b = __select_bucket(htab, i);
2265 struct hlist_head *head;
2266 struct hlist_node *n;
2267 struct htab_elem *l;
2268
2269 raw_spin_lock_bh(&b->lock);
2270 head = &b->head;
2271 hlist_for_each_entry_safe(l, n, head, hash_node) {
2272 struct sock *sock = l->sk;
2273 struct smap_psock *psock;
2274
2275 hlist_del_rcu(&l->hash_node);
2276 psock = smap_psock_sk(sock);
2277
2278
2279
2280
2281
2282 if (likely(psock)) {
2283 smap_list_hash_remove(psock, l);
2284 smap_release_sock(psock, sock);
2285 }
2286 free_htab_elem(htab, l);
2287 }
2288 raw_spin_unlock_bh(&b->lock);
2289 }
2290 rcu_read_unlock();
2291 call_rcu(&htab->rcu, __bpf_htab_free);
2292}
2293
2294static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
2295 void *key, u32 key_size, u32 hash,
2296 struct sock *sk,
2297 struct htab_elem *old_elem)
2298{
2299 struct htab_elem *l_new;
2300
2301 if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
2302 if (!old_elem) {
2303 atomic_dec(&htab->count);
2304 return ERR_PTR(-E2BIG);
2305 }
2306 }
2307 l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
2308 htab->map.numa_node);
2309 if (!l_new) {
2310 atomic_dec(&htab->count);
2311 return ERR_PTR(-ENOMEM);
2312 }
2313
2314 memcpy(l_new->key, key, key_size);
2315 l_new->sk = sk;
2316 l_new->hash = hash;
2317 return l_new;
2318}
2319
2320static inline u32 htab_map_hash(const void *key, u32 key_len)
2321{
2322 return jhash(key, key_len, 0);
2323}
2324
2325static int sock_hash_get_next_key(struct bpf_map *map,
2326 void *key, void *next_key)
2327{
2328 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2329 struct htab_elem *l, *next_l;
2330 struct hlist_head *h;
2331 u32 hash, key_size;
2332 int i = 0;
2333
2334 WARN_ON_ONCE(!rcu_read_lock_held());
2335
2336 key_size = map->key_size;
2337 if (!key)
2338 goto find_first_elem;
2339 hash = htab_map_hash(key, key_size);
2340 h = select_bucket(htab, hash);
2341
2342 l = lookup_elem_raw(h, hash, key, key_size);
2343 if (!l)
2344 goto find_first_elem;
2345 next_l = hlist_entry_safe(
2346 rcu_dereference_raw(hlist_next_rcu(&l->hash_node)),
2347 struct htab_elem, hash_node);
2348 if (next_l) {
2349 memcpy(next_key, next_l->key, key_size);
2350 return 0;
2351 }
2352
2353
2354 i = hash & (htab->n_buckets - 1);
2355 i++;
2356
2357find_first_elem:
2358
2359 for (; i < htab->n_buckets; i++) {
2360 h = select_bucket(htab, i);
2361
2362
2363 next_l = hlist_entry_safe(
2364 rcu_dereference_raw(hlist_first_rcu(h)),
2365 struct htab_elem, hash_node);
2366 if (next_l) {
2367
2368 memcpy(next_key, next_l->key, key_size);
2369 return 0;
2370 }
2371 }
2372
2373
2374 return -ENOENT;
2375}
2376
2377static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2378 struct bpf_map *map,
2379 void *key, u64 map_flags)
2380{
2381 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2382 struct bpf_sock_progs *progs = &htab->progs;
2383 struct htab_elem *l_new = NULL, *l_old;
2384 struct smap_psock_map_entry *e = NULL;
2385 struct hlist_head *head;
2386 struct smap_psock *psock;
2387 u32 key_size, hash;
2388 struct sock *sock;
2389 struct bucket *b;
2390 int err;
2391
2392 sock = skops->sk;
2393
2394 if (sock->sk_type != SOCK_STREAM ||
2395 sock->sk_protocol != IPPROTO_TCP)
2396 return -EOPNOTSUPP;
2397
2398 if (unlikely(map_flags > BPF_EXIST))
2399 return -EINVAL;
2400
2401 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2402 if (!e)
2403 return -ENOMEM;
2404
2405 WARN_ON_ONCE(!rcu_read_lock_held());
2406 key_size = map->key_size;
2407 hash = htab_map_hash(key, key_size);
2408 b = __select_bucket(htab, hash);
2409 head = &b->head;
2410
2411 err = __sock_map_ctx_update_elem(map, progs, sock, key);
2412 if (err)
2413 goto err;
2414
2415
2416
2417
2418 psock = smap_psock_sk(sock);
2419 raw_spin_lock_bh(&b->lock);
2420 l_old = lookup_elem_raw(head, hash, key, key_size);
2421 if (l_old && map_flags == BPF_NOEXIST) {
2422 err = -EEXIST;
2423 goto bucket_err;
2424 }
2425 if (!l_old && map_flags == BPF_EXIST) {
2426 err = -ENOENT;
2427 goto bucket_err;
2428 }
2429
2430 l_new = alloc_sock_hash_elem(htab, key, key_size, hash, sock, l_old);
2431 if (IS_ERR(l_new)) {
2432 err = PTR_ERR(l_new);
2433 goto bucket_err;
2434 }
2435
2436 rcu_assign_pointer(e->hash_link, l_new);
2437 e->map = map;
2438 spin_lock_bh(&psock->maps_lock);
2439 list_add_tail(&e->list, &psock->maps);
2440 spin_unlock_bh(&psock->maps_lock);
2441
2442
2443
2444
2445 hlist_add_head_rcu(&l_new->hash_node, head);
2446 if (l_old) {
2447 psock = smap_psock_sk(l_old->sk);
2448
2449 hlist_del_rcu(&l_old->hash_node);
2450 smap_list_hash_remove(psock, l_old);
2451 smap_release_sock(psock, l_old->sk);
2452 free_htab_elem(htab, l_old);
2453 }
2454 raw_spin_unlock_bh(&b->lock);
2455 return 0;
2456bucket_err:
2457 smap_release_sock(psock, sock);
2458 raw_spin_unlock_bh(&b->lock);
2459err:
2460 kfree(e);
2461 return err;
2462}
2463
2464static int sock_hash_update_elem(struct bpf_map *map,
2465 void *key, void *value, u64 flags)
2466{
2467 struct bpf_sock_ops_kern skops;
2468 u32 fd = *(u32 *)value;
2469 struct socket *socket;
2470 int err;
2471
2472 socket = sockfd_lookup(fd, &err);
2473 if (!socket)
2474 return err;
2475
2476 skops.sk = socket->sk;
2477 if (!skops.sk) {
2478 fput(socket->file);
2479 return -EINVAL;
2480 }
2481
2482
2483
2484
2485 if (skops.sk->sk_type != SOCK_STREAM ||
2486 skops.sk->sk_protocol != IPPROTO_TCP ||
2487 skops.sk->sk_state != TCP_ESTABLISHED) {
2488 fput(socket->file);
2489 return -EOPNOTSUPP;
2490 }
2491
2492 lock_sock(skops.sk);
2493 preempt_disable();
2494 rcu_read_lock();
2495 err = sock_hash_ctx_update_elem(&skops, map, key, flags);
2496 rcu_read_unlock();
2497 preempt_enable();
2498 release_sock(skops.sk);
2499 fput(socket->file);
2500 return err;
2501}
2502
2503static int sock_hash_delete_elem(struct bpf_map *map, void *key)
2504{
2505 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2506 struct hlist_head *head;
2507 struct bucket *b;
2508 struct htab_elem *l;
2509 u32 hash, key_size;
2510 int ret = -ENOENT;
2511
2512 key_size = map->key_size;
2513 hash = htab_map_hash(key, key_size);
2514 b = __select_bucket(htab, hash);
2515 head = &b->head;
2516
2517 raw_spin_lock_bh(&b->lock);
2518 l = lookup_elem_raw(head, hash, key, key_size);
2519 if (l) {
2520 struct sock *sock = l->sk;
2521 struct smap_psock *psock;
2522
2523 hlist_del_rcu(&l->hash_node);
2524 psock = smap_psock_sk(sock);
2525
2526
2527
2528
2529
2530 if (likely(psock)) {
2531 smap_list_hash_remove(psock, l);
2532 smap_release_sock(psock, sock);
2533 }
2534 free_htab_elem(htab, l);
2535 ret = 0;
2536 }
2537 raw_spin_unlock_bh(&b->lock);
2538 return ret;
2539}
2540
2541struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
2542{
2543 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2544 struct hlist_head *head;
2545 struct htab_elem *l;
2546 u32 key_size, hash;
2547 struct bucket *b;
2548 struct sock *sk;
2549
2550 key_size = map->key_size;
2551 hash = htab_map_hash(key, key_size);
2552 b = __select_bucket(htab, hash);
2553 head = &b->head;
2554
2555 l = lookup_elem_raw(head, hash, key, key_size);
2556 sk = l ? l->sk : NULL;
2557 return sk;
2558}
2559
2560const struct bpf_map_ops sock_map_ops = {
2561 .map_alloc = sock_map_alloc,
2562 .map_free = sock_map_free,
2563 .map_lookup_elem = sock_map_lookup,
2564 .map_get_next_key = sock_map_get_next_key,
2565 .map_update_elem = sock_map_update_elem,
2566 .map_delete_elem = sock_map_delete_elem,
2567 .map_release_uref = sock_map_release,
2568 .map_check_btf = map_check_no_btf,
2569};
2570
2571const struct bpf_map_ops sock_hash_ops = {
2572 .map_alloc = sock_hash_alloc,
2573 .map_free = sock_hash_free,
2574 .map_lookup_elem = sock_map_lookup,
2575 .map_get_next_key = sock_hash_get_next_key,
2576 .map_update_elem = sock_hash_update_elem,
2577 .map_delete_elem = sock_hash_delete_elem,
2578 .map_release_uref = sock_map_release,
2579 .map_check_btf = map_check_no_btf,
2580};
2581
2582static bool bpf_is_valid_sock_op(struct bpf_sock_ops_kern *ops)
2583{
2584 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
2585 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
2586}
2587BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
2588 struct bpf_map *, map, void *, key, u64, flags)
2589{
2590 WARN_ON_ONCE(!rcu_read_lock_held());
2591
2592
2593
2594
2595
2596 if (!bpf_is_valid_sock_op(bpf_sock))
2597 return -EOPNOTSUPP;
2598 return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
2599}
2600
2601const struct bpf_func_proto bpf_sock_map_update_proto = {
2602 .func = bpf_sock_map_update,
2603 .gpl_only = false,
2604 .pkt_access = true,
2605 .ret_type = RET_INTEGER,
2606 .arg1_type = ARG_PTR_TO_CTX,
2607 .arg2_type = ARG_CONST_MAP_PTR,
2608 .arg3_type = ARG_PTR_TO_MAP_KEY,
2609 .arg4_type = ARG_ANYTHING,
2610};
2611
2612BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, bpf_sock,
2613 struct bpf_map *, map, void *, key, u64, flags)
2614{
2615 WARN_ON_ONCE(!rcu_read_lock_held());
2616
2617 if (!bpf_is_valid_sock_op(bpf_sock))
2618 return -EOPNOTSUPP;
2619 return sock_hash_ctx_update_elem(bpf_sock, map, key, flags);
2620}
2621
2622const struct bpf_func_proto bpf_sock_hash_update_proto = {
2623 .func = bpf_sock_hash_update,
2624 .gpl_only = false,
2625 .pkt_access = true,
2626 .ret_type = RET_INTEGER,
2627 .arg1_type = ARG_PTR_TO_CTX,
2628 .arg2_type = ARG_CONST_MAP_PTR,
2629 .arg3_type = ARG_PTR_TO_MAP_KEY,
2630 .arg4_type = ARG_ANYTHING,
2631};
2632