1
2
3
4
5
6
7
8
9
10
11
12#include <linux/bpf.h>
13#include <linux/err.h>
14#include <linux/slab.h>
15#include <linux/spinlock.h>
16#include <linux/vmalloc.h>
17#include <net/ipv6.h>
18
19
20#define LPM_TREE_NODE_FLAG_IM BIT(0)
21
22struct lpm_trie_node;
23
24struct lpm_trie_node {
25 struct rcu_head rcu;
26 struct lpm_trie_node __rcu *child[2];
27 u32 prefixlen;
28 u32 flags;
29 u8 data[0];
30};
31
32struct lpm_trie {
33 struct bpf_map map;
34 struct lpm_trie_node __rcu *root;
35 size_t n_entries;
36 size_t max_prefixlen;
37 size_t data_size;
38 raw_spinlock_t lock;
39};
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152static inline int extract_bit(const u8 *data, size_t index)
153{
154 return !!(data[index / 8] & (1 << (7 - (index % 8))));
155}
156
157
158
159
160
161
162
163
164
165static size_t longest_prefix_match(const struct lpm_trie *trie,
166 const struct lpm_trie_node *node,
167 const struct bpf_lpm_trie_key *key)
168{
169 size_t prefixlen = 0;
170 size_t i;
171
172 for (i = 0; i < trie->data_size; i++) {
173 size_t b;
174
175 b = 8 - fls(node->data[i] ^ key->data[i]);
176 prefixlen += b;
177
178 if (prefixlen >= node->prefixlen || prefixlen >= key->prefixlen)
179 return min(node->prefixlen, key->prefixlen);
180
181 if (b < 8)
182 break;
183 }
184
185 return prefixlen;
186}
187
188
189static void *trie_lookup_elem(struct bpf_map *map, void *_key)
190{
191 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
192 struct lpm_trie_node *node, *found = NULL;
193 struct bpf_lpm_trie_key *key = _key;
194
195
196
197 for (node = rcu_dereference(trie->root); node;) {
198 unsigned int next_bit;
199 size_t matchlen;
200
201
202
203
204
205 matchlen = longest_prefix_match(trie, node, key);
206 if (matchlen == trie->max_prefixlen) {
207 found = node;
208 break;
209 }
210
211
212
213
214
215 if (matchlen < node->prefixlen)
216 break;
217
218
219
220
221 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
222 found = node;
223
224
225
226
227
228 next_bit = extract_bit(key->data, node->prefixlen);
229 node = rcu_dereference(node->child[next_bit]);
230 }
231
232 if (!found)
233 return NULL;
234
235 return found->data + trie->data_size;
236}
237
238static struct lpm_trie_node *lpm_trie_node_alloc(const struct lpm_trie *trie,
239 const void *value)
240{
241 struct lpm_trie_node *node;
242 size_t size = sizeof(struct lpm_trie_node) + trie->data_size;
243
244 if (value)
245 size += trie->map.value_size;
246
247 node = kmalloc_node(size, GFP_ATOMIC | __GFP_NOWARN,
248 trie->map.numa_node);
249 if (!node)
250 return NULL;
251
252 node->flags = 0;
253
254 if (value)
255 memcpy(node->data + trie->data_size, value,
256 trie->map.value_size);
257
258 return node;
259}
260
261
262static int trie_update_elem(struct bpf_map *map,
263 void *_key, void *value, u64 flags)
264{
265 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
266 struct lpm_trie_node *node, *im_node = NULL, *new_node = NULL;
267 struct lpm_trie_node __rcu **slot;
268 struct bpf_lpm_trie_key *key = _key;
269 unsigned long irq_flags;
270 unsigned int next_bit;
271 size_t matchlen = 0;
272 int ret = 0;
273
274 if (unlikely(flags > BPF_EXIST))
275 return -EINVAL;
276
277 if (key->prefixlen > trie->max_prefixlen)
278 return -EINVAL;
279
280 raw_spin_lock_irqsave(&trie->lock, irq_flags);
281
282
283
284 if (trie->n_entries == trie->map.max_entries) {
285 ret = -ENOSPC;
286 goto out;
287 }
288
289 new_node = lpm_trie_node_alloc(trie, value);
290 if (!new_node) {
291 ret = -ENOMEM;
292 goto out;
293 }
294
295 trie->n_entries++;
296
297 new_node->prefixlen = key->prefixlen;
298 RCU_INIT_POINTER(new_node->child[0], NULL);
299 RCU_INIT_POINTER(new_node->child[1], NULL);
300 memcpy(new_node->data, key->data, trie->data_size);
301
302
303
304
305
306
307 slot = &trie->root;
308
309 while ((node = rcu_dereference_protected(*slot,
310 lockdep_is_held(&trie->lock)))) {
311 matchlen = longest_prefix_match(trie, node, key);
312
313 if (node->prefixlen != matchlen ||
314 node->prefixlen == key->prefixlen ||
315 node->prefixlen == trie->max_prefixlen)
316 break;
317
318 next_bit = extract_bit(key->data, node->prefixlen);
319 slot = &node->child[next_bit];
320 }
321
322
323
324
325 if (!node) {
326 rcu_assign_pointer(*slot, new_node);
327 goto out;
328 }
329
330
331
332
333 if (node->prefixlen == matchlen) {
334 new_node->child[0] = node->child[0];
335 new_node->child[1] = node->child[1];
336
337 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
338 trie->n_entries--;
339
340 rcu_assign_pointer(*slot, new_node);
341 kfree_rcu(node, rcu);
342
343 goto out;
344 }
345
346
347
348
349 if (matchlen == key->prefixlen) {
350 next_bit = extract_bit(node->data, matchlen);
351 rcu_assign_pointer(new_node->child[next_bit], node);
352 rcu_assign_pointer(*slot, new_node);
353 goto out;
354 }
355
356 im_node = lpm_trie_node_alloc(trie, NULL);
357 if (!im_node) {
358 ret = -ENOMEM;
359 goto out;
360 }
361
362 im_node->prefixlen = matchlen;
363 im_node->flags |= LPM_TREE_NODE_FLAG_IM;
364 memcpy(im_node->data, node->data, trie->data_size);
365
366
367 if (extract_bit(key->data, matchlen)) {
368 rcu_assign_pointer(im_node->child[0], node);
369 rcu_assign_pointer(im_node->child[1], new_node);
370 } else {
371 rcu_assign_pointer(im_node->child[0], new_node);
372 rcu_assign_pointer(im_node->child[1], node);
373 }
374
375
376 rcu_assign_pointer(*slot, im_node);
377
378out:
379 if (ret) {
380 if (new_node)
381 trie->n_entries--;
382
383 kfree(new_node);
384 kfree(im_node);
385 }
386
387 raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
388
389 return ret;
390}
391
392
393static int trie_delete_elem(struct bpf_map *map, void *_key)
394{
395 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
396 struct bpf_lpm_trie_key *key = _key;
397 struct lpm_trie_node __rcu **trim, **trim2;
398 struct lpm_trie_node *node, *parent;
399 unsigned long irq_flags;
400 unsigned int next_bit;
401 size_t matchlen = 0;
402 int ret = 0;
403
404 if (key->prefixlen > trie->max_prefixlen)
405 return -EINVAL;
406
407 raw_spin_lock_irqsave(&trie->lock, irq_flags);
408
409
410
411
412
413
414
415 trim = &trie->root;
416 trim2 = trim;
417 parent = NULL;
418 while ((node = rcu_dereference_protected(
419 *trim, lockdep_is_held(&trie->lock)))) {
420 matchlen = longest_prefix_match(trie, node, key);
421
422 if (node->prefixlen != matchlen ||
423 node->prefixlen == key->prefixlen)
424 break;
425
426 parent = node;
427 trim2 = trim;
428 next_bit = extract_bit(key->data, node->prefixlen);
429 trim = &node->child[next_bit];
430 }
431
432 if (!node || node->prefixlen != key->prefixlen ||
433 (node->flags & LPM_TREE_NODE_FLAG_IM)) {
434 ret = -ENOENT;
435 goto out;
436 }
437
438 trie->n_entries--;
439
440
441
442
443 if (rcu_access_pointer(node->child[0]) &&
444 rcu_access_pointer(node->child[1])) {
445 node->flags |= LPM_TREE_NODE_FLAG_IM;
446 goto out;
447 }
448
449
450
451
452
453
454
455
456 if (parent && (parent->flags & LPM_TREE_NODE_FLAG_IM) &&
457 !node->child[0] && !node->child[1]) {
458 if (node == rcu_access_pointer(parent->child[0]))
459 rcu_assign_pointer(
460 *trim2, rcu_access_pointer(parent->child[1]));
461 else
462 rcu_assign_pointer(
463 *trim2, rcu_access_pointer(parent->child[0]));
464 kfree_rcu(parent, rcu);
465 kfree_rcu(node, rcu);
466 goto out;
467 }
468
469
470
471
472
473 if (node->child[0])
474 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[0]));
475 else if (node->child[1])
476 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[1]));
477 else
478 RCU_INIT_POINTER(*trim, NULL);
479 kfree_rcu(node, rcu);
480
481out:
482 raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
483
484 return ret;
485}
486
487#define LPM_DATA_SIZE_MAX 256
488#define LPM_DATA_SIZE_MIN 1
489
490#define LPM_VAL_SIZE_MAX (KMALLOC_MAX_SIZE - LPM_DATA_SIZE_MAX - \
491 sizeof(struct lpm_trie_node))
492#define LPM_VAL_SIZE_MIN 1
493
494#define LPM_KEY_SIZE(X) (sizeof(struct bpf_lpm_trie_key) + (X))
495#define LPM_KEY_SIZE_MAX LPM_KEY_SIZE(LPM_DATA_SIZE_MAX)
496#define LPM_KEY_SIZE_MIN LPM_KEY_SIZE(LPM_DATA_SIZE_MIN)
497
498#define LPM_CREATE_FLAG_MASK (BPF_F_NO_PREALLOC | BPF_F_NUMA_NODE | \
499 BPF_F_RDONLY | BPF_F_WRONLY)
500
501static struct bpf_map *trie_alloc(union bpf_attr *attr)
502{
503 struct lpm_trie *trie;
504 u64 cost = sizeof(*trie), cost_per_node;
505 int ret;
506
507 if (!capable(CAP_SYS_ADMIN))
508 return ERR_PTR(-EPERM);
509
510
511 if (attr->max_entries == 0 ||
512 !(attr->map_flags & BPF_F_NO_PREALLOC) ||
513 attr->map_flags & ~LPM_CREATE_FLAG_MASK ||
514 attr->key_size < LPM_KEY_SIZE_MIN ||
515 attr->key_size > LPM_KEY_SIZE_MAX ||
516 attr->value_size < LPM_VAL_SIZE_MIN ||
517 attr->value_size > LPM_VAL_SIZE_MAX)
518 return ERR_PTR(-EINVAL);
519
520 trie = kzalloc(sizeof(*trie), GFP_USER | __GFP_NOWARN);
521 if (!trie)
522 return ERR_PTR(-ENOMEM);
523
524
525 bpf_map_init_from_attr(&trie->map, attr);
526 trie->data_size = attr->key_size -
527 offsetof(struct bpf_lpm_trie_key, data);
528 trie->max_prefixlen = trie->data_size * 8;
529
530 cost_per_node = sizeof(struct lpm_trie_node) +
531 attr->value_size + trie->data_size;
532 cost += (u64) attr->max_entries * cost_per_node;
533 if (cost >= U32_MAX - PAGE_SIZE) {
534 ret = -E2BIG;
535 goto out_err;
536 }
537
538 trie->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
539
540 ret = bpf_map_precharge_memlock(trie->map.pages);
541 if (ret)
542 goto out_err;
543
544 raw_spin_lock_init(&trie->lock);
545
546 return &trie->map;
547out_err:
548 kfree(trie);
549 return ERR_PTR(ret);
550}
551
552static void trie_free(struct bpf_map *map)
553{
554 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
555 struct lpm_trie_node __rcu **slot;
556 struct lpm_trie_node *node;
557
558
559
560
561 synchronize_rcu();
562
563
564
565
566
567
568 for (;;) {
569 slot = &trie->root;
570
571 for (;;) {
572 node = rcu_dereference_protected(*slot, 1);
573 if (!node)
574 goto out;
575
576 if (rcu_access_pointer(node->child[0])) {
577 slot = &node->child[0];
578 continue;
579 }
580
581 if (rcu_access_pointer(node->child[1])) {
582 slot = &node->child[1];
583 continue;
584 }
585
586 kfree(node);
587 RCU_INIT_POINTER(*slot, NULL);
588 break;
589 }
590 }
591
592out:
593 kfree(trie);
594}
595
596static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
597{
598 struct lpm_trie_node *node, *next_node = NULL, *parent, *search_root;
599 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
600 struct bpf_lpm_trie_key *key = _key, *next_key = _next_key;
601 struct lpm_trie_node **node_stack = NULL;
602 int err = 0, stack_ptr = -1;
603 unsigned int next_bit;
604 size_t matchlen;
605
606
607
608
609
610
611
612
613
614
615
616
617
618 search_root = rcu_dereference(trie->root);
619 if (!search_root)
620 return -ENOENT;
621
622
623 if (!key || key->prefixlen > trie->max_prefixlen)
624 goto find_leftmost;
625
626 node_stack = kmalloc(trie->max_prefixlen * sizeof(struct lpm_trie_node *),
627 GFP_ATOMIC | __GFP_NOWARN);
628 if (!node_stack)
629 return -ENOMEM;
630
631
632 for (node = search_root; node;) {
633 node_stack[++stack_ptr] = node;
634 matchlen = longest_prefix_match(trie, node, key);
635 if (node->prefixlen != matchlen ||
636 node->prefixlen == key->prefixlen)
637 break;
638
639 next_bit = extract_bit(key->data, node->prefixlen);
640 node = rcu_dereference(node->child[next_bit]);
641 }
642 if (!node || node->prefixlen != key->prefixlen ||
643 (node->flags & LPM_TREE_NODE_FLAG_IM))
644 goto find_leftmost;
645
646
647
648
649 node = node_stack[stack_ptr];
650 while (stack_ptr > 0) {
651 parent = node_stack[stack_ptr - 1];
652 if (rcu_dereference(parent->child[0]) == node) {
653 search_root = rcu_dereference(parent->child[1]);
654 if (search_root)
655 goto find_leftmost;
656 }
657 if (!(parent->flags & LPM_TREE_NODE_FLAG_IM)) {
658 next_node = parent;
659 goto do_copy;
660 }
661
662 node = parent;
663 stack_ptr--;
664 }
665
666
667 err = -ENOENT;
668 goto free_stack;
669
670find_leftmost:
671
672
673
674 for (node = search_root; node;) {
675 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
676 next_node = node;
677 node = rcu_dereference(node->child[0]);
678 }
679do_copy:
680 next_key->prefixlen = next_node->prefixlen;
681 memcpy((void *)next_key + offsetof(struct bpf_lpm_trie_key, data),
682 next_node->data, trie->data_size);
683free_stack:
684 kfree(node_stack);
685 return err;
686}
687
688const struct bpf_map_ops trie_map_ops = {
689 .map_alloc = trie_alloc,
690 .map_free = trie_free,
691 .map_get_next_key = trie_get_next_key,
692 .map_lookup_elem = trie_lookup_elem,
693 .map_update_elem = trie_update_elem,
694 .map_delete_elem = trie_delete_elem,
695};
696