1
2
3
4
5
6
7
8
9#include <linux/bpf.h>
10#include <linux/btf.h>
11#include <linux/err.h>
12#include <linux/slab.h>
13#include <linux/spinlock.h>
14#include <linux/vmalloc.h>
15#include <net/ipv6.h>
16#include <uapi/linux/btf.h>
17
18
19#define LPM_TREE_NODE_FLAG_IM BIT(0)
20
21struct lpm_trie_node;
22
23struct lpm_trie_node {
24 struct rcu_head rcu;
25 struct lpm_trie_node __rcu *child[2];
26 u32 prefixlen;
27 u32 flags;
28 u8 data[];
29};
30
31struct lpm_trie {
32 struct bpf_map map;
33 struct lpm_trie_node __rcu *root;
34 size_t n_entries;
35 size_t max_prefixlen;
36 size_t data_size;
37 spinlock_t lock;
38};
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151static inline int extract_bit(const u8 *data, size_t index)
152{
153 return !!(data[index / 8] & (1 << (7 - (index % 8))));
154}
155
156
157
158
159
160
161
162
163
164static size_t longest_prefix_match(const struct lpm_trie *trie,
165 const struct lpm_trie_node *node,
166 const struct bpf_lpm_trie_key *key)
167{
168 u32 limit = min(node->prefixlen, key->prefixlen);
169 u32 prefixlen = 0, i = 0;
170
171 BUILD_BUG_ON(offsetof(struct lpm_trie_node, data) % sizeof(u32));
172 BUILD_BUG_ON(offsetof(struct bpf_lpm_trie_key, data) % sizeof(u32));
173
174#if defined(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) && defined(CONFIG_64BIT)
175
176
177
178
179 if (trie->data_size >= 8) {
180 u64 diff = be64_to_cpu(*(__be64 *)node->data ^
181 *(__be64 *)key->data);
182
183 prefixlen = 64 - fls64(diff);
184 if (prefixlen >= limit)
185 return limit;
186 if (diff)
187 return prefixlen;
188 i = 8;
189 }
190#endif
191
192 while (trie->data_size >= i + 4) {
193 u32 diff = be32_to_cpu(*(__be32 *)&node->data[i] ^
194 *(__be32 *)&key->data[i]);
195
196 prefixlen += 32 - fls(diff);
197 if (prefixlen >= limit)
198 return limit;
199 if (diff)
200 return prefixlen;
201 i += 4;
202 }
203
204 if (trie->data_size >= i + 2) {
205 u16 diff = be16_to_cpu(*(__be16 *)&node->data[i] ^
206 *(__be16 *)&key->data[i]);
207
208 prefixlen += 16 - fls(diff);
209 if (prefixlen >= limit)
210 return limit;
211 if (diff)
212 return prefixlen;
213 i += 2;
214 }
215
216 if (trie->data_size >= i + 1) {
217 prefixlen += 8 - fls(node->data[i] ^ key->data[i]);
218
219 if (prefixlen >= limit)
220 return limit;
221 }
222
223 return prefixlen;
224}
225
226
227static void *trie_lookup_elem(struct bpf_map *map, void *_key)
228{
229 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
230 struct lpm_trie_node *node, *found = NULL;
231 struct bpf_lpm_trie_key *key = _key;
232
233
234
235 for (node = rcu_dereference_check(trie->root, rcu_read_lock_bh_held());
236 node;) {
237 unsigned int next_bit;
238 size_t matchlen;
239
240
241
242
243
244 matchlen = longest_prefix_match(trie, node, key);
245 if (matchlen == trie->max_prefixlen) {
246 found = node;
247 break;
248 }
249
250
251
252
253
254 if (matchlen < node->prefixlen)
255 break;
256
257
258
259
260 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
261 found = node;
262
263
264
265
266
267 next_bit = extract_bit(key->data, node->prefixlen);
268 node = rcu_dereference_check(node->child[next_bit],
269 rcu_read_lock_bh_held());
270 }
271
272 if (!found)
273 return NULL;
274
275 return found->data + trie->data_size;
276}
277
278static struct lpm_trie_node *lpm_trie_node_alloc(const struct lpm_trie *trie,
279 const void *value)
280{
281 struct lpm_trie_node *node;
282 size_t size = sizeof(struct lpm_trie_node) + trie->data_size;
283
284 if (value)
285 size += trie->map.value_size;
286
287 node = bpf_map_kmalloc_node(&trie->map, size, GFP_ATOMIC | __GFP_NOWARN,
288 trie->map.numa_node);
289 if (!node)
290 return NULL;
291
292 node->flags = 0;
293
294 if (value)
295 memcpy(node->data + trie->data_size, value,
296 trie->map.value_size);
297
298 return node;
299}
300
301
302static int trie_update_elem(struct bpf_map *map,
303 void *_key, void *value, u64 flags)
304{
305 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
306 struct lpm_trie_node *node, *im_node = NULL, *new_node = NULL;
307 struct lpm_trie_node __rcu **slot;
308 struct bpf_lpm_trie_key *key = _key;
309 unsigned long irq_flags;
310 unsigned int next_bit;
311 size_t matchlen = 0;
312 int ret = 0;
313
314 if (unlikely(flags > BPF_EXIST))
315 return -EINVAL;
316
317 if (key->prefixlen > trie->max_prefixlen)
318 return -EINVAL;
319
320 spin_lock_irqsave(&trie->lock, irq_flags);
321
322
323
324 if (trie->n_entries == trie->map.max_entries) {
325 ret = -ENOSPC;
326 goto out;
327 }
328
329 new_node = lpm_trie_node_alloc(trie, value);
330 if (!new_node) {
331 ret = -ENOMEM;
332 goto out;
333 }
334
335 trie->n_entries++;
336
337 new_node->prefixlen = key->prefixlen;
338 RCU_INIT_POINTER(new_node->child[0], NULL);
339 RCU_INIT_POINTER(new_node->child[1], NULL);
340 memcpy(new_node->data, key->data, trie->data_size);
341
342
343
344
345
346
347 slot = &trie->root;
348
349 while ((node = rcu_dereference_protected(*slot,
350 lockdep_is_held(&trie->lock)))) {
351 matchlen = longest_prefix_match(trie, node, key);
352
353 if (node->prefixlen != matchlen ||
354 node->prefixlen == key->prefixlen ||
355 node->prefixlen == trie->max_prefixlen)
356 break;
357
358 next_bit = extract_bit(key->data, node->prefixlen);
359 slot = &node->child[next_bit];
360 }
361
362
363
364
365 if (!node) {
366 rcu_assign_pointer(*slot, new_node);
367 goto out;
368 }
369
370
371
372
373 if (node->prefixlen == matchlen) {
374 new_node->child[0] = node->child[0];
375 new_node->child[1] = node->child[1];
376
377 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
378 trie->n_entries--;
379
380 rcu_assign_pointer(*slot, new_node);
381 kfree_rcu(node, rcu);
382
383 goto out;
384 }
385
386
387
388
389 if (matchlen == key->prefixlen) {
390 next_bit = extract_bit(node->data, matchlen);
391 rcu_assign_pointer(new_node->child[next_bit], node);
392 rcu_assign_pointer(*slot, new_node);
393 goto out;
394 }
395
396 im_node = lpm_trie_node_alloc(trie, NULL);
397 if (!im_node) {
398 ret = -ENOMEM;
399 goto out;
400 }
401
402 im_node->prefixlen = matchlen;
403 im_node->flags |= LPM_TREE_NODE_FLAG_IM;
404 memcpy(im_node->data, node->data, trie->data_size);
405
406
407 if (extract_bit(key->data, matchlen)) {
408 rcu_assign_pointer(im_node->child[0], node);
409 rcu_assign_pointer(im_node->child[1], new_node);
410 } else {
411 rcu_assign_pointer(im_node->child[0], new_node);
412 rcu_assign_pointer(im_node->child[1], node);
413 }
414
415
416 rcu_assign_pointer(*slot, im_node);
417
418out:
419 if (ret) {
420 if (new_node)
421 trie->n_entries--;
422
423 kfree(new_node);
424 kfree(im_node);
425 }
426
427 spin_unlock_irqrestore(&trie->lock, irq_flags);
428
429 return ret;
430}
431
432
433static int trie_delete_elem(struct bpf_map *map, void *_key)
434{
435 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
436 struct bpf_lpm_trie_key *key = _key;
437 struct lpm_trie_node __rcu **trim, **trim2;
438 struct lpm_trie_node *node, *parent;
439 unsigned long irq_flags;
440 unsigned int next_bit;
441 size_t matchlen = 0;
442 int ret = 0;
443
444 if (key->prefixlen > trie->max_prefixlen)
445 return -EINVAL;
446
447 spin_lock_irqsave(&trie->lock, irq_flags);
448
449
450
451
452
453
454
455 trim = &trie->root;
456 trim2 = trim;
457 parent = NULL;
458 while ((node = rcu_dereference_protected(
459 *trim, lockdep_is_held(&trie->lock)))) {
460 matchlen = longest_prefix_match(trie, node, key);
461
462 if (node->prefixlen != matchlen ||
463 node->prefixlen == key->prefixlen)
464 break;
465
466 parent = node;
467 trim2 = trim;
468 next_bit = extract_bit(key->data, node->prefixlen);
469 trim = &node->child[next_bit];
470 }
471
472 if (!node || node->prefixlen != key->prefixlen ||
473 node->prefixlen != matchlen ||
474 (node->flags & LPM_TREE_NODE_FLAG_IM)) {
475 ret = -ENOENT;
476 goto out;
477 }
478
479 trie->n_entries--;
480
481
482
483
484 if (rcu_access_pointer(node->child[0]) &&
485 rcu_access_pointer(node->child[1])) {
486 node->flags |= LPM_TREE_NODE_FLAG_IM;
487 goto out;
488 }
489
490
491
492
493
494
495
496
497 if (parent && (parent->flags & LPM_TREE_NODE_FLAG_IM) &&
498 !node->child[0] && !node->child[1]) {
499 if (node == rcu_access_pointer(parent->child[0]))
500 rcu_assign_pointer(
501 *trim2, rcu_access_pointer(parent->child[1]));
502 else
503 rcu_assign_pointer(
504 *trim2, rcu_access_pointer(parent->child[0]));
505 kfree_rcu(parent, rcu);
506 kfree_rcu(node, rcu);
507 goto out;
508 }
509
510
511
512
513
514 if (node->child[0])
515 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[0]));
516 else if (node->child[1])
517 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[1]));
518 else
519 RCU_INIT_POINTER(*trim, NULL);
520 kfree_rcu(node, rcu);
521
522out:
523 spin_unlock_irqrestore(&trie->lock, irq_flags);
524
525 return ret;
526}
527
528#define LPM_DATA_SIZE_MAX 256
529#define LPM_DATA_SIZE_MIN 1
530
531#define LPM_VAL_SIZE_MAX (KMALLOC_MAX_SIZE - LPM_DATA_SIZE_MAX - \
532 sizeof(struct lpm_trie_node))
533#define LPM_VAL_SIZE_MIN 1
534
535#define LPM_KEY_SIZE(X) (sizeof(struct bpf_lpm_trie_key) + (X))
536#define LPM_KEY_SIZE_MAX LPM_KEY_SIZE(LPM_DATA_SIZE_MAX)
537#define LPM_KEY_SIZE_MIN LPM_KEY_SIZE(LPM_DATA_SIZE_MIN)
538
539#define LPM_CREATE_FLAG_MASK (BPF_F_NO_PREALLOC | BPF_F_NUMA_NODE | \
540 BPF_F_ACCESS_MASK)
541
542static struct bpf_map *trie_alloc(union bpf_attr *attr)
543{
544 struct lpm_trie *trie;
545
546 if (!bpf_capable())
547 return ERR_PTR(-EPERM);
548
549
550 if (attr->max_entries == 0 ||
551 !(attr->map_flags & BPF_F_NO_PREALLOC) ||
552 attr->map_flags & ~LPM_CREATE_FLAG_MASK ||
553 !bpf_map_flags_access_ok(attr->map_flags) ||
554 attr->key_size < LPM_KEY_SIZE_MIN ||
555 attr->key_size > LPM_KEY_SIZE_MAX ||
556 attr->value_size < LPM_VAL_SIZE_MIN ||
557 attr->value_size > LPM_VAL_SIZE_MAX)
558 return ERR_PTR(-EINVAL);
559
560 trie = kzalloc(sizeof(*trie), GFP_USER | __GFP_NOWARN | __GFP_ACCOUNT);
561 if (!trie)
562 return ERR_PTR(-ENOMEM);
563
564
565 bpf_map_init_from_attr(&trie->map, attr);
566 trie->data_size = attr->key_size -
567 offsetof(struct bpf_lpm_trie_key, data);
568 trie->max_prefixlen = trie->data_size * 8;
569
570 spin_lock_init(&trie->lock);
571
572 return &trie->map;
573}
574
575static void trie_free(struct bpf_map *map)
576{
577 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
578 struct lpm_trie_node __rcu **slot;
579 struct lpm_trie_node *node;
580
581
582
583
584
585
586 for (;;) {
587 slot = &trie->root;
588
589 for (;;) {
590 node = rcu_dereference_protected(*slot, 1);
591 if (!node)
592 goto out;
593
594 if (rcu_access_pointer(node->child[0])) {
595 slot = &node->child[0];
596 continue;
597 }
598
599 if (rcu_access_pointer(node->child[1])) {
600 slot = &node->child[1];
601 continue;
602 }
603
604 kfree(node);
605 RCU_INIT_POINTER(*slot, NULL);
606 break;
607 }
608 }
609
610out:
611 kfree(trie);
612}
613
614static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
615{
616 struct lpm_trie_node *node, *next_node = NULL, *parent, *search_root;
617 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
618 struct bpf_lpm_trie_key *key = _key, *next_key = _next_key;
619 struct lpm_trie_node **node_stack = NULL;
620 int err = 0, stack_ptr = -1;
621 unsigned int next_bit;
622 size_t matchlen;
623
624
625
626
627
628
629
630
631
632
633
634
635
636 search_root = rcu_dereference(trie->root);
637 if (!search_root)
638 return -ENOENT;
639
640
641 if (!key || key->prefixlen > trie->max_prefixlen)
642 goto find_leftmost;
643
644 node_stack = kmalloc_array(trie->max_prefixlen,
645 sizeof(struct lpm_trie_node *),
646 GFP_ATOMIC | __GFP_NOWARN);
647 if (!node_stack)
648 return -ENOMEM;
649
650
651 for (node = search_root; node;) {
652 node_stack[++stack_ptr] = node;
653 matchlen = longest_prefix_match(trie, node, key);
654 if (node->prefixlen != matchlen ||
655 node->prefixlen == key->prefixlen)
656 break;
657
658 next_bit = extract_bit(key->data, node->prefixlen);
659 node = rcu_dereference(node->child[next_bit]);
660 }
661 if (!node || node->prefixlen != key->prefixlen ||
662 (node->flags & LPM_TREE_NODE_FLAG_IM))
663 goto find_leftmost;
664
665
666
667
668 node = node_stack[stack_ptr];
669 while (stack_ptr > 0) {
670 parent = node_stack[stack_ptr - 1];
671 if (rcu_dereference(parent->child[0]) == node) {
672 search_root = rcu_dereference(parent->child[1]);
673 if (search_root)
674 goto find_leftmost;
675 }
676 if (!(parent->flags & LPM_TREE_NODE_FLAG_IM)) {
677 next_node = parent;
678 goto do_copy;
679 }
680
681 node = parent;
682 stack_ptr--;
683 }
684
685
686 err = -ENOENT;
687 goto free_stack;
688
689find_leftmost:
690
691
692
693 for (node = search_root; node;) {
694 if (node->flags & LPM_TREE_NODE_FLAG_IM) {
695 node = rcu_dereference(node->child[0]);
696 } else {
697 next_node = node;
698 node = rcu_dereference(node->child[0]);
699 if (!node)
700 node = rcu_dereference(next_node->child[1]);
701 }
702 }
703do_copy:
704 next_key->prefixlen = next_node->prefixlen;
705 memcpy((void *)next_key + offsetof(struct bpf_lpm_trie_key, data),
706 next_node->data, trie->data_size);
707free_stack:
708 kfree(node_stack);
709 return err;
710}
711
712static int trie_check_btf(const struct bpf_map *map,
713 const struct btf *btf,
714 const struct btf_type *key_type,
715 const struct btf_type *value_type)
716{
717
718 return BTF_INFO_KIND(key_type->info) != BTF_KIND_STRUCT ?
719 -EINVAL : 0;
720}
721
722static int trie_map_btf_id;
723const struct bpf_map_ops trie_map_ops = {
724 .map_meta_equal = bpf_map_meta_equal,
725 .map_alloc = trie_alloc,
726 .map_free = trie_free,
727 .map_get_next_key = trie_get_next_key,
728 .map_lookup_elem = trie_lookup_elem,
729 .map_update_elem = trie_update_elem,
730 .map_delete_elem = trie_delete_elem,
731 .map_lookup_batch = generic_map_lookup_batch,
732 .map_update_batch = generic_map_update_batch,
733 .map_delete_batch = generic_map_delete_batch,
734 .map_check_btf = trie_check_btf,
735 .map_btf_name = "lpm_trie",
736 .map_btf_id = &trie_map_btf_id,
737};
738