linux/include/linux/skmsg.h
<<
>>
Prefs
   1/* SPDX-License-Identifier: GPL-2.0 */
   2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3
   4#ifndef _LINUX_SKMSG_H
   5#define _LINUX_SKMSG_H
   6
   7#include <linux/bpf.h>
   8#include <linux/filter.h>
   9#include <linux/scatterlist.h>
  10#include <linux/skbuff.h>
  11
  12#include <net/sock.h>
  13#include <net/tcp.h>
  14#include <net/strparser.h>
  15
  16#define MAX_MSG_FRAGS                   MAX_SKB_FRAGS
  17
  18enum __sk_action {
  19        __SK_DROP = 0,
  20        __SK_PASS,
  21        __SK_REDIRECT,
  22        __SK_NONE,
  23};
  24
  25struct sk_msg_sg {
  26        u32                             start;
  27        u32                             curr;
  28        u32                             end;
  29        u32                             size;
  30        u32                             copybreak;
  31        bool                            copy[MAX_MSG_FRAGS];
  32        /* The extra element is used for chaining the front and sections when
  33         * the list becomes partitioned (e.g. end < start). The crypto APIs
  34         * require the chaining.
  35         */
  36        struct scatterlist              data[MAX_MSG_FRAGS + 1];
  37};
  38
  39/* UAPI in filter.c depends on struct sk_msg_sg being first element. */
  40struct sk_msg {
  41        struct sk_msg_sg                sg;
  42        void                            *data;
  43        void                            *data_end;
  44        u32                             apply_bytes;
  45        u32                             cork_bytes;
  46        u32                             flags;
  47        struct sk_buff                  *skb;
  48        struct sock                     *sk_redir;
  49        struct sock                     *sk;
  50        struct list_head                list;
  51};
  52
  53struct sk_psock_progs {
  54        struct bpf_prog                 *msg_parser;
  55        struct bpf_prog                 *skb_parser;
  56        struct bpf_prog                 *skb_verdict;
  57};
  58
  59enum sk_psock_state_bits {
  60        SK_PSOCK_TX_ENABLED,
  61};
  62
  63struct sk_psock_link {
  64        struct list_head                list;
  65        struct bpf_map                  *map;
  66        void                            *link_raw;
  67};
  68
  69struct sk_psock_parser {
  70        struct strparser                strp;
  71        bool                            enabled;
  72        void (*saved_data_ready)(struct sock *sk);
  73};
  74
  75struct sk_psock_work_state {
  76        struct sk_buff                  *skb;
  77        u32                             len;
  78        u32                             off;
  79};
  80
  81struct sk_psock {
  82        struct sock                     *sk;
  83        struct sock                     *sk_redir;
  84        u32                             apply_bytes;
  85        u32                             cork_bytes;
  86        u32                             eval;
  87        struct sk_msg                   *cork;
  88        struct sk_psock_progs           progs;
  89        struct sk_psock_parser          parser;
  90        struct sk_buff_head             ingress_skb;
  91        struct list_head                ingress_msg;
  92        unsigned long                   state;
  93        struct list_head                link;
  94        spinlock_t                      link_lock;
  95        refcount_t                      refcnt;
  96        void (*saved_unhash)(struct sock *sk);
  97        void (*saved_close)(struct sock *sk, long timeout);
  98        void (*saved_write_space)(struct sock *sk);
  99        struct proto                    *sk_proto;
 100        struct sk_psock_work_state      work_state;
 101        struct work_struct              work;
 102        union {
 103                struct rcu_head         rcu;
 104                struct work_struct      gc;
 105        };
 106};
 107
 108int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
 109                 int elem_first_coalesce);
 110int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
 111                 u32 off, u32 len);
 112void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
 113int sk_msg_free(struct sock *sk, struct sk_msg *msg);
 114int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
 115void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
 116void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
 117                                  u32 bytes);
 118
 119void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
 120void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
 121
 122int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
 123                              struct sk_msg *msg, u32 bytes);
 124int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
 125                             struct sk_msg *msg, u32 bytes);
 126
 127static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
 128{
 129        WARN_ON(i == msg->sg.end && bytes);
 130}
 131
 132static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
 133{
 134        if (psock->apply_bytes) {
 135                if (psock->apply_bytes < bytes)
 136                        psock->apply_bytes = 0;
 137                else
 138                        psock->apply_bytes -= bytes;
 139        }
 140}
 141
 142static inline u32 sk_msg_iter_dist(u32 start, u32 end)
 143{
 144        return end >= start ? end - start : end + (MAX_MSG_FRAGS - start);
 145}
 146
 147#define sk_msg_iter_var_prev(var)                       \
 148        do {                                            \
 149                if (var == 0)                           \
 150                        var = MAX_MSG_FRAGS - 1;        \
 151                else                                    \
 152                        var--;                          \
 153        } while (0)
 154
 155#define sk_msg_iter_var_next(var)                       \
 156        do {                                            \
 157                var++;                                  \
 158                if (var == MAX_MSG_FRAGS)               \
 159                        var = 0;                        \
 160        } while (0)
 161
 162#define sk_msg_iter_prev(msg, which)                    \
 163        sk_msg_iter_var_prev(msg->sg.which)
 164
 165#define sk_msg_iter_next(msg, which)                    \
 166        sk_msg_iter_var_next(msg->sg.which)
 167
 168static inline void sk_msg_clear_meta(struct sk_msg *msg)
 169{
 170        memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
 171}
 172
 173static inline void sk_msg_init(struct sk_msg *msg)
 174{
 175        BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS);
 176        memset(msg, 0, sizeof(*msg));
 177        sg_init_marker(msg->sg.data, MAX_MSG_FRAGS);
 178}
 179
 180static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
 181                               int which, u32 size)
 182{
 183        dst->sg.data[which] = src->sg.data[which];
 184        dst->sg.data[which].length  = size;
 185        dst->sg.size               += size;
 186        src->sg.data[which].length -= size;
 187        src->sg.data[which].offset += size;
 188}
 189
 190static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
 191{
 192        memcpy(dst, src, sizeof(*src));
 193        sk_msg_init(src);
 194}
 195
 196static inline bool sk_msg_full(const struct sk_msg *msg)
 197{
 198        return (msg->sg.end == msg->sg.start) && msg->sg.size;
 199}
 200
 201static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
 202{
 203        if (sk_msg_full(msg))
 204                return MAX_MSG_FRAGS;
 205
 206        return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
 207}
 208
 209static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
 210{
 211        return &msg->sg.data[which];
 212}
 213
 214static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
 215{
 216        return msg->sg.data[which];
 217}
 218
 219static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
 220{
 221        return sg_page(sk_msg_elem(msg, which));
 222}
 223
 224static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
 225{
 226        return msg->flags & BPF_F_INGRESS;
 227}
 228
 229static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
 230{
 231        struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
 232
 233        if (msg->sg.copy[msg->sg.start]) {
 234                msg->data = NULL;
 235                msg->data_end = NULL;
 236        } else {
 237                msg->data = sg_virt(sge);
 238                msg->data_end = msg->data + sge->length;
 239        }
 240}
 241
 242static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
 243                                   u32 len, u32 offset)
 244{
 245        struct scatterlist *sge;
 246
 247        get_page(page);
 248        sge = sk_msg_elem(msg, msg->sg.end);
 249        sg_set_page(sge, page, len, offset);
 250        sg_unmark_end(sge);
 251
 252        msg->sg.copy[msg->sg.end] = true;
 253        msg->sg.size += len;
 254        sk_msg_iter_next(msg, end);
 255}
 256
 257static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
 258{
 259        do {
 260                msg->sg.copy[i] = copy_state;
 261                sk_msg_iter_var_next(i);
 262                if (i == msg->sg.end)
 263                        break;
 264        } while (1);
 265}
 266
 267static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
 268{
 269        sk_msg_sg_copy(msg, start, true);
 270}
 271
 272static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
 273{
 274        sk_msg_sg_copy(msg, start, false);
 275}
 276
 277static inline struct sk_psock *sk_psock(const struct sock *sk)
 278{
 279        return rcu_dereference_sk_user_data(sk);
 280}
 281
 282static inline void sk_psock_queue_msg(struct sk_psock *psock,
 283                                      struct sk_msg *msg)
 284{
 285        list_add_tail(&msg->list, &psock->ingress_msg);
 286}
 287
 288static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
 289{
 290        return psock ? list_empty(&psock->ingress_msg) : true;
 291}
 292
 293static inline void sk_psock_report_error(struct sk_psock *psock, int err)
 294{
 295        struct sock *sk = psock->sk;
 296
 297        sk->sk_err = err;
 298        sk->sk_error_report(sk);
 299}
 300
 301struct sk_psock *sk_psock_init(struct sock *sk, int node);
 302
 303int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
 304void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
 305void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
 306
 307int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
 308                         struct sk_msg *msg);
 309
 310static inline struct sk_psock_link *sk_psock_init_link(void)
 311{
 312        return kzalloc(sizeof(struct sk_psock_link),
 313                       GFP_ATOMIC | __GFP_NOWARN);
 314}
 315
 316static inline void sk_psock_free_link(struct sk_psock_link *link)
 317{
 318        kfree(link);
 319}
 320
 321struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
 322#if defined(CONFIG_BPF_STREAM_PARSER)
 323void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
 324#else
 325static inline void sk_psock_unlink(struct sock *sk,
 326                                   struct sk_psock_link *link)
 327{
 328}
 329#endif
 330
 331void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
 332
 333static inline void sk_psock_cork_free(struct sk_psock *psock)
 334{
 335        if (psock->cork) {
 336                sk_msg_free(psock->sk, psock->cork);
 337                kfree(psock->cork);
 338                psock->cork = NULL;
 339        }
 340}
 341
 342static inline void sk_psock_update_proto(struct sock *sk,
 343                                         struct sk_psock *psock,
 344                                         struct proto *ops)
 345{
 346        psock->saved_unhash = sk->sk_prot->unhash;
 347        psock->saved_close = sk->sk_prot->close;
 348        psock->saved_write_space = sk->sk_write_space;
 349
 350        psock->sk_proto = sk->sk_prot;
 351        sk->sk_prot = ops;
 352}
 353
 354static inline void sk_psock_restore_proto(struct sock *sk,
 355                                          struct sk_psock *psock)
 356{
 357        sk->sk_write_space = psock->saved_write_space;
 358
 359        if (psock->sk_proto) {
 360                struct inet_connection_sock *icsk = inet_csk(sk);
 361                bool has_ulp = !!icsk->icsk_ulp_data;
 362
 363                if (has_ulp)
 364                        tcp_update_ulp(sk, psock->sk_proto);
 365                else
 366                        sk->sk_prot = psock->sk_proto;
 367                psock->sk_proto = NULL;
 368        }
 369}
 370
 371static inline void sk_psock_set_state(struct sk_psock *psock,
 372                                      enum sk_psock_state_bits bit)
 373{
 374        set_bit(bit, &psock->state);
 375}
 376
 377static inline void sk_psock_clear_state(struct sk_psock *psock,
 378                                        enum sk_psock_state_bits bit)
 379{
 380        clear_bit(bit, &psock->state);
 381}
 382
 383static inline bool sk_psock_test_state(const struct sk_psock *psock,
 384                                       enum sk_psock_state_bits bit)
 385{
 386        return test_bit(bit, &psock->state);
 387}
 388
 389static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
 390{
 391        struct sk_psock *psock;
 392
 393        rcu_read_lock();
 394        psock = sk_psock(sk);
 395        if (psock) {
 396                if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
 397                        psock = ERR_PTR(-EBUSY);
 398                        goto out;
 399                }
 400
 401                if (!refcount_inc_not_zero(&psock->refcnt))
 402                        psock = ERR_PTR(-EBUSY);
 403        }
 404out:
 405        rcu_read_unlock();
 406        return psock;
 407}
 408
 409static inline struct sk_psock *sk_psock_get(struct sock *sk)
 410{
 411        struct sk_psock *psock;
 412
 413        rcu_read_lock();
 414        psock = sk_psock(sk);
 415        if (psock && !refcount_inc_not_zero(&psock->refcnt))
 416                psock = NULL;
 417        rcu_read_unlock();
 418        return psock;
 419}
 420
 421void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
 422void sk_psock_destroy(struct rcu_head *rcu);
 423void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
 424
 425static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
 426{
 427        if (refcount_dec_and_test(&psock->refcnt))
 428                sk_psock_drop(sk, psock);
 429}
 430
 431static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
 432{
 433        if (psock->parser.enabled)
 434                psock->parser.saved_data_ready(sk);
 435        else
 436                sk->sk_data_ready(sk);
 437}
 438
 439static inline void psock_set_prog(struct bpf_prog **pprog,
 440                                  struct bpf_prog *prog)
 441{
 442        prog = xchg(pprog, prog);
 443        if (prog)
 444                bpf_prog_put(prog);
 445}
 446
 447static inline void psock_progs_drop(struct sk_psock_progs *progs)
 448{
 449        psock_set_prog(&progs->msg_parser, NULL);
 450        psock_set_prog(&progs->skb_parser, NULL);
 451        psock_set_prog(&progs->skb_verdict, NULL);
 452}
 453
 454#endif /* _LINUX_SKMSG_H */
 455