1
2
3
4
5
6
7
8
9#include <linux/bpf.h>
10#include <linux/btf.h>
11#include <linux/err.h>
12#include <linux/slab.h>
13#include <linux/spinlock.h>
14#include <linux/vmalloc.h>
15#include <net/ipv6.h>
16#include <uapi/linux/btf.h>
17
18
19#define LPM_TREE_NODE_FLAG_IM BIT(0)
20
21struct lpm_trie_node;
22
23struct lpm_trie_node {
24 struct rcu_head rcu;
25 struct lpm_trie_node __rcu *child[2];
26 u32 prefixlen;
27 u32 flags;
28 u8 data[0];
29};
30
31struct lpm_trie {
32 struct bpf_map map;
33 struct lpm_trie_node __rcu *root;
34 size_t n_entries;
35 size_t max_prefixlen;
36 size_t data_size;
37 raw_spinlock_t lock;
38};
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151static inline int extract_bit(const u8 *data, size_t index)
152{
153 return !!(data[index / 8] & (1 << (7 - (index % 8))));
154}
155
156
157
158
159
160
161
162
163
164static size_t longest_prefix_match(const struct lpm_trie *trie,
165 const struct lpm_trie_node *node,
166 const struct bpf_lpm_trie_key *key)
167{
168 u32 limit = min(node->prefixlen, key->prefixlen);
169 u32 prefixlen = 0, i = 0;
170
171 BUILD_BUG_ON(offsetof(struct lpm_trie_node, data) % sizeof(u32));
172 BUILD_BUG_ON(offsetof(struct bpf_lpm_trie_key, data) % sizeof(u32));
173
174#if defined(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) && defined(CONFIG_64BIT)
175
176
177
178
179 if (trie->data_size >= 8) {
180 u64 diff = be64_to_cpu(*(__be64 *)node->data ^
181 *(__be64 *)key->data);
182
183 prefixlen = 64 - fls64(diff);
184 if (prefixlen >= limit)
185 return limit;
186 if (diff)
187 return prefixlen;
188 i = 8;
189 }
190#endif
191
192 while (trie->data_size >= i + 4) {
193 u32 diff = be32_to_cpu(*(__be32 *)&node->data[i] ^
194 *(__be32 *)&key->data[i]);
195
196 prefixlen += 32 - fls(diff);
197 if (prefixlen >= limit)
198 return limit;
199 if (diff)
200 return prefixlen;
201 i += 4;
202 }
203
204 if (trie->data_size >= i + 2) {
205 u16 diff = be16_to_cpu(*(__be16 *)&node->data[i] ^
206 *(__be16 *)&key->data[i]);
207
208 prefixlen += 16 - fls(diff);
209 if (prefixlen >= limit)
210 return limit;
211 if (diff)
212 return prefixlen;
213 i += 2;
214 }
215
216 if (trie->data_size >= i + 1) {
217 prefixlen += 8 - fls(node->data[i] ^ key->data[i]);
218
219 if (prefixlen >= limit)
220 return limit;
221 }
222
223 return prefixlen;
224}
225
226
227static void *trie_lookup_elem(struct bpf_map *map, void *_key)
228{
229 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
230 struct lpm_trie_node *node, *found = NULL;
231 struct bpf_lpm_trie_key *key = _key;
232
233
234
235 for (node = rcu_dereference(trie->root); node;) {
236 unsigned int next_bit;
237 size_t matchlen;
238
239
240
241
242
243 matchlen = longest_prefix_match(trie, node, key);
244 if (matchlen == trie->max_prefixlen) {
245 found = node;
246 break;
247 }
248
249
250
251
252
253 if (matchlen < node->prefixlen)
254 break;
255
256
257
258
259 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
260 found = node;
261
262
263
264
265
266 next_bit = extract_bit(key->data, node->prefixlen);
267 node = rcu_dereference(node->child[next_bit]);
268 }
269
270 if (!found)
271 return NULL;
272
273 return found->data + trie->data_size;
274}
275
276static struct lpm_trie_node *lpm_trie_node_alloc(const struct lpm_trie *trie,
277 const void *value)
278{
279 struct lpm_trie_node *node;
280 size_t size = sizeof(struct lpm_trie_node) + trie->data_size;
281
282 if (value)
283 size += trie->map.value_size;
284
285 node = kmalloc_node(size, GFP_ATOMIC | __GFP_NOWARN,
286 trie->map.numa_node);
287 if (!node)
288 return NULL;
289
290 node->flags = 0;
291
292 if (value)
293 memcpy(node->data + trie->data_size, value,
294 trie->map.value_size);
295
296 return node;
297}
298
299
300static int trie_update_elem(struct bpf_map *map,
301 void *_key, void *value, u64 flags)
302{
303 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
304 struct lpm_trie_node *node, *im_node = NULL, *new_node = NULL;
305 struct lpm_trie_node __rcu **slot;
306 struct bpf_lpm_trie_key *key = _key;
307 unsigned long irq_flags;
308 unsigned int next_bit;
309 size_t matchlen = 0;
310 int ret = 0;
311
312 if (unlikely(flags > BPF_EXIST))
313 return -EINVAL;
314
315 if (key->prefixlen > trie->max_prefixlen)
316 return -EINVAL;
317
318 raw_spin_lock_irqsave(&trie->lock, irq_flags);
319
320
321
322 if (trie->n_entries == trie->map.max_entries) {
323 ret = -ENOSPC;
324 goto out;
325 }
326
327 new_node = lpm_trie_node_alloc(trie, value);
328 if (!new_node) {
329 ret = -ENOMEM;
330 goto out;
331 }
332
333 trie->n_entries++;
334
335 new_node->prefixlen = key->prefixlen;
336 RCU_INIT_POINTER(new_node->child[0], NULL);
337 RCU_INIT_POINTER(new_node->child[1], NULL);
338 memcpy(new_node->data, key->data, trie->data_size);
339
340
341
342
343
344
345 slot = &trie->root;
346
347 while ((node = rcu_dereference_protected(*slot,
348 lockdep_is_held(&trie->lock)))) {
349 matchlen = longest_prefix_match(trie, node, key);
350
351 if (node->prefixlen != matchlen ||
352 node->prefixlen == key->prefixlen ||
353 node->prefixlen == trie->max_prefixlen)
354 break;
355
356 next_bit = extract_bit(key->data, node->prefixlen);
357 slot = &node->child[next_bit];
358 }
359
360
361
362
363 if (!node) {
364 rcu_assign_pointer(*slot, new_node);
365 goto out;
366 }
367
368
369
370
371 if (node->prefixlen == matchlen) {
372 new_node->child[0] = node->child[0];
373 new_node->child[1] = node->child[1];
374
375 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
376 trie->n_entries--;
377
378 rcu_assign_pointer(*slot, new_node);
379 kfree_rcu(node, rcu);
380
381 goto out;
382 }
383
384
385
386
387 if (matchlen == key->prefixlen) {
388 next_bit = extract_bit(node->data, matchlen);
389 rcu_assign_pointer(new_node->child[next_bit], node);
390 rcu_assign_pointer(*slot, new_node);
391 goto out;
392 }
393
394 im_node = lpm_trie_node_alloc(trie, NULL);
395 if (!im_node) {
396 ret = -ENOMEM;
397 goto out;
398 }
399
400 im_node->prefixlen = matchlen;
401 im_node->flags |= LPM_TREE_NODE_FLAG_IM;
402 memcpy(im_node->data, node->data, trie->data_size);
403
404
405 if (extract_bit(key->data, matchlen)) {
406 rcu_assign_pointer(im_node->child[0], node);
407 rcu_assign_pointer(im_node->child[1], new_node);
408 } else {
409 rcu_assign_pointer(im_node->child[0], new_node);
410 rcu_assign_pointer(im_node->child[1], node);
411 }
412
413
414 rcu_assign_pointer(*slot, im_node);
415
416out:
417 if (ret) {
418 if (new_node)
419 trie->n_entries--;
420
421 kfree(new_node);
422 kfree(im_node);
423 }
424
425 raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
426
427 return ret;
428}
429
430
431static int trie_delete_elem(struct bpf_map *map, void *_key)
432{
433 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
434 struct bpf_lpm_trie_key *key = _key;
435 struct lpm_trie_node __rcu **trim, **trim2;
436 struct lpm_trie_node *node, *parent;
437 unsigned long irq_flags;
438 unsigned int next_bit;
439 size_t matchlen = 0;
440 int ret = 0;
441
442 if (key->prefixlen > trie->max_prefixlen)
443 return -EINVAL;
444
445 raw_spin_lock_irqsave(&trie->lock, irq_flags);
446
447
448
449
450
451
452
453 trim = &trie->root;
454 trim2 = trim;
455 parent = NULL;
456 while ((node = rcu_dereference_protected(
457 *trim, lockdep_is_held(&trie->lock)))) {
458 matchlen = longest_prefix_match(trie, node, key);
459
460 if (node->prefixlen != matchlen ||
461 node->prefixlen == key->prefixlen)
462 break;
463
464 parent = node;
465 trim2 = trim;
466 next_bit = extract_bit(key->data, node->prefixlen);
467 trim = &node->child[next_bit];
468 }
469
470 if (!node || node->prefixlen != key->prefixlen ||
471 node->prefixlen != matchlen ||
472 (node->flags & LPM_TREE_NODE_FLAG_IM)) {
473 ret = -ENOENT;
474 goto out;
475 }
476
477 trie->n_entries--;
478
479
480
481
482 if (rcu_access_pointer(node->child[0]) &&
483 rcu_access_pointer(node->child[1])) {
484 node->flags |= LPM_TREE_NODE_FLAG_IM;
485 goto out;
486 }
487
488
489
490
491
492
493
494
495 if (parent && (parent->flags & LPM_TREE_NODE_FLAG_IM) &&
496 !node->child[0] && !node->child[1]) {
497 if (node == rcu_access_pointer(parent->child[0]))
498 rcu_assign_pointer(
499 *trim2, rcu_access_pointer(parent->child[1]));
500 else
501 rcu_assign_pointer(
502 *trim2, rcu_access_pointer(parent->child[0]));
503 kfree_rcu(parent, rcu);
504 kfree_rcu(node, rcu);
505 goto out;
506 }
507
508
509
510
511
512 if (node->child[0])
513 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[0]));
514 else if (node->child[1])
515 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[1]));
516 else
517 RCU_INIT_POINTER(*trim, NULL);
518 kfree_rcu(node, rcu);
519
520out:
521 raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
522
523 return ret;
524}
525
526#define LPM_DATA_SIZE_MAX 256
527#define LPM_DATA_SIZE_MIN 1
528
529#define LPM_VAL_SIZE_MAX (KMALLOC_MAX_SIZE - LPM_DATA_SIZE_MAX - \
530 sizeof(struct lpm_trie_node))
531#define LPM_VAL_SIZE_MIN 1
532
533#define LPM_KEY_SIZE(X) (sizeof(struct bpf_lpm_trie_key) + (X))
534#define LPM_KEY_SIZE_MAX LPM_KEY_SIZE(LPM_DATA_SIZE_MAX)
535#define LPM_KEY_SIZE_MIN LPM_KEY_SIZE(LPM_DATA_SIZE_MIN)
536
537#define LPM_CREATE_FLAG_MASK (BPF_F_NO_PREALLOC | BPF_F_NUMA_NODE | \
538 BPF_F_ACCESS_MASK)
539
540static struct bpf_map *trie_alloc(union bpf_attr *attr)
541{
542 struct lpm_trie *trie;
543 u64 cost = sizeof(*trie), cost_per_node;
544 int ret;
545
546 if (!capable(CAP_SYS_ADMIN))
547 return ERR_PTR(-EPERM);
548
549
550 if (attr->max_entries == 0 ||
551 !(attr->map_flags & BPF_F_NO_PREALLOC) ||
552 attr->map_flags & ~LPM_CREATE_FLAG_MASK ||
553 !bpf_map_flags_access_ok(attr->map_flags) ||
554 attr->key_size < LPM_KEY_SIZE_MIN ||
555 attr->key_size > LPM_KEY_SIZE_MAX ||
556 attr->value_size < LPM_VAL_SIZE_MIN ||
557 attr->value_size > LPM_VAL_SIZE_MAX)
558 return ERR_PTR(-EINVAL);
559
560 trie = kzalloc(sizeof(*trie), GFP_USER | __GFP_NOWARN);
561 if (!trie)
562 return ERR_PTR(-ENOMEM);
563
564
565 bpf_map_init_from_attr(&trie->map, attr);
566 trie->data_size = attr->key_size -
567 offsetof(struct bpf_lpm_trie_key, data);
568 trie->max_prefixlen = trie->data_size * 8;
569
570 cost_per_node = sizeof(struct lpm_trie_node) +
571 attr->value_size + trie->data_size;
572 cost += (u64) attr->max_entries * cost_per_node;
573 if (cost >= U32_MAX - PAGE_SIZE) {
574 ret = -E2BIG;
575 goto out_err;
576 }
577
578 trie->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
579
580 ret = bpf_map_precharge_memlock(trie->map.pages);
581 if (ret)
582 goto out_err;
583
584 raw_spin_lock_init(&trie->lock);
585
586 return &trie->map;
587out_err:
588 kfree(trie);
589 return ERR_PTR(ret);
590}
591
592static void trie_free(struct bpf_map *map)
593{
594 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
595 struct lpm_trie_node __rcu **slot;
596 struct lpm_trie_node *node;
597
598
599
600
601 synchronize_rcu();
602
603
604
605
606
607
608 for (;;) {
609 slot = &trie->root;
610
611 for (;;) {
612 node = rcu_dereference_protected(*slot, 1);
613 if (!node)
614 goto out;
615
616 if (rcu_access_pointer(node->child[0])) {
617 slot = &node->child[0];
618 continue;
619 }
620
621 if (rcu_access_pointer(node->child[1])) {
622 slot = &node->child[1];
623 continue;
624 }
625
626 kfree(node);
627 RCU_INIT_POINTER(*slot, NULL);
628 break;
629 }
630 }
631
632out:
633 kfree(trie);
634}
635
636static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
637{
638 struct lpm_trie_node *node, *next_node = NULL, *parent, *search_root;
639 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
640 struct bpf_lpm_trie_key *key = _key, *next_key = _next_key;
641 struct lpm_trie_node **node_stack = NULL;
642 int err = 0, stack_ptr = -1;
643 unsigned int next_bit;
644 size_t matchlen;
645
646
647
648
649
650
651
652
653
654
655
656
657
658 search_root = rcu_dereference(trie->root);
659 if (!search_root)
660 return -ENOENT;
661
662
663 if (!key || key->prefixlen > trie->max_prefixlen)
664 goto find_leftmost;
665
666 node_stack = kmalloc_array(trie->max_prefixlen,
667 sizeof(struct lpm_trie_node *),
668 GFP_ATOMIC | __GFP_NOWARN);
669 if (!node_stack)
670 return -ENOMEM;
671
672
673 for (node = search_root; node;) {
674 node_stack[++stack_ptr] = node;
675 matchlen = longest_prefix_match(trie, node, key);
676 if (node->prefixlen != matchlen ||
677 node->prefixlen == key->prefixlen)
678 break;
679
680 next_bit = extract_bit(key->data, node->prefixlen);
681 node = rcu_dereference(node->child[next_bit]);
682 }
683 if (!node || node->prefixlen != key->prefixlen ||
684 (node->flags & LPM_TREE_NODE_FLAG_IM))
685 goto find_leftmost;
686
687
688
689
690 node = node_stack[stack_ptr];
691 while (stack_ptr > 0) {
692 parent = node_stack[stack_ptr - 1];
693 if (rcu_dereference(parent->child[0]) == node) {
694 search_root = rcu_dereference(parent->child[1]);
695 if (search_root)
696 goto find_leftmost;
697 }
698 if (!(parent->flags & LPM_TREE_NODE_FLAG_IM)) {
699 next_node = parent;
700 goto do_copy;
701 }
702
703 node = parent;
704 stack_ptr--;
705 }
706
707
708 err = -ENOENT;
709 goto free_stack;
710
711find_leftmost:
712
713
714
715 for (node = search_root; node;) {
716 if (node->flags & LPM_TREE_NODE_FLAG_IM) {
717 node = rcu_dereference(node->child[0]);
718 } else {
719 next_node = node;
720 node = rcu_dereference(node->child[0]);
721 if (!node)
722 node = rcu_dereference(next_node->child[1]);
723 }
724 }
725do_copy:
726 next_key->prefixlen = next_node->prefixlen;
727 memcpy((void *)next_key + offsetof(struct bpf_lpm_trie_key, data),
728 next_node->data, trie->data_size);
729free_stack:
730 kfree(node_stack);
731 return err;
732}
733
734static int trie_check_btf(const struct bpf_map *map,
735 const struct btf *btf,
736 const struct btf_type *key_type,
737 const struct btf_type *value_type)
738{
739
740 return BTF_INFO_KIND(key_type->info) != BTF_KIND_STRUCT ?
741 -EINVAL : 0;
742}
743
744const struct bpf_map_ops trie_map_ops = {
745 .map_alloc = trie_alloc,
746 .map_free = trie_free,
747 .map_get_next_key = trie_get_next_key,
748 .map_lookup_elem = trie_lookup_elem,
749 .map_update_elem = trie_update_elem,
750 .map_delete_elem = trie_delete_elem,
751 .map_check_btf = trie_check_btf,
752};
753