1
2
3
4
5
6#include "ratelimiter.h"
7#include <linux/siphash.h>
8#include <linux/mm.h>
9#include <linux/slab.h>
10#include <net/ip.h>
11
12static struct kmem_cache *entry_cache;
13static hsiphash_key_t key;
14static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
15static DEFINE_MUTEX(init_lock);
16static u64 init_refcnt;
17static atomic_t total_entries = ATOMIC_INIT(0);
18static unsigned int max_entries, table_size;
19static void wg_ratelimiter_gc_entries(struct work_struct *);
20static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
21static struct hlist_head *table_v4;
22#if IS_ENABLED(CONFIG_IPV6)
23static struct hlist_head *table_v6;
24#endif
25
26struct ratelimiter_entry {
27 u64 last_time_ns, tokens, ip;
28 void *net;
29 spinlock_t lock;
30 struct hlist_node hash;
31 struct rcu_head rcu;
32};
33
34enum {
35 PACKETS_PER_SECOND = 20,
36 PACKETS_BURSTABLE = 5,
37 PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
38 TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
39};
40
41static void entry_free(struct rcu_head *rcu)
42{
43 kmem_cache_free(entry_cache,
44 container_of(rcu, struct ratelimiter_entry, rcu));
45 atomic_dec(&total_entries);
46}
47
48static void entry_uninit(struct ratelimiter_entry *entry)
49{
50 hlist_del_rcu(&entry->hash);
51 call_rcu(&entry->rcu, entry_free);
52}
53
54
55static void wg_ratelimiter_gc_entries(struct work_struct *work)
56{
57 const u64 now = ktime_get_coarse_boottime_ns();
58 struct ratelimiter_entry *entry;
59 struct hlist_node *temp;
60 unsigned int i;
61
62 for (i = 0; i < table_size; ++i) {
63 spin_lock(&table_lock);
64 hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
65 if (unlikely(!work) ||
66 now - entry->last_time_ns > NSEC_PER_SEC)
67 entry_uninit(entry);
68 }
69#if IS_ENABLED(CONFIG_IPV6)
70 hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
71 if (unlikely(!work) ||
72 now - entry->last_time_ns > NSEC_PER_SEC)
73 entry_uninit(entry);
74 }
75#endif
76 spin_unlock(&table_lock);
77 if (likely(work))
78 cond_resched();
79 }
80 if (likely(work))
81 queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
82}
83
84bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
85{
86
87
88
89
90 const u32 net_word = (unsigned long)net;
91 struct ratelimiter_entry *entry;
92 struct hlist_head *bucket;
93 u64 ip;
94
95 if (skb->protocol == htons(ETH_P_IP)) {
96 ip = (u64 __force)ip_hdr(skb)->saddr;
97 bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
98 (table_size - 1)];
99 }
100#if IS_ENABLED(CONFIG_IPV6)
101 else if (skb->protocol == htons(ETH_P_IPV6)) {
102
103 memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
104 bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
105 (table_size - 1)];
106 }
107#endif
108 else
109 return false;
110 rcu_read_lock();
111 hlist_for_each_entry_rcu(entry, bucket, hash) {
112 if (entry->net == net && entry->ip == ip) {
113 u64 now, tokens;
114 bool ret;
115
116
117
118
119
120 spin_lock(&entry->lock);
121 now = ktime_get_coarse_boottime_ns();
122 tokens = min_t(u64, TOKEN_MAX,
123 entry->tokens + now -
124 entry->last_time_ns);
125 entry->last_time_ns = now;
126 ret = tokens >= PACKET_COST;
127 entry->tokens = ret ? tokens - PACKET_COST : tokens;
128 spin_unlock(&entry->lock);
129 rcu_read_unlock();
130 return ret;
131 }
132 }
133 rcu_read_unlock();
134
135 if (atomic_inc_return(&total_entries) > max_entries)
136 goto err_oom;
137
138 entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
139 if (unlikely(!entry))
140 goto err_oom;
141
142 entry->net = net;
143 entry->ip = ip;
144 INIT_HLIST_NODE(&entry->hash);
145 spin_lock_init(&entry->lock);
146 entry->last_time_ns = ktime_get_coarse_boottime_ns();
147 entry->tokens = TOKEN_MAX - PACKET_COST;
148 spin_lock(&table_lock);
149 hlist_add_head_rcu(&entry->hash, bucket);
150 spin_unlock(&table_lock);
151 return true;
152
153err_oom:
154 atomic_dec(&total_entries);
155 return false;
156}
157
158int wg_ratelimiter_init(void)
159{
160 mutex_lock(&init_lock);
161 if (++init_refcnt != 1)
162 goto out;
163
164 entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
165 if (!entry_cache)
166 goto err;
167
168
169
170
171
172
173 table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
174 max_t(unsigned long, 16, roundup_pow_of_two(
175 (totalram_pages() << PAGE_SHIFT) /
176 (1U << 14) / sizeof(struct hlist_head)));
177 max_entries = table_size * 8;
178
179 table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
180 if (unlikely(!table_v4))
181 goto err_kmemcache;
182
183#if IS_ENABLED(CONFIG_IPV6)
184 table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
185 if (unlikely(!table_v6)) {
186 kvfree(table_v4);
187 goto err_kmemcache;
188 }
189#endif
190
191 queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
192 get_random_bytes(&key, sizeof(key));
193out:
194 mutex_unlock(&init_lock);
195 return 0;
196
197err_kmemcache:
198 kmem_cache_destroy(entry_cache);
199err:
200 --init_refcnt;
201 mutex_unlock(&init_lock);
202 return -ENOMEM;
203}
204
205void wg_ratelimiter_uninit(void)
206{
207 mutex_lock(&init_lock);
208 if (!init_refcnt || --init_refcnt)
209 goto out;
210
211 cancel_delayed_work_sync(&gc_work);
212 wg_ratelimiter_gc_entries(NULL);
213 rcu_barrier();
214 kvfree(table_v4);
215#if IS_ENABLED(CONFIG_IPV6)
216 kvfree(table_v6);
217#endif
218 kmem_cache_destroy(entry_cache);
219out:
220 mutex_unlock(&init_lock);
221}
222
223#include "selftest/ratelimiter.c"
224