linux/net/ipv4/udp_bpf.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
   3
   4#include <linux/skmsg.h>
   5#include <net/sock.h>
   6#include <net/udp.h>
   7#include <net/inet_common.h>
   8
   9#include "udp_impl.h"
  10
  11static struct proto *udpv6_prot_saved __read_mostly;
  12
  13static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  14                          int noblock, int flags, int *addr_len)
  15{
  16#if IS_ENABLED(CONFIG_IPV6)
  17        if (sk->sk_family == AF_INET6)
  18                return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags,
  19                                                 addr_len);
  20#endif
  21        return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len);
  22}
  23
  24static bool udp_sk_has_data(struct sock *sk)
  25{
  26        return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
  27               !skb_queue_empty(&sk->sk_receive_queue);
  28}
  29
  30static bool psock_has_data(struct sk_psock *psock)
  31{
  32        return !skb_queue_empty(&psock->ingress_skb) ||
  33               !sk_psock_queue_empty(psock);
  34}
  35
  36#define udp_msg_has_data(__sk, __psock) \
  37                ({ udp_sk_has_data(__sk) || psock_has_data(__psock); })
  38
  39static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
  40                             long timeo, int *err)
  41{
  42        DEFINE_WAIT_FUNC(wait, woken_wake_function);
  43        int ret = 0;
  44
  45        if (sk->sk_shutdown & RCV_SHUTDOWN)
  46                return 1;
  47
  48        if (!timeo)
  49                return ret;
  50
  51        add_wait_queue(sk_sleep(sk), &wait);
  52        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  53        ret = udp_msg_has_data(sk, psock);
  54        if (!ret) {
  55                wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
  56                ret = udp_msg_has_data(sk, psock);
  57        }
  58        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  59        remove_wait_queue(sk_sleep(sk), &wait);
  60        return ret;
  61}
  62
  63static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  64                           int nonblock, int flags, int *addr_len)
  65{
  66        struct sk_psock *psock;
  67        int copied, ret;
  68
  69        if (unlikely(flags & MSG_ERRQUEUE))
  70                return inet_recv_error(sk, msg, len, addr_len);
  71
  72        psock = sk_psock_get(sk);
  73        if (unlikely(!psock))
  74                return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  75
  76        if (!psock_has_data(psock)) {
  77                ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  78                goto out;
  79        }
  80
  81msg_bytes_ready:
  82        copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
  83        if (!copied) {
  84                int data, err = 0;
  85                long timeo;
  86
  87                timeo = sock_rcvtimeo(sk, nonblock);
  88                data = udp_msg_wait_data(sk, psock, flags, timeo, &err);
  89                if (data) {
  90                        if (psock_has_data(psock))
  91                                goto msg_bytes_ready;
  92                        ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  93                        goto out;
  94                }
  95                if (err) {
  96                        ret = err;
  97                        goto out;
  98                }
  99                copied = -EAGAIN;
 100        }
 101        ret = copied;
 102out:
 103        sk_psock_put(sk, psock);
 104        return ret;
 105}
 106
 107enum {
 108        UDP_BPF_IPV4,
 109        UDP_BPF_IPV6,
 110        UDP_BPF_NUM_PROTS,
 111};
 112
 113static DEFINE_SPINLOCK(udpv6_prot_lock);
 114static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
 115
 116static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
 117{
 118        *prot        = *base;
 119        prot->unhash = sock_map_unhash;
 120        prot->close  = sock_map_close;
 121        prot->recvmsg = udp_bpf_recvmsg;
 122}
 123
 124static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
 125{
 126        if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
 127                spin_lock_bh(&udpv6_prot_lock);
 128                if (likely(ops != udpv6_prot_saved)) {
 129                        udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
 130                        smp_store_release(&udpv6_prot_saved, ops);
 131                }
 132                spin_unlock_bh(&udpv6_prot_lock);
 133        }
 134}
 135
 136static int __init udp_bpf_v4_build_proto(void)
 137{
 138        udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
 139        return 0;
 140}
 141late_initcall(udp_bpf_v4_build_proto);
 142
 143int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
 144{
 145        int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
 146
 147        if (restore) {
 148                sk->sk_write_space = psock->saved_write_space;
 149                WRITE_ONCE(sk->sk_prot, psock->sk_proto);
 150                return 0;
 151        }
 152
 153        if (sk->sk_family == AF_INET6)
 154                udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 155
 156        WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
 157        return 0;
 158}
 159EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
 160