1
2
3
4
5
6
7
8
9
10
11
12#include <linux/bpf.h>
13#include <linux/err.h>
14#include <linux/slab.h>
15#include <linux/spinlock.h>
16#include <linux/vmalloc.h>
17#include <net/ipv6.h>
18
19
20#define LPM_TREE_NODE_FLAG_IM BIT(0)
21
22struct lpm_trie_node;
23
24struct lpm_trie_node {
25 struct rcu_head rcu;
26 struct lpm_trie_node __rcu *child[2];
27 u32 prefixlen;
28 u32 flags;
29 u8 data[0];
30};
31
32struct lpm_trie {
33 struct bpf_map map;
34 struct lpm_trie_node __rcu *root;
35 size_t n_entries;
36 size_t max_prefixlen;
37 size_t data_size;
38 raw_spinlock_t lock;
39};
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152static inline int extract_bit(const u8 *data, size_t index)
153{
154 return !!(data[index / 8] & (1 << (7 - (index % 8))));
155}
156
157
158
159
160
161
162
163
164
165static size_t longest_prefix_match(const struct lpm_trie *trie,
166 const struct lpm_trie_node *node,
167 const struct bpf_lpm_trie_key *key)
168{
169 size_t prefixlen = 0;
170 size_t i;
171
172 for (i = 0; i < trie->data_size; i++) {
173 size_t b;
174
175 b = 8 - fls(node->data[i] ^ key->data[i]);
176 prefixlen += b;
177
178 if (prefixlen >= node->prefixlen || prefixlen >= key->prefixlen)
179 return min(node->prefixlen, key->prefixlen);
180
181 if (b < 8)
182 break;
183 }
184
185 return prefixlen;
186}
187
188
189static void *trie_lookup_elem(struct bpf_map *map, void *_key)
190{
191 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
192 struct lpm_trie_node *node, *found = NULL;
193 struct bpf_lpm_trie_key *key = _key;
194
195
196
197 for (node = rcu_dereference(trie->root); node;) {
198 unsigned int next_bit;
199 size_t matchlen;
200
201
202
203
204
205 matchlen = longest_prefix_match(trie, node, key);
206 if (matchlen == trie->max_prefixlen) {
207 found = node;
208 break;
209 }
210
211
212
213
214
215 if (matchlen < node->prefixlen)
216 break;
217
218
219
220
221 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
222 found = node;
223
224
225
226
227
228 next_bit = extract_bit(key->data, node->prefixlen);
229 node = rcu_dereference(node->child[next_bit]);
230 }
231
232 if (!found)
233 return NULL;
234
235 return found->data + trie->data_size;
236}
237
238static struct lpm_trie_node *lpm_trie_node_alloc(const struct lpm_trie *trie,
239 const void *value)
240{
241 struct lpm_trie_node *node;
242 size_t size = sizeof(struct lpm_trie_node) + trie->data_size;
243
244 if (value)
245 size += trie->map.value_size;
246
247 node = kmalloc_node(size, GFP_ATOMIC | __GFP_NOWARN,
248 trie->map.numa_node);
249 if (!node)
250 return NULL;
251
252 node->flags = 0;
253
254 if (value)
255 memcpy(node->data + trie->data_size, value,
256 trie->map.value_size);
257
258 return node;
259}
260
261
262static int trie_update_elem(struct bpf_map *map,
263 void *_key, void *value, u64 flags)
264{
265 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
266 struct lpm_trie_node *node, *im_node = NULL, *new_node = NULL;
267 struct lpm_trie_node __rcu **slot;
268 struct bpf_lpm_trie_key *key = _key;
269 unsigned long irq_flags;
270 unsigned int next_bit;
271 size_t matchlen = 0;
272 int ret = 0;
273
274 if (unlikely(flags > BPF_EXIST))
275 return -EINVAL;
276
277 if (key->prefixlen > trie->max_prefixlen)
278 return -EINVAL;
279
280 raw_spin_lock_irqsave(&trie->lock, irq_flags);
281
282
283
284 if (trie->n_entries == trie->map.max_entries) {
285 ret = -ENOSPC;
286 goto out;
287 }
288
289 new_node = lpm_trie_node_alloc(trie, value);
290 if (!new_node) {
291 ret = -ENOMEM;
292 goto out;
293 }
294
295 trie->n_entries++;
296
297 new_node->prefixlen = key->prefixlen;
298 RCU_INIT_POINTER(new_node->child[0], NULL);
299 RCU_INIT_POINTER(new_node->child[1], NULL);
300 memcpy(new_node->data, key->data, trie->data_size);
301
302
303
304
305
306
307 slot = &trie->root;
308
309 while ((node = rcu_dereference_protected(*slot,
310 lockdep_is_held(&trie->lock)))) {
311 matchlen = longest_prefix_match(trie, node, key);
312
313 if (node->prefixlen != matchlen ||
314 node->prefixlen == key->prefixlen ||
315 node->prefixlen == trie->max_prefixlen)
316 break;
317
318 next_bit = extract_bit(key->data, node->prefixlen);
319 slot = &node->child[next_bit];
320 }
321
322
323
324
325 if (!node) {
326 rcu_assign_pointer(*slot, new_node);
327 goto out;
328 }
329
330
331
332
333 if (node->prefixlen == matchlen) {
334 new_node->child[0] = node->child[0];
335 new_node->child[1] = node->child[1];
336
337 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
338 trie->n_entries--;
339
340 rcu_assign_pointer(*slot, new_node);
341 kfree_rcu(node, rcu);
342
343 goto out;
344 }
345
346
347
348
349 if (matchlen == key->prefixlen) {
350 next_bit = extract_bit(node->data, matchlen);
351 rcu_assign_pointer(new_node->child[next_bit], node);
352 rcu_assign_pointer(*slot, new_node);
353 goto out;
354 }
355
356 im_node = lpm_trie_node_alloc(trie, NULL);
357 if (!im_node) {
358 ret = -ENOMEM;
359 goto out;
360 }
361
362 im_node->prefixlen = matchlen;
363 im_node->flags |= LPM_TREE_NODE_FLAG_IM;
364 memcpy(im_node->data, node->data, trie->data_size);
365
366
367 if (extract_bit(key->data, matchlen)) {
368 rcu_assign_pointer(im_node->child[0], node);
369 rcu_assign_pointer(im_node->child[1], new_node);
370 } else {
371 rcu_assign_pointer(im_node->child[0], new_node);
372 rcu_assign_pointer(im_node->child[1], node);
373 }
374
375
376 rcu_assign_pointer(*slot, im_node);
377
378out:
379 if (ret) {
380 if (new_node)
381 trie->n_entries--;
382
383 kfree(new_node);
384 kfree(im_node);
385 }
386
387 raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
388
389 return ret;
390}
391
392static int trie_delete_elem(struct bpf_map *map, void *key)
393{
394
395 return -ENOSYS;
396}
397
398#define LPM_DATA_SIZE_MAX 256
399#define LPM_DATA_SIZE_MIN 1
400
401#define LPM_VAL_SIZE_MAX (KMALLOC_MAX_SIZE - LPM_DATA_SIZE_MAX - \
402 sizeof(struct lpm_trie_node))
403#define LPM_VAL_SIZE_MIN 1
404
405#define LPM_KEY_SIZE(X) (sizeof(struct bpf_lpm_trie_key) + (X))
406#define LPM_KEY_SIZE_MAX LPM_KEY_SIZE(LPM_DATA_SIZE_MAX)
407#define LPM_KEY_SIZE_MIN LPM_KEY_SIZE(LPM_DATA_SIZE_MIN)
408
409#define LPM_CREATE_FLAG_MASK (BPF_F_NO_PREALLOC | BPF_F_NUMA_NODE)
410
411static struct bpf_map *trie_alloc(union bpf_attr *attr)
412{
413 struct lpm_trie *trie;
414 u64 cost = sizeof(*trie), cost_per_node;
415 int ret;
416
417 if (!capable(CAP_SYS_ADMIN))
418 return ERR_PTR(-EPERM);
419
420
421 if (attr->max_entries == 0 ||
422 !(attr->map_flags & BPF_F_NO_PREALLOC) ||
423 attr->map_flags & ~LPM_CREATE_FLAG_MASK ||
424 attr->key_size < LPM_KEY_SIZE_MIN ||
425 attr->key_size > LPM_KEY_SIZE_MAX ||
426 attr->value_size < LPM_VAL_SIZE_MIN ||
427 attr->value_size > LPM_VAL_SIZE_MAX)
428 return ERR_PTR(-EINVAL);
429
430 trie = kzalloc(sizeof(*trie), GFP_USER | __GFP_NOWARN);
431 if (!trie)
432 return ERR_PTR(-ENOMEM);
433
434
435 trie->map.map_type = attr->map_type;
436 trie->map.key_size = attr->key_size;
437 trie->map.value_size = attr->value_size;
438 trie->map.max_entries = attr->max_entries;
439 trie->map.map_flags = attr->map_flags;
440 trie->map.numa_node = bpf_map_attr_numa_node(attr);
441 trie->data_size = attr->key_size -
442 offsetof(struct bpf_lpm_trie_key, data);
443 trie->max_prefixlen = trie->data_size * 8;
444
445 cost_per_node = sizeof(struct lpm_trie_node) +
446 attr->value_size + trie->data_size;
447 cost += (u64) attr->max_entries * cost_per_node;
448 if (cost >= U32_MAX - PAGE_SIZE) {
449 ret = -E2BIG;
450 goto out_err;
451 }
452
453 trie->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
454
455 ret = bpf_map_precharge_memlock(trie->map.pages);
456 if (ret)
457 goto out_err;
458
459 raw_spin_lock_init(&trie->lock);
460
461 return &trie->map;
462out_err:
463 kfree(trie);
464 return ERR_PTR(ret);
465}
466
467static void trie_free(struct bpf_map *map)
468{
469 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
470 struct lpm_trie_node __rcu **slot;
471 struct lpm_trie_node *node;
472
473 raw_spin_lock(&trie->lock);
474
475
476
477
478
479
480 for (;;) {
481 slot = &trie->root;
482
483 for (;;) {
484 node = rcu_dereference_protected(*slot,
485 lockdep_is_held(&trie->lock));
486 if (!node)
487 goto unlock;
488
489 if (rcu_access_pointer(node->child[0])) {
490 slot = &node->child[0];
491 continue;
492 }
493
494 if (rcu_access_pointer(node->child[1])) {
495 slot = &node->child[1];
496 continue;
497 }
498
499 kfree(node);
500 RCU_INIT_POINTER(*slot, NULL);
501 break;
502 }
503 }
504
505unlock:
506 raw_spin_unlock(&trie->lock);
507}
508
509static int trie_get_next_key(struct bpf_map *map, void *key, void *next_key)
510{
511 return -ENOTSUPP;
512}
513
514const struct bpf_map_ops trie_map_ops = {
515 .map_alloc = trie_alloc,
516 .map_free = trie_free,
517 .map_get_next_key = trie_get_next_key,
518 .map_lookup_elem = trie_lookup_elem,
519 .map_update_elem = trie_update_elem,
520 .map_delete_elem = trie_delete_elem,
521};
522