linux/net/core/bpf_sk_storage.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2019 Facebook  */
   3#include <linux/rculist.h>
   4#include <linux/list.h>
   5#include <linux/hash.h>
   6#include <linux/types.h>
   7#include <linux/spinlock.h>
   8#include <linux/bpf.h>
   9#include <net/bpf_sk_storage.h>
  10#include <net/sock.h>
  11#include <uapi/linux/sock_diag.h>
  12#include <uapi/linux/btf.h>
  13
  14static atomic_t cache_idx;
  15
  16#define SK_STORAGE_CREATE_FLAG_MASK                                     \
  17        (BPF_F_NO_PREALLOC | BPF_F_CLONE)
  18
  19struct bucket {
  20        struct hlist_head list;
  21        raw_spinlock_t lock;
  22};
  23
  24/* Thp map is not the primary owner of a bpf_sk_storage_elem.
  25 * Instead, the sk->sk_bpf_storage is.
  26 *
  27 * The map (bpf_sk_storage_map) is for two purposes
  28 * 1. Define the size of the "sk local storage".  It is
  29 *    the map's value_size.
  30 *
  31 * 2. Maintain a list to keep track of all elems such
  32 *    that they can be cleaned up during the map destruction.
  33 *
  34 * When a bpf local storage is being looked up for a
  35 * particular sk,  the "bpf_map" pointer is actually used
  36 * as the "key" to search in the list of elem in
  37 * sk->sk_bpf_storage.
  38 *
  39 * Hence, consider sk->sk_bpf_storage is the mini-map
  40 * with the "bpf_map" pointer as the searching key.
  41 */
  42struct bpf_sk_storage_map {
  43        struct bpf_map map;
  44        /* Lookup elem does not require accessing the map.
  45         *
  46         * Updating/Deleting requires a bucket lock to
  47         * link/unlink the elem from the map.  Having
  48         * multiple buckets to improve contention.
  49         */
  50        struct bucket *buckets;
  51        u32 bucket_log;
  52        u16 elem_size;
  53        u16 cache_idx;
  54};
  55
  56struct bpf_sk_storage_data {
  57        /* smap is used as the searching key when looking up
  58         * from sk->sk_bpf_storage.
  59         *
  60         * Put it in the same cacheline as the data to minimize
  61         * the number of cachelines access during the cache hit case.
  62         */
  63        struct bpf_sk_storage_map __rcu *smap;
  64        u8 data[] __aligned(8);
  65};
  66
  67/* Linked to bpf_sk_storage and bpf_sk_storage_map */
  68struct bpf_sk_storage_elem {
  69        struct hlist_node map_node;     /* Linked to bpf_sk_storage_map */
  70        struct hlist_node snode;        /* Linked to bpf_sk_storage */
  71        struct bpf_sk_storage __rcu *sk_storage;
  72        struct rcu_head rcu;
  73        /* 8 bytes hole */
  74        /* The data is stored in aother cacheline to minimize
  75         * the number of cachelines access during a cache hit.
  76         */
  77        struct bpf_sk_storage_data sdata ____cacheline_aligned;
  78};
  79
  80#define SELEM(_SDATA) container_of((_SDATA), struct bpf_sk_storage_elem, sdata)
  81#define SDATA(_SELEM) (&(_SELEM)->sdata)
  82#define BPF_SK_STORAGE_CACHE_SIZE       16
  83
  84struct bpf_sk_storage {
  85        struct bpf_sk_storage_data __rcu *cache[BPF_SK_STORAGE_CACHE_SIZE];
  86        struct hlist_head list; /* List of bpf_sk_storage_elem */
  87        struct sock *sk;        /* The sk that owns the the above "list" of
  88                                 * bpf_sk_storage_elem.
  89                                 */
  90        struct rcu_head rcu;
  91        raw_spinlock_t lock;    /* Protect adding/removing from the "list" */
  92};
  93
  94static struct bucket *select_bucket(struct bpf_sk_storage_map *smap,
  95                                    struct bpf_sk_storage_elem *selem)
  96{
  97        return &smap->buckets[hash_ptr(selem, smap->bucket_log)];
  98}
  99
 100static int omem_charge(struct sock *sk, unsigned int size)
 101{
 102        /* same check as in sock_kmalloc() */
 103        if (size <= sysctl_optmem_max &&
 104            atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
 105                atomic_add(size, &sk->sk_omem_alloc);
 106                return 0;
 107        }
 108
 109        return -ENOMEM;
 110}
 111
 112static bool selem_linked_to_sk(const struct bpf_sk_storage_elem *selem)
 113{
 114        return !hlist_unhashed(&selem->snode);
 115}
 116
 117static bool selem_linked_to_map(const struct bpf_sk_storage_elem *selem)
 118{
 119        return !hlist_unhashed(&selem->map_node);
 120}
 121
 122static struct bpf_sk_storage_elem *selem_alloc(struct bpf_sk_storage_map *smap,
 123                                               struct sock *sk, void *value,
 124                                               bool charge_omem)
 125{
 126        struct bpf_sk_storage_elem *selem;
 127
 128        if (charge_omem && omem_charge(sk, smap->elem_size))
 129                return NULL;
 130
 131        selem = kzalloc(smap->elem_size, GFP_ATOMIC | __GFP_NOWARN);
 132        if (selem) {
 133                if (value)
 134                        memcpy(SDATA(selem)->data, value, smap->map.value_size);
 135                return selem;
 136        }
 137
 138        if (charge_omem)
 139                atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
 140
 141        return NULL;
 142}
 143
 144/* sk_storage->lock must be held and selem->sk_storage == sk_storage.
 145 * The caller must ensure selem->smap is still valid to be
 146 * dereferenced for its smap->elem_size and smap->cache_idx.
 147 */
 148static bool __selem_unlink_sk(struct bpf_sk_storage *sk_storage,
 149                              struct bpf_sk_storage_elem *selem,
 150                              bool uncharge_omem)
 151{
 152        struct bpf_sk_storage_map *smap;
 153        bool free_sk_storage;
 154        struct sock *sk;
 155
 156        smap = rcu_dereference(SDATA(selem)->smap);
 157        sk = sk_storage->sk;
 158
 159        /* All uncharging on sk->sk_omem_alloc must be done first.
 160         * sk may be freed once the last selem is unlinked from sk_storage.
 161         */
 162        if (uncharge_omem)
 163                atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
 164
 165        free_sk_storage = hlist_is_singular_node(&selem->snode,
 166                                                 &sk_storage->list);
 167        if (free_sk_storage) {
 168                atomic_sub(sizeof(struct bpf_sk_storage), &sk->sk_omem_alloc);
 169                sk_storage->sk = NULL;
 170                /* After this RCU_INIT, sk may be freed and cannot be used */
 171                RCU_INIT_POINTER(sk->sk_bpf_storage, NULL);
 172
 173                /* sk_storage is not freed now.  sk_storage->lock is
 174                 * still held and raw_spin_unlock_bh(&sk_storage->lock)
 175                 * will be done by the caller.
 176                 *
 177                 * Although the unlock will be done under
 178                 * rcu_read_lock(),  it is more intutivie to
 179                 * read if kfree_rcu(sk_storage, rcu) is done
 180                 * after the raw_spin_unlock_bh(&sk_storage->lock).
 181                 *
 182                 * Hence, a "bool free_sk_storage" is returned
 183                 * to the caller which then calls the kfree_rcu()
 184                 * after unlock.
 185                 */
 186        }
 187        hlist_del_init_rcu(&selem->snode);
 188        if (rcu_access_pointer(sk_storage->cache[smap->cache_idx]) ==
 189            SDATA(selem))
 190                RCU_INIT_POINTER(sk_storage->cache[smap->cache_idx], NULL);
 191
 192        kfree_rcu(selem, rcu);
 193
 194        return free_sk_storage;
 195}
 196
 197static void selem_unlink_sk(struct bpf_sk_storage_elem *selem)
 198{
 199        struct bpf_sk_storage *sk_storage;
 200        bool free_sk_storage = false;
 201
 202        if (unlikely(!selem_linked_to_sk(selem)))
 203                /* selem has already been unlinked from sk */
 204                return;
 205
 206        sk_storage = rcu_dereference(selem->sk_storage);
 207        raw_spin_lock_bh(&sk_storage->lock);
 208        if (likely(selem_linked_to_sk(selem)))
 209                free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
 210        raw_spin_unlock_bh(&sk_storage->lock);
 211
 212        if (free_sk_storage)
 213                kfree_rcu(sk_storage, rcu);
 214}
 215
 216static void __selem_link_sk(struct bpf_sk_storage *sk_storage,
 217                            struct bpf_sk_storage_elem *selem)
 218{
 219        RCU_INIT_POINTER(selem->sk_storage, sk_storage);
 220        hlist_add_head(&selem->snode, &sk_storage->list);
 221}
 222
 223static void selem_unlink_map(struct bpf_sk_storage_elem *selem)
 224{
 225        struct bpf_sk_storage_map *smap;
 226        struct bucket *b;
 227
 228        if (unlikely(!selem_linked_to_map(selem)))
 229                /* selem has already be unlinked from smap */
 230                return;
 231
 232        smap = rcu_dereference(SDATA(selem)->smap);
 233        b = select_bucket(smap, selem);
 234        raw_spin_lock_bh(&b->lock);
 235        if (likely(selem_linked_to_map(selem)))
 236                hlist_del_init_rcu(&selem->map_node);
 237        raw_spin_unlock_bh(&b->lock);
 238}
 239
 240static void selem_link_map(struct bpf_sk_storage_map *smap,
 241                           struct bpf_sk_storage_elem *selem)
 242{
 243        struct bucket *b = select_bucket(smap, selem);
 244
 245        raw_spin_lock_bh(&b->lock);
 246        RCU_INIT_POINTER(SDATA(selem)->smap, smap);
 247        hlist_add_head_rcu(&selem->map_node, &b->list);
 248        raw_spin_unlock_bh(&b->lock);
 249}
 250
 251static void selem_unlink(struct bpf_sk_storage_elem *selem)
 252{
 253        /* Always unlink from map before unlinking from sk_storage
 254         * because selem will be freed after successfully unlinked from
 255         * the sk_storage.
 256         */
 257        selem_unlink_map(selem);
 258        selem_unlink_sk(selem);
 259}
 260
 261static struct bpf_sk_storage_data *
 262__sk_storage_lookup(struct bpf_sk_storage *sk_storage,
 263                    struct bpf_sk_storage_map *smap,
 264                    bool cacheit_lockit)
 265{
 266        struct bpf_sk_storage_data *sdata;
 267        struct bpf_sk_storage_elem *selem;
 268
 269        /* Fast path (cache hit) */
 270        sdata = rcu_dereference(sk_storage->cache[smap->cache_idx]);
 271        if (sdata && rcu_access_pointer(sdata->smap) == smap)
 272                return sdata;
 273
 274        /* Slow path (cache miss) */
 275        hlist_for_each_entry_rcu(selem, &sk_storage->list, snode)
 276                if (rcu_access_pointer(SDATA(selem)->smap) == smap)
 277                        break;
 278
 279        if (!selem)
 280                return NULL;
 281
 282        sdata = SDATA(selem);
 283        if (cacheit_lockit) {
 284                /* spinlock is needed to avoid racing with the
 285                 * parallel delete.  Otherwise, publishing an already
 286                 * deleted sdata to the cache will become a use-after-free
 287                 * problem in the next __sk_storage_lookup().
 288                 */
 289                raw_spin_lock_bh(&sk_storage->lock);
 290                if (selem_linked_to_sk(selem))
 291                        rcu_assign_pointer(sk_storage->cache[smap->cache_idx],
 292                                           sdata);
 293                raw_spin_unlock_bh(&sk_storage->lock);
 294        }
 295
 296        return sdata;
 297}
 298
 299static struct bpf_sk_storage_data *
 300sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
 301{
 302        struct bpf_sk_storage *sk_storage;
 303        struct bpf_sk_storage_map *smap;
 304
 305        sk_storage = rcu_dereference(sk->sk_bpf_storage);
 306        if (!sk_storage)
 307                return NULL;
 308
 309        smap = (struct bpf_sk_storage_map *)map;
 310        return __sk_storage_lookup(sk_storage, smap, cacheit_lockit);
 311}
 312
 313static int check_flags(const struct bpf_sk_storage_data *old_sdata,
 314                       u64 map_flags)
 315{
 316        if (old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_NOEXIST)
 317                /* elem already exists */
 318                return -EEXIST;
 319
 320        if (!old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_EXIST)
 321                /* elem doesn't exist, cannot update it */
 322                return -ENOENT;
 323
 324        return 0;
 325}
 326
 327static int sk_storage_alloc(struct sock *sk,
 328                            struct bpf_sk_storage_map *smap,
 329                            struct bpf_sk_storage_elem *first_selem)
 330{
 331        struct bpf_sk_storage *prev_sk_storage, *sk_storage;
 332        int err;
 333
 334        err = omem_charge(sk, sizeof(*sk_storage));
 335        if (err)
 336                return err;
 337
 338        sk_storage = kzalloc(sizeof(*sk_storage), GFP_ATOMIC | __GFP_NOWARN);
 339        if (!sk_storage) {
 340                err = -ENOMEM;
 341                goto uncharge;
 342        }
 343        INIT_HLIST_HEAD(&sk_storage->list);
 344        raw_spin_lock_init(&sk_storage->lock);
 345        sk_storage->sk = sk;
 346
 347        __selem_link_sk(sk_storage, first_selem);
 348        selem_link_map(smap, first_selem);
 349        /* Publish sk_storage to sk.  sk->sk_lock cannot be acquired.
 350         * Hence, atomic ops is used to set sk->sk_bpf_storage
 351         * from NULL to the newly allocated sk_storage ptr.
 352         *
 353         * From now on, the sk->sk_bpf_storage pointer is protected
 354         * by the sk_storage->lock.  Hence,  when freeing
 355         * the sk->sk_bpf_storage, the sk_storage->lock must
 356         * be held before setting sk->sk_bpf_storage to NULL.
 357         */
 358        prev_sk_storage = cmpxchg((struct bpf_sk_storage **)&sk->sk_bpf_storage,
 359                                  NULL, sk_storage);
 360        if (unlikely(prev_sk_storage)) {
 361                selem_unlink_map(first_selem);
 362                err = -EAGAIN;
 363                goto uncharge;
 364
 365                /* Note that even first_selem was linked to smap's
 366                 * bucket->list, first_selem can be freed immediately
 367                 * (instead of kfree_rcu) because
 368                 * bpf_sk_storage_map_free() does a
 369                 * synchronize_rcu() before walking the bucket->list.
 370                 * Hence, no one is accessing selem from the
 371                 * bucket->list under rcu_read_lock().
 372                 */
 373        }
 374
 375        return 0;
 376
 377uncharge:
 378        kfree(sk_storage);
 379        atomic_sub(sizeof(*sk_storage), &sk->sk_omem_alloc);
 380        return err;
 381}
 382
 383/* sk cannot be going away because it is linking new elem
 384 * to sk->sk_bpf_storage. (i.e. sk->sk_refcnt cannot be 0).
 385 * Otherwise, it will become a leak (and other memory issues
 386 * during map destruction).
 387 */
 388static struct bpf_sk_storage_data *sk_storage_update(struct sock *sk,
 389                                                     struct bpf_map *map,
 390                                                     void *value,
 391                                                     u64 map_flags)
 392{
 393        struct bpf_sk_storage_data *old_sdata = NULL;
 394        struct bpf_sk_storage_elem *selem;
 395        struct bpf_sk_storage *sk_storage;
 396        struct bpf_sk_storage_map *smap;
 397        int err;
 398
 399        /* BPF_EXIST and BPF_NOEXIST cannot be both set */
 400        if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
 401            /* BPF_F_LOCK can only be used in a value with spin_lock */
 402            unlikely((map_flags & BPF_F_LOCK) && !map_value_has_spin_lock(map)))
 403                return ERR_PTR(-EINVAL);
 404
 405        smap = (struct bpf_sk_storage_map *)map;
 406        sk_storage = rcu_dereference(sk->sk_bpf_storage);
 407        if (!sk_storage || hlist_empty(&sk_storage->list)) {
 408                /* Very first elem for this sk */
 409                err = check_flags(NULL, map_flags);
 410                if (err)
 411                        return ERR_PTR(err);
 412
 413                selem = selem_alloc(smap, sk, value, true);
 414                if (!selem)
 415                        return ERR_PTR(-ENOMEM);
 416
 417                err = sk_storage_alloc(sk, smap, selem);
 418                if (err) {
 419                        kfree(selem);
 420                        atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
 421                        return ERR_PTR(err);
 422                }
 423
 424                return SDATA(selem);
 425        }
 426
 427        if ((map_flags & BPF_F_LOCK) && !(map_flags & BPF_NOEXIST)) {
 428                /* Hoping to find an old_sdata to do inline update
 429                 * such that it can avoid taking the sk_storage->lock
 430                 * and changing the lists.
 431                 */
 432                old_sdata = __sk_storage_lookup(sk_storage, smap, false);
 433                err = check_flags(old_sdata, map_flags);
 434                if (err)
 435                        return ERR_PTR(err);
 436                if (old_sdata && selem_linked_to_sk(SELEM(old_sdata))) {
 437                        copy_map_value_locked(map, old_sdata->data,
 438                                              value, false);
 439                        return old_sdata;
 440                }
 441        }
 442
 443        raw_spin_lock_bh(&sk_storage->lock);
 444
 445        /* Recheck sk_storage->list under sk_storage->lock */
 446        if (unlikely(hlist_empty(&sk_storage->list))) {
 447                /* A parallel del is happening and sk_storage is going
 448                 * away.  It has just been checked before, so very
 449                 * unlikely.  Return instead of retry to keep things
 450                 * simple.
 451                 */
 452                err = -EAGAIN;
 453                goto unlock_err;
 454        }
 455
 456        old_sdata = __sk_storage_lookup(sk_storage, smap, false);
 457        err = check_flags(old_sdata, map_flags);
 458        if (err)
 459                goto unlock_err;
 460
 461        if (old_sdata && (map_flags & BPF_F_LOCK)) {
 462                copy_map_value_locked(map, old_sdata->data, value, false);
 463                selem = SELEM(old_sdata);
 464                goto unlock;
 465        }
 466
 467        /* sk_storage->lock is held.  Hence, we are sure
 468         * we can unlink and uncharge the old_sdata successfully
 469         * later.  Hence, instead of charging the new selem now
 470         * and then uncharge the old selem later (which may cause
 471         * a potential but unnecessary charge failure),  avoid taking
 472         * a charge at all here (the "!old_sdata" check) and the
 473         * old_sdata will not be uncharged later during __selem_unlink_sk().
 474         */
 475        selem = selem_alloc(smap, sk, value, !old_sdata);
 476        if (!selem) {
 477                err = -ENOMEM;
 478                goto unlock_err;
 479        }
 480
 481        /* First, link the new selem to the map */
 482        selem_link_map(smap, selem);
 483
 484        /* Second, link (and publish) the new selem to sk_storage */
 485        __selem_link_sk(sk_storage, selem);
 486
 487        /* Third, remove old selem, SELEM(old_sdata) */
 488        if (old_sdata) {
 489                selem_unlink_map(SELEM(old_sdata));
 490                __selem_unlink_sk(sk_storage, SELEM(old_sdata), false);
 491        }
 492
 493unlock:
 494        raw_spin_unlock_bh(&sk_storage->lock);
 495        return SDATA(selem);
 496
 497unlock_err:
 498        raw_spin_unlock_bh(&sk_storage->lock);
 499        return ERR_PTR(err);
 500}
 501
 502static int sk_storage_delete(struct sock *sk, struct bpf_map *map)
 503{
 504        struct bpf_sk_storage_data *sdata;
 505
 506        sdata = sk_storage_lookup(sk, map, false);
 507        if (!sdata)
 508                return -ENOENT;
 509
 510        selem_unlink(SELEM(sdata));
 511
 512        return 0;
 513}
 514
 515/* Called by __sk_destruct() & bpf_sk_storage_clone() */
 516void bpf_sk_storage_free(struct sock *sk)
 517{
 518        struct bpf_sk_storage_elem *selem;
 519        struct bpf_sk_storage *sk_storage;
 520        bool free_sk_storage = false;
 521        struct hlist_node *n;
 522
 523        rcu_read_lock();
 524        sk_storage = rcu_dereference(sk->sk_bpf_storage);
 525        if (!sk_storage) {
 526                rcu_read_unlock();
 527                return;
 528        }
 529
 530        /* Netiher the bpf_prog nor the bpf-map's syscall
 531         * could be modifying the sk_storage->list now.
 532         * Thus, no elem can be added-to or deleted-from the
 533         * sk_storage->list by the bpf_prog or by the bpf-map's syscall.
 534         *
 535         * It is racing with bpf_sk_storage_map_free() alone
 536         * when unlinking elem from the sk_storage->list and
 537         * the map's bucket->list.
 538         */
 539        raw_spin_lock_bh(&sk_storage->lock);
 540        hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) {
 541                /* Always unlink from map before unlinking from
 542                 * sk_storage.
 543                 */
 544                selem_unlink_map(selem);
 545                free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
 546        }
 547        raw_spin_unlock_bh(&sk_storage->lock);
 548        rcu_read_unlock();
 549
 550        if (free_sk_storage)
 551                kfree_rcu(sk_storage, rcu);
 552}
 553
 554static void bpf_sk_storage_map_free(struct bpf_map *map)
 555{
 556        struct bpf_sk_storage_elem *selem;
 557        struct bpf_sk_storage_map *smap;
 558        struct bucket *b;
 559        unsigned int i;
 560
 561        smap = (struct bpf_sk_storage_map *)map;
 562
 563        /* Note that this map might be concurrently cloned from
 564         * bpf_sk_storage_clone. Wait for any existing bpf_sk_storage_clone
 565         * RCU read section to finish before proceeding. New RCU
 566         * read sections should be prevented via bpf_map_inc_not_zero.
 567         */
 568        synchronize_rcu();
 569
 570        /* bpf prog and the userspace can no longer access this map
 571         * now.  No new selem (of this map) can be added
 572         * to the sk->sk_bpf_storage or to the map bucket's list.
 573         *
 574         * The elem of this map can be cleaned up here
 575         * or
 576         * by bpf_sk_storage_free() during __sk_destruct().
 577         */
 578        for (i = 0; i < (1U << smap->bucket_log); i++) {
 579                b = &smap->buckets[i];
 580
 581                rcu_read_lock();
 582                /* No one is adding to b->list now */
 583                while ((selem = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(&b->list)),
 584                                                 struct bpf_sk_storage_elem,
 585                                                 map_node))) {
 586                        selem_unlink(selem);
 587                        cond_resched_rcu();
 588                }
 589                rcu_read_unlock();
 590        }
 591
 592        /* bpf_sk_storage_free() may still need to access the map.
 593         * e.g. bpf_sk_storage_free() has unlinked selem from the map
 594         * which then made the above while((selem = ...)) loop
 595         * exited immediately.
 596         *
 597         * However, the bpf_sk_storage_free() still needs to access
 598         * the smap->elem_size to do the uncharging in
 599         * __selem_unlink_sk().
 600         *
 601         * Hence, wait another rcu grace period for the
 602         * bpf_sk_storage_free() to finish.
 603         */
 604        synchronize_rcu();
 605
 606        kvfree(smap->buckets);
 607        kfree(map);
 608}
 609
 610/* U16_MAX is much more than enough for sk local storage
 611 * considering a tcp_sock is ~2k.
 612 */
 613#define MAX_VALUE_SIZE                                                  \
 614        min_t(u32,                                                      \
 615              (KMALLOC_MAX_SIZE - MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem)), \
 616              (U16_MAX - sizeof(struct bpf_sk_storage_elem)))
 617
 618static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
 619{
 620        if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK ||
 621            !(attr->map_flags & BPF_F_NO_PREALLOC) ||
 622            attr->max_entries ||
 623            attr->key_size != sizeof(int) || !attr->value_size ||
 624            /* Enforce BTF for userspace sk dumping */
 625            !attr->btf_key_type_id || !attr->btf_value_type_id)
 626                return -EINVAL;
 627
 628        if (!capable(CAP_SYS_ADMIN))
 629                return -EPERM;
 630
 631        if (attr->value_size > MAX_VALUE_SIZE)
 632                return -E2BIG;
 633
 634        return 0;
 635}
 636
 637static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
 638{
 639        struct bpf_sk_storage_map *smap;
 640        unsigned int i;
 641        u32 nbuckets;
 642        u64 cost;
 643        int ret;
 644
 645        smap = kzalloc(sizeof(*smap), GFP_USER | __GFP_NOWARN);
 646        if (!smap)
 647                return ERR_PTR(-ENOMEM);
 648        bpf_map_init_from_attr(&smap->map, attr);
 649
 650        nbuckets = roundup_pow_of_two(num_possible_cpus());
 651        /* Use at least 2 buckets, select_bucket() is undefined behavior with 1 bucket */
 652        nbuckets = max_t(u32, 2, nbuckets);
 653        smap->bucket_log = ilog2(nbuckets);
 654        cost = sizeof(*smap->buckets) * nbuckets + sizeof(*smap);
 655
 656        ret = bpf_map_charge_init(&smap->map.memory, cost);
 657        if (ret < 0) {
 658                kfree(smap);
 659                return ERR_PTR(ret);
 660        }
 661
 662        smap->buckets = kvcalloc(sizeof(*smap->buckets), nbuckets,
 663                                 GFP_USER | __GFP_NOWARN);
 664        if (!smap->buckets) {
 665                bpf_map_charge_finish(&smap->map.memory);
 666                kfree(smap);
 667                return ERR_PTR(-ENOMEM);
 668        }
 669
 670        for (i = 0; i < nbuckets; i++) {
 671                INIT_HLIST_HEAD(&smap->buckets[i].list);
 672                raw_spin_lock_init(&smap->buckets[i].lock);
 673        }
 674
 675        smap->elem_size = sizeof(struct bpf_sk_storage_elem) + attr->value_size;
 676        smap->cache_idx = (unsigned int)atomic_inc_return(&cache_idx) %
 677                BPF_SK_STORAGE_CACHE_SIZE;
 678
 679        return &smap->map;
 680}
 681
 682static int notsupp_get_next_key(struct bpf_map *map, void *key,
 683                                void *next_key)
 684{
 685        return -ENOTSUPP;
 686}
 687
 688static int bpf_sk_storage_map_check_btf(const struct bpf_map *map,
 689                                        const struct btf *btf,
 690                                        const struct btf_type *key_type,
 691                                        const struct btf_type *value_type)
 692{
 693        u32 int_data;
 694
 695        if (BTF_INFO_KIND(key_type->info) != BTF_KIND_INT)
 696                return -EINVAL;
 697
 698        int_data = *(u32 *)(key_type + 1);
 699        if (BTF_INT_BITS(int_data) != 32 || BTF_INT_OFFSET(int_data))
 700                return -EINVAL;
 701
 702        return 0;
 703}
 704
 705static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
 706{
 707        struct bpf_sk_storage_data *sdata;
 708        struct socket *sock;
 709        int fd, err;
 710
 711        fd = *(int *)key;
 712        sock = sockfd_lookup(fd, &err);
 713        if (sock) {
 714                sdata = sk_storage_lookup(sock->sk, map, true);
 715                sockfd_put(sock);
 716                return sdata ? sdata->data : NULL;
 717        }
 718
 719        return ERR_PTR(err);
 720}
 721
 722static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key,
 723                                         void *value, u64 map_flags)
 724{
 725        struct bpf_sk_storage_data *sdata;
 726        struct socket *sock;
 727        int fd, err;
 728
 729        fd = *(int *)key;
 730        sock = sockfd_lookup(fd, &err);
 731        if (sock) {
 732                sdata = sk_storage_update(sock->sk, map, value, map_flags);
 733                sockfd_put(sock);
 734                return PTR_ERR_OR_ZERO(sdata);
 735        }
 736
 737        return err;
 738}
 739
 740static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
 741{
 742        struct socket *sock;
 743        int fd, err;
 744
 745        fd = *(int *)key;
 746        sock = sockfd_lookup(fd, &err);
 747        if (sock) {
 748                err = sk_storage_delete(sock->sk, map);
 749                sockfd_put(sock);
 750                return err;
 751        }
 752
 753        return err;
 754}
 755
 756static struct bpf_sk_storage_elem *
 757bpf_sk_storage_clone_elem(struct sock *newsk,
 758                          struct bpf_sk_storage_map *smap,
 759                          struct bpf_sk_storage_elem *selem)
 760{
 761        struct bpf_sk_storage_elem *copy_selem;
 762
 763        copy_selem = selem_alloc(smap, newsk, NULL, true);
 764        if (!copy_selem)
 765                return NULL;
 766
 767        if (map_value_has_spin_lock(&smap->map))
 768                copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
 769                                      SDATA(selem)->data, true);
 770        else
 771                copy_map_value(&smap->map, SDATA(copy_selem)->data,
 772                               SDATA(selem)->data);
 773
 774        return copy_selem;
 775}
 776
 777int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
 778{
 779        struct bpf_sk_storage *new_sk_storage = NULL;
 780        struct bpf_sk_storage *sk_storage;
 781        struct bpf_sk_storage_elem *selem;
 782        int ret = 0;
 783
 784        RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
 785
 786        rcu_read_lock();
 787        sk_storage = rcu_dereference(sk->sk_bpf_storage);
 788
 789        if (!sk_storage || hlist_empty(&sk_storage->list))
 790                goto out;
 791
 792        hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
 793                struct bpf_sk_storage_elem *copy_selem;
 794                struct bpf_sk_storage_map *smap;
 795                struct bpf_map *map;
 796
 797                smap = rcu_dereference(SDATA(selem)->smap);
 798                if (!(smap->map.map_flags & BPF_F_CLONE))
 799                        continue;
 800
 801                /* Note that for lockless listeners adding new element
 802                 * here can race with cleanup in bpf_sk_storage_map_free.
 803                 * Try to grab map refcnt to make sure that it's still
 804                 * alive and prevent concurrent removal.
 805                 */
 806                map = bpf_map_inc_not_zero(&smap->map);
 807                if (IS_ERR(map))
 808                        continue;
 809
 810                copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
 811                if (!copy_selem) {
 812                        ret = -ENOMEM;
 813                        bpf_map_put(map);
 814                        goto out;
 815                }
 816
 817                if (new_sk_storage) {
 818                        selem_link_map(smap, copy_selem);
 819                        __selem_link_sk(new_sk_storage, copy_selem);
 820                } else {
 821                        ret = sk_storage_alloc(newsk, smap, copy_selem);
 822                        if (ret) {
 823                                kfree(copy_selem);
 824                                atomic_sub(smap->elem_size,
 825                                           &newsk->sk_omem_alloc);
 826                                bpf_map_put(map);
 827                                goto out;
 828                        }
 829
 830                        new_sk_storage = rcu_dereference(copy_selem->sk_storage);
 831                }
 832                bpf_map_put(map);
 833        }
 834
 835out:
 836        rcu_read_unlock();
 837
 838        /* In case of an error, don't free anything explicitly here, the
 839         * caller is responsible to call bpf_sk_storage_free.
 840         */
 841
 842        return ret;
 843}
 844
 845BPF_CALL_4(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
 846           void *, value, u64, flags)
 847{
 848        struct bpf_sk_storage_data *sdata;
 849
 850        if (flags > BPF_SK_STORAGE_GET_F_CREATE)
 851                return (unsigned long)NULL;
 852
 853        sdata = sk_storage_lookup(sk, map, true);
 854        if (sdata)
 855                return (unsigned long)sdata->data;
 856
 857        if (flags == BPF_SK_STORAGE_GET_F_CREATE &&
 858            /* Cannot add new elem to a going away sk.
 859             * Otherwise, the new elem may become a leak
 860             * (and also other memory issues during map
 861             *  destruction).
 862             */
 863            refcount_inc_not_zero(&sk->sk_refcnt)) {
 864                sdata = sk_storage_update(sk, map, value, BPF_NOEXIST);
 865                /* sk must be a fullsock (guaranteed by verifier),
 866                 * so sock_gen_put() is unnecessary.
 867                 */
 868                sock_put(sk);
 869                return IS_ERR(sdata) ?
 870                        (unsigned long)NULL : (unsigned long)sdata->data;
 871        }
 872
 873        return (unsigned long)NULL;
 874}
 875
 876BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
 877{
 878        if (refcount_inc_not_zero(&sk->sk_refcnt)) {
 879                int err;
 880
 881                err = sk_storage_delete(sk, map);
 882                sock_put(sk);
 883                return err;
 884        }
 885
 886        return -ENOENT;
 887}
 888
 889const struct bpf_map_ops sk_storage_map_ops = {
 890        .map_alloc_check = bpf_sk_storage_map_alloc_check,
 891        .map_alloc = bpf_sk_storage_map_alloc,
 892        .map_free = bpf_sk_storage_map_free,
 893        .map_get_next_key = notsupp_get_next_key,
 894        .map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
 895        .map_update_elem = bpf_fd_sk_storage_update_elem,
 896        .map_delete_elem = bpf_fd_sk_storage_delete_elem,
 897        .map_check_btf = bpf_sk_storage_map_check_btf,
 898};
 899
 900const struct bpf_func_proto bpf_sk_storage_get_proto = {
 901        .func           = bpf_sk_storage_get,
 902        .gpl_only       = false,
 903        .ret_type       = RET_PTR_TO_MAP_VALUE_OR_NULL,
 904        .arg1_type      = ARG_CONST_MAP_PTR,
 905        .arg2_type      = ARG_PTR_TO_SOCKET,
 906        .arg3_type      = ARG_PTR_TO_MAP_VALUE_OR_NULL,
 907        .arg4_type      = ARG_ANYTHING,
 908};
 909
 910const struct bpf_func_proto bpf_sk_storage_delete_proto = {
 911        .func           = bpf_sk_storage_delete,
 912        .gpl_only       = false,
 913        .ret_type       = RET_INTEGER,
 914        .arg1_type      = ARG_CONST_MAP_PTR,
 915        .arg2_type      = ARG_PTR_TO_SOCKET,
 916};
 917
 918struct bpf_sk_storage_diag {
 919        u32 nr_maps;
 920        struct bpf_map *maps[];
 921};
 922
 923/* The reply will be like:
 924 * INET_DIAG_BPF_SK_STORAGES (nla_nest)
 925 *      SK_DIAG_BPF_STORAGE (nla_nest)
 926 *              SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
 927 *              SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
 928 *      SK_DIAG_BPF_STORAGE (nla_nest)
 929 *              SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
 930 *              SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
 931 *      ....
 932 */
 933static int nla_value_size(u32 value_size)
 934{
 935        /* SK_DIAG_BPF_STORAGE (nla_nest)
 936         *      SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
 937         *      SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
 938         */
 939        return nla_total_size(0) + nla_total_size(sizeof(u32)) +
 940                nla_total_size_64bit(value_size);
 941}
 942
 943void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
 944{
 945        u32 i;
 946
 947        if (!diag)
 948                return;
 949
 950        for (i = 0; i < diag->nr_maps; i++)
 951                bpf_map_put(diag->maps[i]);
 952
 953        kfree(diag);
 954}
 955EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
 956
 957static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
 958                           const struct bpf_map *map)
 959{
 960        u32 i;
 961
 962        for (i = 0; i < diag->nr_maps; i++) {
 963                if (diag->maps[i] == map)
 964                        return true;
 965        }
 966
 967        return false;
 968}
 969
 970struct bpf_sk_storage_diag *
 971bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
 972{
 973        struct bpf_sk_storage_diag *diag;
 974        struct nlattr *nla;
 975        u32 nr_maps = 0;
 976        int rem, err;
 977
 978        /* bpf_sk_storage_map is currently limited to CAP_SYS_ADMIN as
 979         * the map_alloc_check() side also does.
 980         */
 981        if (!capable(CAP_SYS_ADMIN))
 982                return ERR_PTR(-EPERM);
 983
 984        nla_for_each_nested(nla, nla_stgs, rem) {
 985                if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
 986                        nr_maps++;
 987        }
 988
 989        diag = kzalloc(sizeof(*diag) + sizeof(diag->maps[0]) * nr_maps,
 990                       GFP_KERNEL);
 991        if (!diag)
 992                return ERR_PTR(-ENOMEM);
 993
 994        nla_for_each_nested(nla, nla_stgs, rem) {
 995                struct bpf_map *map;
 996                int map_fd;
 997
 998                if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
 999                        continue;
1000
1001                map_fd = nla_get_u32(nla);
1002                map = bpf_map_get(map_fd);
1003                if (IS_ERR(map)) {
1004                        err = PTR_ERR(map);
1005                        goto err_free;
1006                }
1007                if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
1008                        bpf_map_put(map);
1009                        err = -EINVAL;
1010                        goto err_free;
1011                }
1012                if (diag_check_dup(diag, map)) {
1013                        bpf_map_put(map);
1014                        err = -EEXIST;
1015                        goto err_free;
1016                }
1017                diag->maps[diag->nr_maps++] = map;
1018        }
1019
1020        return diag;
1021
1022err_free:
1023        bpf_sk_storage_diag_free(diag);
1024        return ERR_PTR(err);
1025}
1026EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
1027
1028static int diag_get(struct bpf_sk_storage_data *sdata, struct sk_buff *skb)
1029{
1030        struct nlattr *nla_stg, *nla_value;
1031        struct bpf_sk_storage_map *smap;
1032
1033        /* It cannot exceed max nlattr's payload */
1034        BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < MAX_VALUE_SIZE);
1035
1036        nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
1037        if (!nla_stg)
1038                return -EMSGSIZE;
1039
1040        smap = rcu_dereference(sdata->smap);
1041        if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
1042                goto errout;
1043
1044        nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
1045                                      smap->map.value_size,
1046                                      SK_DIAG_BPF_STORAGE_PAD);
1047        if (!nla_value)
1048                goto errout;
1049
1050        if (map_value_has_spin_lock(&smap->map))
1051                copy_map_value_locked(&smap->map, nla_data(nla_value),
1052                                      sdata->data, true);
1053        else
1054                copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
1055
1056        nla_nest_end(skb, nla_stg);
1057        return 0;
1058
1059errout:
1060        nla_nest_cancel(skb, nla_stg);
1061        return -EMSGSIZE;
1062}
1063
1064static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
1065                                       int stg_array_type,
1066                                       unsigned int *res_diag_size)
1067{
1068        /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1069        unsigned int diag_size = nla_total_size(0);
1070        struct bpf_sk_storage *sk_storage;
1071        struct bpf_sk_storage_elem *selem;
1072        struct bpf_sk_storage_map *smap;
1073        struct nlattr *nla_stgs;
1074        unsigned int saved_len;
1075        int err = 0;
1076
1077        rcu_read_lock();
1078
1079        sk_storage = rcu_dereference(sk->sk_bpf_storage);
1080        if (!sk_storage || hlist_empty(&sk_storage->list)) {
1081                rcu_read_unlock();
1082                return 0;
1083        }
1084
1085        nla_stgs = nla_nest_start(skb, stg_array_type);
1086        if (!nla_stgs)
1087                /* Continue to learn diag_size */
1088                err = -EMSGSIZE;
1089
1090        saved_len = skb->len;
1091        hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
1092                smap = rcu_dereference(SDATA(selem)->smap);
1093                diag_size += nla_value_size(smap->map.value_size);
1094
1095                if (nla_stgs && diag_get(SDATA(selem), skb))
1096                        /* Continue to learn diag_size */
1097                        err = -EMSGSIZE;
1098        }
1099
1100        rcu_read_unlock();
1101
1102        if (nla_stgs) {
1103                if (saved_len == skb->len)
1104                        nla_nest_cancel(skb, nla_stgs);
1105                else
1106                        nla_nest_end(skb, nla_stgs);
1107        }
1108
1109        if (diag_size == nla_total_size(0)) {
1110                *res_diag_size = 0;
1111                return 0;
1112        }
1113
1114        *res_diag_size = diag_size;
1115        return err;
1116}
1117
1118int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
1119                            struct sock *sk, struct sk_buff *skb,
1120                            int stg_array_type,
1121                            unsigned int *res_diag_size)
1122{
1123        /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1124        unsigned int diag_size = nla_total_size(0);
1125        struct bpf_sk_storage *sk_storage;
1126        struct bpf_sk_storage_data *sdata;
1127        struct nlattr *nla_stgs;
1128        unsigned int saved_len;
1129        int err = 0;
1130        u32 i;
1131
1132        *res_diag_size = 0;
1133
1134        /* No map has been specified.  Dump all. */
1135        if (!diag->nr_maps)
1136                return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
1137                                                   res_diag_size);
1138
1139        rcu_read_lock();
1140        sk_storage = rcu_dereference(sk->sk_bpf_storage);
1141        if (!sk_storage || hlist_empty(&sk_storage->list)) {
1142                rcu_read_unlock();
1143                return 0;
1144        }
1145
1146        nla_stgs = nla_nest_start(skb, stg_array_type);
1147        if (!nla_stgs)
1148                /* Continue to learn diag_size */
1149                err = -EMSGSIZE;
1150
1151        saved_len = skb->len;
1152        for (i = 0; i < diag->nr_maps; i++) {
1153                sdata = __sk_storage_lookup(sk_storage,
1154                                (struct bpf_sk_storage_map *)diag->maps[i],
1155                                false);
1156
1157                if (!sdata)
1158                        continue;
1159
1160                diag_size += nla_value_size(diag->maps[i]->value_size);
1161
1162                if (nla_stgs && diag_get(sdata, skb))
1163                        /* Continue to learn diag_size */
1164                        err = -EMSGSIZE;
1165        }
1166        rcu_read_unlock();
1167
1168        if (nla_stgs) {
1169                if (saved_len == skb->len)
1170                        nla_nest_cancel(skb, nla_stgs);
1171                else
1172                        nla_nest_end(skb, nla_stgs);
1173        }
1174
1175        if (diag_size == nla_total_size(0)) {
1176                *res_diag_size = 0;
1177                return 0;
1178        }
1179
1180        *res_diag_size = diag_size;
1181        return err;
1182}
1183EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);
1184