1
2
3
4
5
6
7
8
9
10
11
12#include <linux/bpf.h>
13#include <linux/btf.h>
14#include <linux/err.h>
15#include <linux/slab.h>
16#include <linux/spinlock.h>
17#include <linux/vmalloc.h>
18#include <net/ipv6.h>
19#include <uapi/linux/btf.h>
20
21
22#define LPM_TREE_NODE_FLAG_IM BIT(0)
23
24struct lpm_trie_node;
25
26struct lpm_trie_node {
27 struct rcu_head rcu;
28 struct lpm_trie_node __rcu *child[2];
29 u32 prefixlen;
30 u32 flags;
31 u8 data[0];
32};
33
34struct lpm_trie {
35 struct bpf_map map;
36 struct lpm_trie_node __rcu *root;
37 size_t n_entries;
38 size_t max_prefixlen;
39 size_t data_size;
40 raw_spinlock_t lock;
41};
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154static inline int extract_bit(const u8 *data, size_t index)
155{
156 return !!(data[index / 8] & (1 << (7 - (index % 8))));
157}
158
159
160
161
162
163
164
165
166
167static size_t longest_prefix_match(const struct lpm_trie *trie,
168 const struct lpm_trie_node *node,
169 const struct bpf_lpm_trie_key *key)
170{
171 u32 limit = min(node->prefixlen, key->prefixlen);
172 u32 prefixlen = 0, i = 0;
173
174 BUILD_BUG_ON(offsetof(struct lpm_trie_node, data) % sizeof(u32));
175 BUILD_BUG_ON(offsetof(struct bpf_lpm_trie_key, data) % sizeof(u32));
176
177#if defined(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) && defined(CONFIG_64BIT)
178
179
180
181
182 if (trie->data_size >= 8) {
183 u64 diff = be64_to_cpu(*(__be64 *)node->data ^
184 *(__be64 *)key->data);
185
186 prefixlen = 64 - fls64(diff);
187 if (prefixlen >= limit)
188 return limit;
189 if (diff)
190 return prefixlen;
191 i = 8;
192 }
193#endif
194
195 while (trie->data_size >= i + 4) {
196 u32 diff = be32_to_cpu(*(__be32 *)&node->data[i] ^
197 *(__be32 *)&key->data[i]);
198
199 prefixlen += 32 - fls(diff);
200 if (prefixlen >= limit)
201 return limit;
202 if (diff)
203 return prefixlen;
204 i += 4;
205 }
206
207 if (trie->data_size >= i + 2) {
208 u16 diff = be16_to_cpu(*(__be16 *)&node->data[i] ^
209 *(__be16 *)&key->data[i]);
210
211 prefixlen += 16 - fls(diff);
212 if (prefixlen >= limit)
213 return limit;
214 if (diff)
215 return prefixlen;
216 i += 2;
217 }
218
219 if (trie->data_size >= i + 1) {
220 prefixlen += 8 - fls(node->data[i] ^ key->data[i]);
221
222 if (prefixlen >= limit)
223 return limit;
224 }
225
226 return prefixlen;
227}
228
229
230static void *trie_lookup_elem(struct bpf_map *map, void *_key)
231{
232 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
233 struct lpm_trie_node *node, *found = NULL;
234 struct bpf_lpm_trie_key *key = _key;
235
236
237
238 for (node = rcu_dereference(trie->root); node;) {
239 unsigned int next_bit;
240 size_t matchlen;
241
242
243
244
245
246 matchlen = longest_prefix_match(trie, node, key);
247 if (matchlen == trie->max_prefixlen) {
248 found = node;
249 break;
250 }
251
252
253
254
255
256 if (matchlen < node->prefixlen)
257 break;
258
259
260
261
262 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
263 found = node;
264
265
266
267
268
269 next_bit = extract_bit(key->data, node->prefixlen);
270 node = rcu_dereference(node->child[next_bit]);
271 }
272
273 if (!found)
274 return NULL;
275
276 return found->data + trie->data_size;
277}
278
279static struct lpm_trie_node *lpm_trie_node_alloc(const struct lpm_trie *trie,
280 const void *value)
281{
282 struct lpm_trie_node *node;
283 size_t size = sizeof(struct lpm_trie_node) + trie->data_size;
284
285 if (value)
286 size += trie->map.value_size;
287
288 node = kmalloc_node(size, GFP_ATOMIC | __GFP_NOWARN,
289 trie->map.numa_node);
290 if (!node)
291 return NULL;
292
293 node->flags = 0;
294
295 if (value)
296 memcpy(node->data + trie->data_size, value,
297 trie->map.value_size);
298
299 return node;
300}
301
302
303static int trie_update_elem(struct bpf_map *map,
304 void *_key, void *value, u64 flags)
305{
306 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
307 struct lpm_trie_node *node, *im_node = NULL, *new_node = NULL;
308 struct lpm_trie_node __rcu **slot;
309 struct bpf_lpm_trie_key *key = _key;
310 unsigned long irq_flags;
311 unsigned int next_bit;
312 size_t matchlen = 0;
313 int ret = 0;
314
315 if (unlikely(flags > BPF_EXIST))
316 return -EINVAL;
317
318 if (key->prefixlen > trie->max_prefixlen)
319 return -EINVAL;
320
321 raw_spin_lock_irqsave(&trie->lock, irq_flags);
322
323
324
325 if (trie->n_entries == trie->map.max_entries) {
326 ret = -ENOSPC;
327 goto out;
328 }
329
330 new_node = lpm_trie_node_alloc(trie, value);
331 if (!new_node) {
332 ret = -ENOMEM;
333 goto out;
334 }
335
336 trie->n_entries++;
337
338 new_node->prefixlen = key->prefixlen;
339 RCU_INIT_POINTER(new_node->child[0], NULL);
340 RCU_INIT_POINTER(new_node->child[1], NULL);
341 memcpy(new_node->data, key->data, trie->data_size);
342
343
344
345
346
347
348 slot = &trie->root;
349
350 while ((node = rcu_dereference_protected(*slot,
351 lockdep_is_held(&trie->lock)))) {
352 matchlen = longest_prefix_match(trie, node, key);
353
354 if (node->prefixlen != matchlen ||
355 node->prefixlen == key->prefixlen ||
356 node->prefixlen == trie->max_prefixlen)
357 break;
358
359 next_bit = extract_bit(key->data, node->prefixlen);
360 slot = &node->child[next_bit];
361 }
362
363
364
365
366 if (!node) {
367 rcu_assign_pointer(*slot, new_node);
368 goto out;
369 }
370
371
372
373
374 if (node->prefixlen == matchlen) {
375 new_node->child[0] = node->child[0];
376 new_node->child[1] = node->child[1];
377
378 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
379 trie->n_entries--;
380
381 rcu_assign_pointer(*slot, new_node);
382 kfree_rcu(node, rcu);
383
384 goto out;
385 }
386
387
388
389
390 if (matchlen == key->prefixlen) {
391 next_bit = extract_bit(node->data, matchlen);
392 rcu_assign_pointer(new_node->child[next_bit], node);
393 rcu_assign_pointer(*slot, new_node);
394 goto out;
395 }
396
397 im_node = lpm_trie_node_alloc(trie, NULL);
398 if (!im_node) {
399 ret = -ENOMEM;
400 goto out;
401 }
402
403 im_node->prefixlen = matchlen;
404 im_node->flags |= LPM_TREE_NODE_FLAG_IM;
405 memcpy(im_node->data, node->data, trie->data_size);
406
407
408 if (extract_bit(key->data, matchlen)) {
409 rcu_assign_pointer(im_node->child[0], node);
410 rcu_assign_pointer(im_node->child[1], new_node);
411 } else {
412 rcu_assign_pointer(im_node->child[0], new_node);
413 rcu_assign_pointer(im_node->child[1], node);
414 }
415
416
417 rcu_assign_pointer(*slot, im_node);
418
419out:
420 if (ret) {
421 if (new_node)
422 trie->n_entries--;
423
424 kfree(new_node);
425 kfree(im_node);
426 }
427
428 raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
429
430 return ret;
431}
432
433
434static int trie_delete_elem(struct bpf_map *map, void *_key)
435{
436 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
437 struct bpf_lpm_trie_key *key = _key;
438 struct lpm_trie_node __rcu **trim, **trim2;
439 struct lpm_trie_node *node, *parent;
440 unsigned long irq_flags;
441 unsigned int next_bit;
442 size_t matchlen = 0;
443 int ret = 0;
444
445 if (key->prefixlen > trie->max_prefixlen)
446 return -EINVAL;
447
448 raw_spin_lock_irqsave(&trie->lock, irq_flags);
449
450
451
452
453
454
455
456 trim = &trie->root;
457 trim2 = trim;
458 parent = NULL;
459 while ((node = rcu_dereference_protected(
460 *trim, lockdep_is_held(&trie->lock)))) {
461 matchlen = longest_prefix_match(trie, node, key);
462
463 if (node->prefixlen != matchlen ||
464 node->prefixlen == key->prefixlen)
465 break;
466
467 parent = node;
468 trim2 = trim;
469 next_bit = extract_bit(key->data, node->prefixlen);
470 trim = &node->child[next_bit];
471 }
472
473 if (!node || node->prefixlen != key->prefixlen ||
474 node->prefixlen != matchlen ||
475 (node->flags & LPM_TREE_NODE_FLAG_IM)) {
476 ret = -ENOENT;
477 goto out;
478 }
479
480 trie->n_entries--;
481
482
483
484
485 if (rcu_access_pointer(node->child[0]) &&
486 rcu_access_pointer(node->child[1])) {
487 node->flags |= LPM_TREE_NODE_FLAG_IM;
488 goto out;
489 }
490
491
492
493
494
495
496
497
498 if (parent && (parent->flags & LPM_TREE_NODE_FLAG_IM) &&
499 !node->child[0] && !node->child[1]) {
500 if (node == rcu_access_pointer(parent->child[0]))
501 rcu_assign_pointer(
502 *trim2, rcu_access_pointer(parent->child[1]));
503 else
504 rcu_assign_pointer(
505 *trim2, rcu_access_pointer(parent->child[0]));
506 kfree_rcu(parent, rcu);
507 kfree_rcu(node, rcu);
508 goto out;
509 }
510
511
512
513
514
515 if (node->child[0])
516 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[0]));
517 else if (node->child[1])
518 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[1]));
519 else
520 RCU_INIT_POINTER(*trim, NULL);
521 kfree_rcu(node, rcu);
522
523out:
524 raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
525
526 return ret;
527}
528
529#define LPM_DATA_SIZE_MAX 256
530#define LPM_DATA_SIZE_MIN 1
531
532#define LPM_VAL_SIZE_MAX (KMALLOC_MAX_SIZE - LPM_DATA_SIZE_MAX - \
533 sizeof(struct lpm_trie_node))
534#define LPM_VAL_SIZE_MIN 1
535
536#define LPM_KEY_SIZE(X) (sizeof(struct bpf_lpm_trie_key) + (X))
537#define LPM_KEY_SIZE_MAX LPM_KEY_SIZE(LPM_DATA_SIZE_MAX)
538#define LPM_KEY_SIZE_MIN LPM_KEY_SIZE(LPM_DATA_SIZE_MIN)
539
540#define LPM_CREATE_FLAG_MASK (BPF_F_NO_PREALLOC | BPF_F_NUMA_NODE | \
541 BPF_F_RDONLY | BPF_F_WRONLY)
542
543static struct bpf_map *trie_alloc(union bpf_attr *attr)
544{
545 struct lpm_trie *trie;
546 u64 cost = sizeof(*trie), cost_per_node;
547 int ret;
548
549 if (!capable(CAP_SYS_ADMIN))
550 return ERR_PTR(-EPERM);
551
552
553 if (attr->max_entries == 0 ||
554 !(attr->map_flags & BPF_F_NO_PREALLOC) ||
555 attr->map_flags & ~LPM_CREATE_FLAG_MASK ||
556 attr->key_size < LPM_KEY_SIZE_MIN ||
557 attr->key_size > LPM_KEY_SIZE_MAX ||
558 attr->value_size < LPM_VAL_SIZE_MIN ||
559 attr->value_size > LPM_VAL_SIZE_MAX)
560 return ERR_PTR(-EINVAL);
561
562 trie = kzalloc(sizeof(*trie), GFP_USER | __GFP_NOWARN);
563 if (!trie)
564 return ERR_PTR(-ENOMEM);
565
566
567 bpf_map_init_from_attr(&trie->map, attr);
568 trie->data_size = attr->key_size -
569 offsetof(struct bpf_lpm_trie_key, data);
570 trie->max_prefixlen = trie->data_size * 8;
571
572 cost_per_node = sizeof(struct lpm_trie_node) +
573 attr->value_size + trie->data_size;
574 cost += (u64) attr->max_entries * cost_per_node;
575 if (cost >= U32_MAX - PAGE_SIZE) {
576 ret = -E2BIG;
577 goto out_err;
578 }
579
580 trie->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
581
582 ret = bpf_map_precharge_memlock(trie->map.pages);
583 if (ret)
584 goto out_err;
585
586 raw_spin_lock_init(&trie->lock);
587
588 return &trie->map;
589out_err:
590 kfree(trie);
591 return ERR_PTR(ret);
592}
593
594static void trie_free(struct bpf_map *map)
595{
596 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
597 struct lpm_trie_node __rcu **slot;
598 struct lpm_trie_node *node;
599
600
601
602
603 synchronize_rcu();
604
605
606
607
608
609
610 for (;;) {
611 slot = &trie->root;
612
613 for (;;) {
614 node = rcu_dereference_protected(*slot, 1);
615 if (!node)
616 goto out;
617
618 if (rcu_access_pointer(node->child[0])) {
619 slot = &node->child[0];
620 continue;
621 }
622
623 if (rcu_access_pointer(node->child[1])) {
624 slot = &node->child[1];
625 continue;
626 }
627
628 kfree(node);
629 RCU_INIT_POINTER(*slot, NULL);
630 break;
631 }
632 }
633
634out:
635 kfree(trie);
636}
637
638static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
639{
640 struct lpm_trie_node *node, *next_node = NULL, *parent, *search_root;
641 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
642 struct bpf_lpm_trie_key *key = _key, *next_key = _next_key;
643 struct lpm_trie_node **node_stack = NULL;
644 int err = 0, stack_ptr = -1;
645 unsigned int next_bit;
646 size_t matchlen;
647
648
649
650
651
652
653
654
655
656
657
658
659
660 search_root = rcu_dereference(trie->root);
661 if (!search_root)
662 return -ENOENT;
663
664
665 if (!key || key->prefixlen > trie->max_prefixlen)
666 goto find_leftmost;
667
668 node_stack = kmalloc_array(trie->max_prefixlen,
669 sizeof(struct lpm_trie_node *),
670 GFP_ATOMIC | __GFP_NOWARN);
671 if (!node_stack)
672 return -ENOMEM;
673
674
675 for (node = search_root; node;) {
676 node_stack[++stack_ptr] = node;
677 matchlen = longest_prefix_match(trie, node, key);
678 if (node->prefixlen != matchlen ||
679 node->prefixlen == key->prefixlen)
680 break;
681
682 next_bit = extract_bit(key->data, node->prefixlen);
683 node = rcu_dereference(node->child[next_bit]);
684 }
685 if (!node || node->prefixlen != key->prefixlen ||
686 (node->flags & LPM_TREE_NODE_FLAG_IM))
687 goto find_leftmost;
688
689
690
691
692 node = node_stack[stack_ptr];
693 while (stack_ptr > 0) {
694 parent = node_stack[stack_ptr - 1];
695 if (rcu_dereference(parent->child[0]) == node) {
696 search_root = rcu_dereference(parent->child[1]);
697 if (search_root)
698 goto find_leftmost;
699 }
700 if (!(parent->flags & LPM_TREE_NODE_FLAG_IM)) {
701 next_node = parent;
702 goto do_copy;
703 }
704
705 node = parent;
706 stack_ptr--;
707 }
708
709
710 err = -ENOENT;
711 goto free_stack;
712
713find_leftmost:
714
715
716
717 for (node = search_root; node;) {
718 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
719 next_node = node;
720 node = rcu_dereference(node->child[0]);
721 }
722do_copy:
723 next_key->prefixlen = next_node->prefixlen;
724 memcpy((void *)next_key + offsetof(struct bpf_lpm_trie_key, data),
725 next_node->data, trie->data_size);
726free_stack:
727 kfree(node_stack);
728 return err;
729}
730
731static int trie_check_btf(const struct bpf_map *map,
732 const struct btf *btf,
733 const struct btf_type *key_type,
734 const struct btf_type *value_type)
735{
736
737 return BTF_INFO_KIND(key_type->info) != BTF_KIND_STRUCT ?
738 -EINVAL : 0;
739}
740
741const struct bpf_map_ops trie_map_ops = {
742 .map_alloc = trie_alloc,
743 .map_free = trie_free,
744 .map_get_next_key = trie_get_next_key,
745 .map_lookup_elem = trie_lookup_elem,
746 .map_update_elem = trie_update_elem,
747 .map_delete_elem = trie_delete_elem,
748 .map_check_btf = trie_check_btf,
749};
750