linux/include/linux/skmsg.h
<<
>>
Prefs
   1/* SPDX-License-Identifier: GPL-2.0 */
   2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3
   4#ifndef _LINUX_SKMSG_H
   5#define _LINUX_SKMSG_H
   6
   7#include <linux/bpf.h>
   8#include <linux/filter.h>
   9#include <linux/scatterlist.h>
  10#include <linux/skbuff.h>
  11
  12#include <net/sock.h>
  13#include <net/tcp.h>
  14#include <net/strparser.h>
  15
  16#define MAX_MSG_FRAGS                   MAX_SKB_FRAGS
  17#define NR_MSG_FRAG_IDS                 (MAX_MSG_FRAGS + 1)
  18
  19enum __sk_action {
  20        __SK_DROP = 0,
  21        __SK_PASS,
  22        __SK_REDIRECT,
  23        __SK_NONE,
  24};
  25
  26struct sk_msg_sg {
  27        u32                             start;
  28        u32                             curr;
  29        u32                             end;
  30        u32                             size;
  31        u32                             copybreak;
  32        unsigned long                   copy;
  33        /* The extra two elements:
  34         * 1) used for chaining the front and sections when the list becomes
  35         *    partitioned (e.g. end < start). The crypto APIs require the
  36         *    chaining;
  37         * 2) to chain tailer SG entries after the message.
  38         */
  39        struct scatterlist              data[MAX_MSG_FRAGS + 2];
  40};
  41static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS);
  42
  43/* UAPI in filter.c depends on struct sk_msg_sg being first element. */
  44struct sk_msg {
  45        struct sk_msg_sg                sg;
  46        void                            *data;
  47        void                            *data_end;
  48        u32                             apply_bytes;
  49        u32                             cork_bytes;
  50        u32                             flags;
  51        struct sk_buff                  *skb;
  52        struct sock                     *sk_redir;
  53        struct sock                     *sk;
  54        struct list_head                list;
  55};
  56
  57struct sk_psock_progs {
  58        struct bpf_prog                 *msg_parser;
  59        struct bpf_prog                 *stream_parser;
  60        struct bpf_prog                 *stream_verdict;
  61        struct bpf_prog                 *skb_verdict;
  62};
  63
  64enum sk_psock_state_bits {
  65        SK_PSOCK_TX_ENABLED,
  66};
  67
  68struct sk_psock_link {
  69        struct list_head                list;
  70        struct bpf_map                  *map;
  71        void                            *link_raw;
  72};
  73
  74struct sk_psock_work_state {
  75        struct sk_buff                  *skb;
  76        u32                             len;
  77        u32                             off;
  78};
  79
  80struct sk_psock {
  81        struct sock                     *sk;
  82        struct sock                     *sk_redir;
  83        u32                             apply_bytes;
  84        u32                             cork_bytes;
  85        u32                             eval;
  86        struct sk_msg                   *cork;
  87        struct sk_psock_progs           progs;
  88#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
  89        struct strparser                strp;
  90#endif
  91        struct sk_buff_head             ingress_skb;
  92        struct list_head                ingress_msg;
  93        spinlock_t                      ingress_lock;
  94        unsigned long                   state;
  95        struct list_head                link;
  96        spinlock_t                      link_lock;
  97        refcount_t                      refcnt;
  98        void (*saved_unhash)(struct sock *sk);
  99        void (*saved_close)(struct sock *sk, long timeout);
 100        void (*saved_write_space)(struct sock *sk);
 101        void (*saved_data_ready)(struct sock *sk);
 102        int  (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
 103                                     bool restore);
 104        struct proto                    *sk_proto;
 105        struct mutex                    work_mutex;
 106        struct sk_psock_work_state      work_state;
 107        struct work_struct              work;
 108        struct rcu_work                 rwork;
 109};
 110
 111int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
 112                 int elem_first_coalesce);
 113int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
 114                 u32 off, u32 len);
 115void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
 116int sk_msg_free(struct sock *sk, struct sk_msg *msg);
 117int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
 118void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
 119void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
 120                                  u32 bytes);
 121
 122void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
 123void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
 124
 125int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
 126                              struct sk_msg *msg, u32 bytes);
 127int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
 128                             struct sk_msg *msg, u32 bytes);
 129int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
 130                   int len, int flags);
 131bool sk_msg_is_readable(struct sock *sk);
 132
 133static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
 134{
 135        WARN_ON(i == msg->sg.end && bytes);
 136}
 137
 138static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
 139{
 140        if (psock->apply_bytes) {
 141                if (psock->apply_bytes < bytes)
 142                        psock->apply_bytes = 0;
 143                else
 144                        psock->apply_bytes -= bytes;
 145        }
 146}
 147
 148static inline u32 sk_msg_iter_dist(u32 start, u32 end)
 149{
 150        return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
 151}
 152
 153#define sk_msg_iter_var_prev(var)                       \
 154        do {                                            \
 155                if (var == 0)                           \
 156                        var = NR_MSG_FRAG_IDS - 1;      \
 157                else                                    \
 158                        var--;                          \
 159        } while (0)
 160
 161#define sk_msg_iter_var_next(var)                       \
 162        do {                                            \
 163                var++;                                  \
 164                if (var == NR_MSG_FRAG_IDS)             \
 165                        var = 0;                        \
 166        } while (0)
 167
 168#define sk_msg_iter_prev(msg, which)                    \
 169        sk_msg_iter_var_prev(msg->sg.which)
 170
 171#define sk_msg_iter_next(msg, which)                    \
 172        sk_msg_iter_var_next(msg->sg.which)
 173
 174static inline void sk_msg_clear_meta(struct sk_msg *msg)
 175{
 176        memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
 177}
 178
 179static inline void sk_msg_init(struct sk_msg *msg)
 180{
 181        BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
 182        memset(msg, 0, sizeof(*msg));
 183        sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
 184}
 185
 186static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
 187                               int which, u32 size)
 188{
 189        dst->sg.data[which] = src->sg.data[which];
 190        dst->sg.data[which].length  = size;
 191        dst->sg.size               += size;
 192        src->sg.size               -= size;
 193        src->sg.data[which].length -= size;
 194        src->sg.data[which].offset += size;
 195}
 196
 197static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
 198{
 199        memcpy(dst, src, sizeof(*src));
 200        sk_msg_init(src);
 201}
 202
 203static inline bool sk_msg_full(const struct sk_msg *msg)
 204{
 205        return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
 206}
 207
 208static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
 209{
 210        return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
 211}
 212
 213static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
 214{
 215        return &msg->sg.data[which];
 216}
 217
 218static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
 219{
 220        return msg->sg.data[which];
 221}
 222
 223static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
 224{
 225        return sg_page(sk_msg_elem(msg, which));
 226}
 227
 228static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
 229{
 230        return msg->flags & BPF_F_INGRESS;
 231}
 232
 233static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
 234{
 235        struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
 236
 237        if (test_bit(msg->sg.start, &msg->sg.copy)) {
 238                msg->data = NULL;
 239                msg->data_end = NULL;
 240        } else {
 241                msg->data = sg_virt(sge);
 242                msg->data_end = msg->data + sge->length;
 243        }
 244}
 245
 246static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
 247                                   u32 len, u32 offset)
 248{
 249        struct scatterlist *sge;
 250
 251        get_page(page);
 252        sge = sk_msg_elem(msg, msg->sg.end);
 253        sg_set_page(sge, page, len, offset);
 254        sg_unmark_end(sge);
 255
 256        __set_bit(msg->sg.end, &msg->sg.copy);
 257        msg->sg.size += len;
 258        sk_msg_iter_next(msg, end);
 259}
 260
 261static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
 262{
 263        do {
 264                if (copy_state)
 265                        __set_bit(i, &msg->sg.copy);
 266                else
 267                        __clear_bit(i, &msg->sg.copy);
 268                sk_msg_iter_var_next(i);
 269                if (i == msg->sg.end)
 270                        break;
 271        } while (1);
 272}
 273
 274static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
 275{
 276        sk_msg_sg_copy(msg, start, true);
 277}
 278
 279static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
 280{
 281        sk_msg_sg_copy(msg, start, false);
 282}
 283
 284static inline struct sk_psock *sk_psock(const struct sock *sk)
 285{
 286        return rcu_dereference_sk_user_data(sk);
 287}
 288
 289static inline void sk_psock_set_state(struct sk_psock *psock,
 290                                      enum sk_psock_state_bits bit)
 291{
 292        set_bit(bit, &psock->state);
 293}
 294
 295static inline void sk_psock_clear_state(struct sk_psock *psock,
 296                                        enum sk_psock_state_bits bit)
 297{
 298        clear_bit(bit, &psock->state);
 299}
 300
 301static inline bool sk_psock_test_state(const struct sk_psock *psock,
 302                                       enum sk_psock_state_bits bit)
 303{
 304        return test_bit(bit, &psock->state);
 305}
 306
 307static inline void sock_drop(struct sock *sk, struct sk_buff *skb)
 308{
 309        sk_drops_add(sk, skb);
 310        kfree_skb(skb);
 311}
 312
 313static inline void drop_sk_msg(struct sk_psock *psock, struct sk_msg *msg)
 314{
 315        if (msg->skb)
 316                sock_drop(psock->sk, msg->skb);
 317        kfree(msg);
 318}
 319
 320static inline void sk_psock_queue_msg(struct sk_psock *psock,
 321                                      struct sk_msg *msg)
 322{
 323        spin_lock_bh(&psock->ingress_lock);
 324        if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))
 325                list_add_tail(&msg->list, &psock->ingress_msg);
 326        else
 327                drop_sk_msg(psock, msg);
 328        spin_unlock_bh(&psock->ingress_lock);
 329}
 330
 331static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)
 332{
 333        struct sk_msg *msg;
 334
 335        spin_lock_bh(&psock->ingress_lock);
 336        msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
 337        if (msg)
 338                list_del(&msg->list);
 339        spin_unlock_bh(&psock->ingress_lock);
 340        return msg;
 341}
 342
 343static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
 344{
 345        struct sk_msg *msg;
 346
 347        spin_lock_bh(&psock->ingress_lock);
 348        msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
 349        spin_unlock_bh(&psock->ingress_lock);
 350        return msg;
 351}
 352
 353static inline struct sk_msg *sk_psock_next_msg(struct sk_psock *psock,
 354                                               struct sk_msg *msg)
 355{
 356        struct sk_msg *ret;
 357
 358        spin_lock_bh(&psock->ingress_lock);
 359        if (list_is_last(&msg->list, &psock->ingress_msg))
 360                ret = NULL;
 361        else
 362                ret = list_next_entry(msg, list);
 363        spin_unlock_bh(&psock->ingress_lock);
 364        return ret;
 365}
 366
 367static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
 368{
 369        return psock ? list_empty(&psock->ingress_msg) : true;
 370}
 371
 372static inline void kfree_sk_msg(struct sk_msg *msg)
 373{
 374        if (msg->skb)
 375                consume_skb(msg->skb);
 376        kfree(msg);
 377}
 378
 379static inline void sk_psock_report_error(struct sk_psock *psock, int err)
 380{
 381        struct sock *sk = psock->sk;
 382
 383        sk->sk_err = err;
 384        sk_error_report(sk);
 385}
 386
 387struct sk_psock *sk_psock_init(struct sock *sk, int node);
 388void sk_psock_stop(struct sk_psock *psock, bool wait);
 389
 390#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
 391int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
 392void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
 393void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
 394#else
 395static inline int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
 396{
 397        return -EOPNOTSUPP;
 398}
 399
 400static inline void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
 401{
 402}
 403
 404static inline void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
 405{
 406}
 407#endif
 408
 409void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock);
 410void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock);
 411
 412int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
 413                         struct sk_msg *msg);
 414
 415static inline struct sk_psock_link *sk_psock_init_link(void)
 416{
 417        return kzalloc(sizeof(struct sk_psock_link),
 418                       GFP_ATOMIC | __GFP_NOWARN);
 419}
 420
 421static inline void sk_psock_free_link(struct sk_psock_link *link)
 422{
 423        kfree(link);
 424}
 425
 426struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
 427
 428static inline void sk_psock_cork_free(struct sk_psock *psock)
 429{
 430        if (psock->cork) {
 431                sk_msg_free(psock->sk, psock->cork);
 432                kfree(psock->cork);
 433                psock->cork = NULL;
 434        }
 435}
 436
 437static inline void sk_psock_restore_proto(struct sock *sk,
 438                                          struct sk_psock *psock)
 439{
 440        if (psock->psock_update_sk_prot)
 441                psock->psock_update_sk_prot(sk, psock, true);
 442}
 443
 444static inline struct sk_psock *sk_psock_get(struct sock *sk)
 445{
 446        struct sk_psock *psock;
 447
 448        rcu_read_lock();
 449        psock = sk_psock(sk);
 450        if (psock && !refcount_inc_not_zero(&psock->refcnt))
 451                psock = NULL;
 452        rcu_read_unlock();
 453        return psock;
 454}
 455
 456void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
 457
 458static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
 459{
 460        if (refcount_dec_and_test(&psock->refcnt))
 461                sk_psock_drop(sk, psock);
 462}
 463
 464static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
 465{
 466        if (psock->saved_data_ready)
 467                psock->saved_data_ready(sk);
 468        else
 469                sk->sk_data_ready(sk);
 470}
 471
 472static inline void psock_set_prog(struct bpf_prog **pprog,
 473                                  struct bpf_prog *prog)
 474{
 475        prog = xchg(pprog, prog);
 476        if (prog)
 477                bpf_prog_put(prog);
 478}
 479
 480static inline int psock_replace_prog(struct bpf_prog **pprog,
 481                                     struct bpf_prog *prog,
 482                                     struct bpf_prog *old)
 483{
 484        if (cmpxchg(pprog, old, prog) != old)
 485                return -ENOENT;
 486
 487        if (old)
 488                bpf_prog_put(old);
 489
 490        return 0;
 491}
 492
 493static inline void psock_progs_drop(struct sk_psock_progs *progs)
 494{
 495        psock_set_prog(&progs->msg_parser, NULL);
 496        psock_set_prog(&progs->stream_parser, NULL);
 497        psock_set_prog(&progs->stream_verdict, NULL);
 498        psock_set_prog(&progs->skb_verdict, NULL);
 499}
 500
 501int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);
 502
 503static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
 504{
 505        if (!psock)
 506                return false;
 507        return !!psock->saved_data_ready;
 508}
 509
 510#if IS_ENABLED(CONFIG_NET_SOCK_MSG)
 511
 512/* We only have one bit so far. */
 513#define BPF_F_PTR_MASK ~(BPF_F_INGRESS)
 514
 515static inline bool skb_bpf_ingress(const struct sk_buff *skb)
 516{
 517        unsigned long sk_redir = skb->_sk_redir;
 518
 519        return sk_redir & BPF_F_INGRESS;
 520}
 521
 522static inline void skb_bpf_set_ingress(struct sk_buff *skb)
 523{
 524        skb->_sk_redir |= BPF_F_INGRESS;
 525}
 526
 527static inline void skb_bpf_set_redir(struct sk_buff *skb, struct sock *sk_redir,
 528                                     bool ingress)
 529{
 530        skb->_sk_redir = (unsigned long)sk_redir;
 531        if (ingress)
 532                skb->_sk_redir |= BPF_F_INGRESS;
 533}
 534
 535static inline struct sock *skb_bpf_redirect_fetch(const struct sk_buff *skb)
 536{
 537        unsigned long sk_redir = skb->_sk_redir;
 538
 539        return (struct sock *)(sk_redir & BPF_F_PTR_MASK);
 540}
 541
 542static inline void skb_bpf_redirect_clear(struct sk_buff *skb)
 543{
 544        skb->_sk_redir = 0;
 545}
 546#endif /* CONFIG_NET_SOCK_MSG */
 547#endif /* _LINUX_SKMSG_H */
 548