linux/drivers/net/wireguard/ratelimiter.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
   4 */
   5
   6#include "ratelimiter.h"
   7#include <linux/siphash.h>
   8#include <linux/mm.h>
   9#include <linux/slab.h>
  10#include <net/ip.h>
  11
  12static struct kmem_cache *entry_cache;
  13static hsiphash_key_t key;
  14static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
  15static DEFINE_MUTEX(init_lock);
  16static u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
  17static atomic_t total_entries = ATOMIC_INIT(0);
  18static unsigned int max_entries, table_size;
  19static void wg_ratelimiter_gc_entries(struct work_struct *);
  20static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
  21static struct hlist_head *table_v4;
  22#if IS_ENABLED(CONFIG_IPV6)
  23static struct hlist_head *table_v6;
  24#endif
  25
  26struct ratelimiter_entry {
  27        u64 last_time_ns, tokens, ip;
  28        void *net;
  29        spinlock_t lock;
  30        struct hlist_node hash;
  31        struct rcu_head rcu;
  32};
  33
  34enum {
  35        PACKETS_PER_SECOND = 20,
  36        PACKETS_BURSTABLE = 5,
  37        PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
  38        TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
  39};
  40
  41static void entry_free(struct rcu_head *rcu)
  42{
  43        kmem_cache_free(entry_cache,
  44                        container_of(rcu, struct ratelimiter_entry, rcu));
  45        atomic_dec(&total_entries);
  46}
  47
  48static void entry_uninit(struct ratelimiter_entry *entry)
  49{
  50        hlist_del_rcu(&entry->hash);
  51        call_rcu(&entry->rcu, entry_free);
  52}
  53
  54/* Calling this function with a NULL work uninits all entries. */
  55static void wg_ratelimiter_gc_entries(struct work_struct *work)
  56{
  57        const u64 now = ktime_get_coarse_boottime_ns();
  58        struct ratelimiter_entry *entry;
  59        struct hlist_node *temp;
  60        unsigned int i;
  61
  62        for (i = 0; i < table_size; ++i) {
  63                spin_lock(&table_lock);
  64                hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
  65                        if (unlikely(!work) ||
  66                            now - entry->last_time_ns > NSEC_PER_SEC)
  67                                entry_uninit(entry);
  68                }
  69#if IS_ENABLED(CONFIG_IPV6)
  70                hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
  71                        if (unlikely(!work) ||
  72                            now - entry->last_time_ns > NSEC_PER_SEC)
  73                                entry_uninit(entry);
  74                }
  75#endif
  76                spin_unlock(&table_lock);
  77                if (likely(work))
  78                        cond_resched();
  79        }
  80        if (likely(work))
  81                queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
  82}
  83
  84bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
  85{
  86        /* We only take the bottom half of the net pointer, so that we can hash
  87         * 3 words in the end. This way, siphash's len param fits into the final
  88         * u32, and we don't incur an extra round.
  89         */
  90        const u32 net_word = (unsigned long)net;
  91        struct ratelimiter_entry *entry;
  92        struct hlist_head *bucket;
  93        u64 ip;
  94
  95        if (skb->protocol == htons(ETH_P_IP)) {
  96                ip = (u64 __force)ip_hdr(skb)->saddr;
  97                bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
  98                                   (table_size - 1)];
  99        }
 100#if IS_ENABLED(CONFIG_IPV6)
 101        else if (skb->protocol == htons(ETH_P_IPV6)) {
 102                /* Only use 64 bits, so as to ratelimit the whole /64. */
 103                memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
 104                bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
 105                                   (table_size - 1)];
 106        }
 107#endif
 108        else
 109                return false;
 110        rcu_read_lock();
 111        hlist_for_each_entry_rcu(entry, bucket, hash) {
 112                if (entry->net == net && entry->ip == ip) {
 113                        u64 now, tokens;
 114                        bool ret;
 115                        /* Quasi-inspired by nft_limit.c, but this is actually a
 116                         * slightly different algorithm. Namely, we incorporate
 117                         * the burst as part of the maximum tokens, rather than
 118                         * as part of the rate.
 119                         */
 120                        spin_lock(&entry->lock);
 121                        now = ktime_get_coarse_boottime_ns();
 122                        tokens = min_t(u64, TOKEN_MAX,
 123                                       entry->tokens + now -
 124                                               entry->last_time_ns);
 125                        entry->last_time_ns = now;
 126                        ret = tokens >= PACKET_COST;
 127                        entry->tokens = ret ? tokens - PACKET_COST : tokens;
 128                        spin_unlock(&entry->lock);
 129                        rcu_read_unlock();
 130                        return ret;
 131                }
 132        }
 133        rcu_read_unlock();
 134
 135        if (atomic_inc_return(&total_entries) > max_entries)
 136                goto err_oom;
 137
 138        entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
 139        if (unlikely(!entry))
 140                goto err_oom;
 141
 142        entry->net = net;
 143        entry->ip = ip;
 144        INIT_HLIST_NODE(&entry->hash);
 145        spin_lock_init(&entry->lock);
 146        entry->last_time_ns = ktime_get_coarse_boottime_ns();
 147        entry->tokens = TOKEN_MAX - PACKET_COST;
 148        spin_lock(&table_lock);
 149        hlist_add_head_rcu(&entry->hash, bucket);
 150        spin_unlock(&table_lock);
 151        return true;
 152
 153err_oom:
 154        atomic_dec(&total_entries);
 155        return false;
 156}
 157
 158int wg_ratelimiter_init(void)
 159{
 160        mutex_lock(&init_lock);
 161        if (++init_refcnt != 1)
 162                goto out;
 163
 164        entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
 165        if (!entry_cache)
 166                goto err;
 167
 168        /* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
 169         * but what it shares in common is that it uses a massive hashtable. So,
 170         * we borrow their wisdom about good table sizes on different systems
 171         * dependent on RAM. This calculation here comes from there.
 172         */
 173        table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
 174                max_t(unsigned long, 16, roundup_pow_of_two(
 175                        (totalram_pages() << PAGE_SHIFT) /
 176                        (1U << 14) / sizeof(struct hlist_head)));
 177        max_entries = table_size * 8;
 178
 179        table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
 180        if (unlikely(!table_v4))
 181                goto err_kmemcache;
 182
 183#if IS_ENABLED(CONFIG_IPV6)
 184        table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
 185        if (unlikely(!table_v6)) {
 186                kvfree(table_v4);
 187                goto err_kmemcache;
 188        }
 189#endif
 190
 191        queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
 192        get_random_bytes(&key, sizeof(key));
 193out:
 194        mutex_unlock(&init_lock);
 195        return 0;
 196
 197err_kmemcache:
 198        kmem_cache_destroy(entry_cache);
 199err:
 200        --init_refcnt;
 201        mutex_unlock(&init_lock);
 202        return -ENOMEM;
 203}
 204
 205void wg_ratelimiter_uninit(void)
 206{
 207        mutex_lock(&init_lock);
 208        if (!init_refcnt || --init_refcnt)
 209                goto out;
 210
 211        cancel_delayed_work_sync(&gc_work);
 212        wg_ratelimiter_gc_entries(NULL);
 213        rcu_barrier();
 214        kvfree(table_v4);
 215#if IS_ENABLED(CONFIG_IPV6)
 216        kvfree(table_v6);
 217#endif
 218        kmem_cache_destroy(entry_cache);
 219out:
 220        mutex_unlock(&init_lock);
 221}
 222
 223#include "selftest/ratelimiter.c"
 224