linux/net/ipv4/udp_bpf.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
   3
   4#include <linux/skmsg.h>
   5#include <net/sock.h>
   6#include <net/udp.h>
   7#include <net/inet_common.h>
   8
   9#include "udp_impl.h"
  10
  11static struct proto *udpv6_prot_saved __read_mostly;
  12
  13static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  14                          int noblock, int flags, int *addr_len)
  15{
  16#if IS_ENABLED(CONFIG_IPV6)
  17        if (sk->sk_family == AF_INET6)
  18                return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags,
  19                                                 addr_len);
  20#endif
  21        return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len);
  22}
  23
  24static bool udp_sk_has_data(struct sock *sk)
  25{
  26        return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
  27               !skb_queue_empty(&sk->sk_receive_queue);
  28}
  29
  30static bool psock_has_data(struct sk_psock *psock)
  31{
  32        return !skb_queue_empty(&psock->ingress_skb) ||
  33               !sk_psock_queue_empty(psock);
  34}
  35
  36#define udp_msg_has_data(__sk, __psock) \
  37                ({ udp_sk_has_data(__sk) || psock_has_data(__psock); })
  38
  39static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
  40                             long timeo)
  41{
  42        DEFINE_WAIT_FUNC(wait, woken_wake_function);
  43        int ret = 0;
  44
  45        if (sk->sk_shutdown & RCV_SHUTDOWN)
  46                return 1;
  47
  48        if (!timeo)
  49                return ret;
  50
  51        add_wait_queue(sk_sleep(sk), &wait);
  52        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  53        ret = udp_msg_has_data(sk, psock);
  54        if (!ret) {
  55                wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
  56                ret = udp_msg_has_data(sk, psock);
  57        }
  58        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  59        remove_wait_queue(sk_sleep(sk), &wait);
  60        return ret;
  61}
  62
  63static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  64                           int nonblock, int flags, int *addr_len)
  65{
  66        struct sk_psock *psock;
  67        int copied, ret;
  68
  69        if (unlikely(flags & MSG_ERRQUEUE))
  70                return inet_recv_error(sk, msg, len, addr_len);
  71
  72        psock = sk_psock_get(sk);
  73        if (unlikely(!psock))
  74                return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  75
  76        if (!psock_has_data(psock)) {
  77                ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  78                goto out;
  79        }
  80
  81msg_bytes_ready:
  82        copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
  83        if (!copied) {
  84                long timeo;
  85                int data;
  86
  87                timeo = sock_rcvtimeo(sk, nonblock);
  88                data = udp_msg_wait_data(sk, psock, timeo);
  89                if (data) {
  90                        if (psock_has_data(psock))
  91                                goto msg_bytes_ready;
  92                        ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  93                        goto out;
  94                }
  95                copied = -EAGAIN;
  96        }
  97        ret = copied;
  98out:
  99        sk_psock_put(sk, psock);
 100        return ret;
 101}
 102
 103enum {
 104        UDP_BPF_IPV4,
 105        UDP_BPF_IPV6,
 106        UDP_BPF_NUM_PROTS,
 107};
 108
 109static DEFINE_SPINLOCK(udpv6_prot_lock);
 110static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
 111
 112static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
 113{
 114        *prot        = *base;
 115        prot->close  = sock_map_close;
 116        prot->recvmsg = udp_bpf_recvmsg;
 117        prot->sock_is_readable = sk_msg_is_readable;
 118}
 119
 120static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
 121{
 122        if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
 123                spin_lock_bh(&udpv6_prot_lock);
 124                if (likely(ops != udpv6_prot_saved)) {
 125                        udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
 126                        smp_store_release(&udpv6_prot_saved, ops);
 127                }
 128                spin_unlock_bh(&udpv6_prot_lock);
 129        }
 130}
 131
 132static int __init udp_bpf_v4_build_proto(void)
 133{
 134        udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
 135        return 0;
 136}
 137late_initcall(udp_bpf_v4_build_proto);
 138
 139int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
 140{
 141        int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
 142
 143        if (restore) {
 144                sk->sk_write_space = psock->saved_write_space;
 145                WRITE_ONCE(sk->sk_prot, psock->sk_proto);
 146                return 0;
 147        }
 148
 149        if (sk->sk_family == AF_INET6)
 150                udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 151
 152        WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
 153        return 0;
 154}
 155EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
 156