linux/drivers/net/wireguard/selftest/allowedips.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
   4 *
   5 * This contains some basic static unit tests for the allowedips data structure.
   6 * It also has two additional modes that are disabled and meant to be used by
   7 * folks directly playing with this file. If you define the macro
   8 * DEBUG_PRINT_TRIE_GRAPHVIZ to be 1, then every time there's a full tree in
   9 * memory, it will be printed out as KERN_DEBUG in a format that can be passed
  10 * to graphviz (the dot command) to visualize it. If you define the macro
  11 * DEBUG_RANDOM_TRIE to be 1, then there will be an extremely costly set of
  12 * randomized tests done against a trivial implementation, which may take
  13 * upwards of a half-hour to complete. There's no set of users who should be
  14 * enabling these, and the only developers that should go anywhere near these
  15 * nobs are the ones who are reading this comment.
  16 */
  17
  18#ifdef DEBUG
  19
  20#include <linux/siphash.h>
  21
  22static __init void print_node(struct allowedips_node *node, u8 bits)
  23{
  24        char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
  25        char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
  26        u8 ip1[16], ip2[16], cidr1, cidr2;
  27        char *style = "dotted";
  28        u32 color = 0;
  29
  30        if (node == NULL)
  31                return;
  32        if (bits == 32) {
  33                fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
  34                fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
  35        } else if (bits == 128) {
  36                fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
  37                fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
  38        }
  39        if (node->peer) {
  40                hsiphash_key_t key = { { 0 } };
  41
  42                memcpy(&key, &node->peer, sizeof(node->peer));
  43                color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 |
  44                        hsiphash_1u32(0xbabecafe, &key) % 200 << 8 |
  45                        hsiphash_1u32(0xabad1dea, &key) % 200;
  46                style = "bold";
  47        }
  48        wg_allowedips_read_node(node, ip1, &cidr1);
  49        printk(fmt_declaration, ip1, cidr1, style, color);
  50        if (node->bit[0]) {
  51                wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2);
  52                printk(fmt_connection, ip1, cidr1, ip2, cidr2);
  53        }
  54        if (node->bit[1]) {
  55                wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2);
  56                printk(fmt_connection, ip1, cidr1, ip2, cidr2);
  57        }
  58        if (node->bit[0])
  59                print_node(rcu_dereference_raw(node->bit[0]), bits);
  60        if (node->bit[1])
  61                print_node(rcu_dereference_raw(node->bit[1]), bits);
  62}
  63
  64static __init void print_tree(struct allowedips_node __rcu *top, u8 bits)
  65{
  66        printk(KERN_DEBUG "digraph trie {\n");
  67        print_node(rcu_dereference_raw(top), bits);
  68        printk(KERN_DEBUG "}\n");
  69}
  70
  71enum {
  72        NUM_PEERS = 2000,
  73        NUM_RAND_ROUTES = 400,
  74        NUM_MUTATED_ROUTES = 100,
  75        NUM_QUERIES = NUM_RAND_ROUTES * NUM_MUTATED_ROUTES * 30
  76};
  77
  78struct horrible_allowedips {
  79        struct hlist_head head;
  80};
  81
  82struct horrible_allowedips_node {
  83        struct hlist_node table;
  84        union nf_inet_addr ip;
  85        union nf_inet_addr mask;
  86        u8 ip_version;
  87        void *value;
  88};
  89
  90static __init void horrible_allowedips_init(struct horrible_allowedips *table)
  91{
  92        INIT_HLIST_HEAD(&table->head);
  93}
  94
  95static __init void horrible_allowedips_free(struct horrible_allowedips *table)
  96{
  97        struct horrible_allowedips_node *node;
  98        struct hlist_node *h;
  99
 100        hlist_for_each_entry_safe(node, h, &table->head, table) {
 101                hlist_del(&node->table);
 102                kfree(node);
 103        }
 104}
 105
 106static __init inline union nf_inet_addr horrible_cidr_to_mask(u8 cidr)
 107{
 108        union nf_inet_addr mask;
 109
 110        memset(&mask, 0, sizeof(mask));
 111        memset(&mask.all, 0xff, cidr / 8);
 112        if (cidr % 32)
 113                mask.all[cidr / 32] = (__force u32)htonl(
 114                        (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
 115        return mask;
 116}
 117
 118static __init inline u8 horrible_mask_to_cidr(union nf_inet_addr subnet)
 119{
 120        return hweight32(subnet.all[0]) + hweight32(subnet.all[1]) +
 121               hweight32(subnet.all[2]) + hweight32(subnet.all[3]);
 122}
 123
 124static __init inline void
 125horrible_mask_self(struct horrible_allowedips_node *node)
 126{
 127        if (node->ip_version == 4) {
 128                node->ip.ip &= node->mask.ip;
 129        } else if (node->ip_version == 6) {
 130                node->ip.ip6[0] &= node->mask.ip6[0];
 131                node->ip.ip6[1] &= node->mask.ip6[1];
 132                node->ip.ip6[2] &= node->mask.ip6[2];
 133                node->ip.ip6[3] &= node->mask.ip6[3];
 134        }
 135}
 136
 137static __init inline bool
 138horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip)
 139{
 140        return (ip->s_addr & node->mask.ip) == node->ip.ip;
 141}
 142
 143static __init inline bool
 144horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip)
 145{
 146        return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] &&
 147               (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] &&
 148               (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] &&
 149               (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
 150}
 151
 152static __init void
 153horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node)
 154{
 155        struct horrible_allowedips_node *other = NULL, *where = NULL;
 156        u8 my_cidr = horrible_mask_to_cidr(node->mask);
 157
 158        hlist_for_each_entry(other, &table->head, table) {
 159                if (other->ip_version == node->ip_version &&
 160                    !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) &&
 161                    !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) {
 162                        other->value = node->value;
 163                        kfree(node);
 164                        return;
 165                }
 166        }
 167        hlist_for_each_entry(other, &table->head, table) {
 168                where = other;
 169                if (horrible_mask_to_cidr(other->mask) <= my_cidr)
 170                        break;
 171        }
 172        if (!other && !where)
 173                hlist_add_head(&node->table, &table->head);
 174        else if (!other)
 175                hlist_add_behind(&node->table, &where->table);
 176        else
 177                hlist_add_before(&node->table, &where->table);
 178}
 179
 180static __init int
 181horrible_allowedips_insert_v4(struct horrible_allowedips *table,
 182                              struct in_addr *ip, u8 cidr, void *value)
 183{
 184        struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
 185
 186        if (unlikely(!node))
 187                return -ENOMEM;
 188        node->ip.in = *ip;
 189        node->mask = horrible_cidr_to_mask(cidr);
 190        node->ip_version = 4;
 191        node->value = value;
 192        horrible_mask_self(node);
 193        horrible_insert_ordered(table, node);
 194        return 0;
 195}
 196
 197static __init int
 198horrible_allowedips_insert_v6(struct horrible_allowedips *table,
 199                              struct in6_addr *ip, u8 cidr, void *value)
 200{
 201        struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
 202
 203        if (unlikely(!node))
 204                return -ENOMEM;
 205        node->ip.in6 = *ip;
 206        node->mask = horrible_cidr_to_mask(cidr);
 207        node->ip_version = 6;
 208        node->value = value;
 209        horrible_mask_self(node);
 210        horrible_insert_ordered(table, node);
 211        return 0;
 212}
 213
 214static __init void *
 215horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip)
 216{
 217        struct horrible_allowedips_node *node;
 218
 219        hlist_for_each_entry(node, &table->head, table) {
 220                if (node->ip_version == 4 && horrible_match_v4(node, ip))
 221                        return node->value;
 222        }
 223        return NULL;
 224}
 225
 226static __init void *
 227horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip)
 228{
 229        struct horrible_allowedips_node *node;
 230
 231        hlist_for_each_entry(node, &table->head, table) {
 232                if (node->ip_version == 6 && horrible_match_v6(node, ip))
 233                        return node->value;
 234        }
 235        return NULL;
 236}
 237
 238
 239static __init void
 240horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value)
 241{
 242        struct horrible_allowedips_node *node;
 243        struct hlist_node *h;
 244
 245        hlist_for_each_entry_safe(node, h, &table->head, table) {
 246                if (node->value != value)
 247                        continue;
 248                hlist_del(&node->table);
 249                kfree(node);
 250        }
 251
 252}
 253
 254static __init bool randomized_test(void)
 255{
 256        unsigned int i, j, k, mutate_amount, cidr;
 257        u8 ip[16], mutate_mask[16], mutated[16];
 258        struct wg_peer **peers, *peer;
 259        struct horrible_allowedips h;
 260        DEFINE_MUTEX(mutex);
 261        struct allowedips t;
 262        bool ret = false;
 263
 264        mutex_init(&mutex);
 265
 266        wg_allowedips_init(&t);
 267        horrible_allowedips_init(&h);
 268
 269        peers = kcalloc(NUM_PEERS, sizeof(*peers), GFP_KERNEL);
 270        if (unlikely(!peers)) {
 271                pr_err("allowedips random self-test malloc: FAIL\n");
 272                goto free;
 273        }
 274        for (i = 0; i < NUM_PEERS; ++i) {
 275                peers[i] = kzalloc(sizeof(*peers[i]), GFP_KERNEL);
 276                if (unlikely(!peers[i])) {
 277                        pr_err("allowedips random self-test malloc: FAIL\n");
 278                        goto free;
 279                }
 280                kref_init(&peers[i]->refcount);
 281                INIT_LIST_HEAD(&peers[i]->allowedips_list);
 282        }
 283
 284        mutex_lock(&mutex);
 285
 286        for (i = 0; i < NUM_RAND_ROUTES; ++i) {
 287                prandom_bytes(ip, 4);
 288                cidr = prandom_u32_max(32) + 1;
 289                peer = peers[prandom_u32_max(NUM_PEERS)];
 290                if (wg_allowedips_insert_v4(&t, (struct in_addr *)ip, cidr,
 291                                            peer, &mutex) < 0) {
 292                        pr_err("allowedips random self-test malloc: FAIL\n");
 293                        goto free_locked;
 294                }
 295                if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip,
 296                                                  cidr, peer) < 0) {
 297                        pr_err("allowedips random self-test malloc: FAIL\n");
 298                        goto free_locked;
 299                }
 300                for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
 301                        memcpy(mutated, ip, 4);
 302                        prandom_bytes(mutate_mask, 4);
 303                        mutate_amount = prandom_u32_max(32);
 304                        for (k = 0; k < mutate_amount / 8; ++k)
 305                                mutate_mask[k] = 0xff;
 306                        mutate_mask[k] = 0xff
 307                                         << ((8 - (mutate_amount % 8)) % 8);
 308                        for (; k < 4; ++k)
 309                                mutate_mask[k] = 0;
 310                        for (k = 0; k < 4; ++k)
 311                                mutated[k] = (mutated[k] & mutate_mask[k]) |
 312                                             (~mutate_mask[k] &
 313                                              prandom_u32_max(256));
 314                        cidr = prandom_u32_max(32) + 1;
 315                        peer = peers[prandom_u32_max(NUM_PEERS)];
 316                        if (wg_allowedips_insert_v4(&t,
 317                                                    (struct in_addr *)mutated,
 318                                                    cidr, peer, &mutex) < 0) {
 319                                pr_err("allowedips random self-test malloc: FAIL\n");
 320                                goto free_locked;
 321                        }
 322                        if (horrible_allowedips_insert_v4(&h,
 323                                (struct in_addr *)mutated, cidr, peer)) {
 324                                pr_err("allowedips random self-test malloc: FAIL\n");
 325                                goto free_locked;
 326                        }
 327                }
 328        }
 329
 330        for (i = 0; i < NUM_RAND_ROUTES; ++i) {
 331                prandom_bytes(ip, 16);
 332                cidr = prandom_u32_max(128) + 1;
 333                peer = peers[prandom_u32_max(NUM_PEERS)];
 334                if (wg_allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr,
 335                                            peer, &mutex) < 0) {
 336                        pr_err("allowedips random self-test malloc: FAIL\n");
 337                        goto free_locked;
 338                }
 339                if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip,
 340                                                  cidr, peer) < 0) {
 341                        pr_err("allowedips random self-test malloc: FAIL\n");
 342                        goto free_locked;
 343                }
 344                for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
 345                        memcpy(mutated, ip, 16);
 346                        prandom_bytes(mutate_mask, 16);
 347                        mutate_amount = prandom_u32_max(128);
 348                        for (k = 0; k < mutate_amount / 8; ++k)
 349                                mutate_mask[k] = 0xff;
 350                        mutate_mask[k] = 0xff
 351                                         << ((8 - (mutate_amount % 8)) % 8);
 352                        for (; k < 4; ++k)
 353                                mutate_mask[k] = 0;
 354                        for (k = 0; k < 4; ++k)
 355                                mutated[k] = (mutated[k] & mutate_mask[k]) |
 356                                             (~mutate_mask[k] &
 357                                              prandom_u32_max(256));
 358                        cidr = prandom_u32_max(128) + 1;
 359                        peer = peers[prandom_u32_max(NUM_PEERS)];
 360                        if (wg_allowedips_insert_v6(&t,
 361                                                    (struct in6_addr *)mutated,
 362                                                    cidr, peer, &mutex) < 0) {
 363                                pr_err("allowedips random self-test malloc: FAIL\n");
 364                                goto free_locked;
 365                        }
 366                        if (horrible_allowedips_insert_v6(
 367                                    &h, (struct in6_addr *)mutated, cidr,
 368                                    peer)) {
 369                                pr_err("allowedips random self-test malloc: FAIL\n");
 370                                goto free_locked;
 371                        }
 372                }
 373        }
 374
 375        mutex_unlock(&mutex);
 376
 377        if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
 378                print_tree(t.root4, 32);
 379                print_tree(t.root6, 128);
 380        }
 381
 382        for (j = 0;; ++j) {
 383                for (i = 0; i < NUM_QUERIES; ++i) {
 384                        prandom_bytes(ip, 4);
 385                        if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
 386                                horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip);
 387                                pr_err("allowedips random v4 self-test: FAIL\n");
 388                                goto free;
 389                        }
 390                        prandom_bytes(ip, 16);
 391                        if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
 392                                pr_err("allowedips random v6 self-test: FAIL\n");
 393                                goto free;
 394                        }
 395                }
 396                if (j >= NUM_PEERS)
 397                        break;
 398                mutex_lock(&mutex);
 399                wg_allowedips_remove_by_peer(&t, peers[j], &mutex);
 400                mutex_unlock(&mutex);
 401                horrible_allowedips_remove_by_value(&h, peers[j]);
 402        }
 403
 404        if (t.root4 || t.root6) {
 405                pr_err("allowedips random self-test removal: FAIL\n");
 406                goto free;
 407        }
 408
 409        ret = true;
 410
 411free:
 412        mutex_lock(&mutex);
 413free_locked:
 414        wg_allowedips_free(&t, &mutex);
 415        mutex_unlock(&mutex);
 416        horrible_allowedips_free(&h);
 417        if (peers) {
 418                for (i = 0; i < NUM_PEERS; ++i)
 419                        kfree(peers[i]);
 420        }
 421        kfree(peers);
 422        return ret;
 423}
 424
 425static __init inline struct in_addr *ip4(u8 a, u8 b, u8 c, u8 d)
 426{
 427        static struct in_addr ip;
 428        u8 *split = (u8 *)&ip;
 429
 430        split[0] = a;
 431        split[1] = b;
 432        split[2] = c;
 433        split[3] = d;
 434        return &ip;
 435}
 436
 437static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
 438{
 439        static struct in6_addr ip;
 440        __be32 *split = (__be32 *)&ip;
 441
 442        split[0] = cpu_to_be32(a);
 443        split[1] = cpu_to_be32(b);
 444        split[2] = cpu_to_be32(c);
 445        split[3] = cpu_to_be32(d);
 446        return &ip;
 447}
 448
 449static __init struct wg_peer *init_peer(void)
 450{
 451        struct wg_peer *peer = kzalloc(sizeof(*peer), GFP_KERNEL);
 452
 453        if (!peer)
 454                return NULL;
 455        kref_init(&peer->refcount);
 456        INIT_LIST_HEAD(&peer->allowedips_list);
 457        return peer;
 458}
 459
 460#define insert(version, mem, ipa, ipb, ipc, ipd, cidr)                       \
 461        wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
 462                                        cidr, mem, &mutex)
 463
 464#define maybe_fail() do {                                               \
 465                ++i;                                                    \
 466                if (!_s) {                                              \
 467                        pr_info("allowedips self-test %zu: FAIL\n", i); \
 468                        success = false;                                \
 469                }                                                       \
 470        } while (0)
 471
 472#define test(version, mem, ipa, ipb, ipc, ipd) do {                          \
 473                bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
 474                                 ip##version(ipa, ipb, ipc, ipd)) == (mem);  \
 475                maybe_fail();                                                \
 476        } while (0)
 477
 478#define test_negative(version, mem, ipa, ipb, ipc, ipd) do {                 \
 479                bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
 480                                 ip##version(ipa, ipb, ipc, ipd)) != (mem);  \
 481                maybe_fail();                                                \
 482        } while (0)
 483
 484#define test_boolean(cond) do {   \
 485                bool _s = (cond); \
 486                maybe_fail();     \
 487        } while (0)
 488
 489bool __init wg_allowedips_selftest(void)
 490{
 491        bool found_a = false, found_b = false, found_c = false, found_d = false,
 492             found_e = false, found_other = false;
 493        struct wg_peer *a = init_peer(), *b = init_peer(), *c = init_peer(),
 494                       *d = init_peer(), *e = init_peer(), *f = init_peer(),
 495                       *g = init_peer(), *h = init_peer();
 496        struct allowedips_node *iter_node;
 497        bool success = false;
 498        struct allowedips t;
 499        DEFINE_MUTEX(mutex);
 500        struct in6_addr ip;
 501        size_t i = 0, count = 0;
 502        __be64 part;
 503
 504        mutex_init(&mutex);
 505        mutex_lock(&mutex);
 506        wg_allowedips_init(&t);
 507
 508        if (!a || !b || !c || !d || !e || !f || !g || !h) {
 509                pr_err("allowedips self-test malloc: FAIL\n");
 510                goto free;
 511        }
 512
 513        insert(4, a, 192, 168, 4, 0, 24);
 514        insert(4, b, 192, 168, 4, 4, 32);
 515        insert(4, c, 192, 168, 0, 0, 16);
 516        insert(4, d, 192, 95, 5, 64, 27);
 517        /* replaces previous entry, and maskself is required */
 518        insert(4, c, 192, 95, 5, 65, 27);
 519        insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
 520        insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
 521        insert(4, e, 0, 0, 0, 0, 0);
 522        insert(6, e, 0, 0, 0, 0, 0);
 523        /* replaces previous entry */
 524        insert(6, f, 0, 0, 0, 0, 0);
 525        insert(6, g, 0x24046800, 0, 0, 0, 32);
 526        /* maskself is required */
 527        insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64);
 528        insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
 529        insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
 530        insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
 531        insert(4, g, 64, 15, 112, 0, 20);
 532        /* maskself is required */
 533        insert(4, h, 64, 15, 123, 211, 25);
 534        insert(4, a, 10, 0, 0, 0, 25);
 535        insert(4, b, 10, 0, 0, 128, 25);
 536        insert(4, a, 10, 1, 0, 0, 30);
 537        insert(4, b, 10, 1, 0, 4, 30);
 538        insert(4, c, 10, 1, 0, 8, 29);
 539        insert(4, d, 10, 1, 0, 16, 29);
 540
 541        if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
 542                print_tree(t.root4, 32);
 543                print_tree(t.root6, 128);
 544        }
 545
 546        success = true;
 547
 548        test(4, a, 192, 168, 4, 20);
 549        test(4, a, 192, 168, 4, 0);
 550        test(4, b, 192, 168, 4, 4);
 551        test(4, c, 192, 168, 200, 182);
 552        test(4, c, 192, 95, 5, 68);
 553        test(4, e, 192, 95, 5, 96);
 554        test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543);
 555        test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee);
 556        test(6, f, 0x26075300, 0x60006b01, 0, 0);
 557        test(6, g, 0x24046800, 0x40040806, 0, 0x1006);
 558        test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678);
 559        test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678);
 560        test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678);
 561        test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678);
 562        test(6, h, 0x24046800, 0x40040800, 0, 0);
 563        test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010);
 564        test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef);
 565        test(4, g, 64, 15, 116, 26);
 566        test(4, g, 64, 15, 127, 3);
 567        test(4, g, 64, 15, 123, 1);
 568        test(4, h, 64, 15, 123, 128);
 569        test(4, h, 64, 15, 123, 129);
 570        test(4, a, 10, 0, 0, 52);
 571        test(4, b, 10, 0, 0, 220);
 572        test(4, a, 10, 1, 0, 2);
 573        test(4, b, 10, 1, 0, 6);
 574        test(4, c, 10, 1, 0, 10);
 575        test(4, d, 10, 1, 0, 20);
 576
 577        insert(4, a, 1, 0, 0, 0, 32);
 578        insert(4, a, 64, 0, 0, 0, 32);
 579        insert(4, a, 128, 0, 0, 0, 32);
 580        insert(4, a, 192, 0, 0, 0, 32);
 581        insert(4, a, 255, 0, 0, 0, 32);
 582        wg_allowedips_remove_by_peer(&t, a, &mutex);
 583        test_negative(4, a, 1, 0, 0, 0);
 584        test_negative(4, a, 64, 0, 0, 0);
 585        test_negative(4, a, 128, 0, 0, 0);
 586        test_negative(4, a, 192, 0, 0, 0);
 587        test_negative(4, a, 255, 0, 0, 0);
 588
 589        wg_allowedips_free(&t, &mutex);
 590        wg_allowedips_init(&t);
 591        insert(4, a, 192, 168, 0, 0, 16);
 592        insert(4, a, 192, 168, 0, 0, 24);
 593        wg_allowedips_remove_by_peer(&t, a, &mutex);
 594        test_negative(4, a, 192, 168, 0, 1);
 595
 596        /* These will hit the WARN_ON(len >= 128) in free_node if something
 597         * goes wrong.
 598         */
 599        for (i = 0; i < 128; ++i) {
 600                part = cpu_to_be64(~(1LLU << (i % 64)));
 601                memset(&ip, 0xff, 16);
 602                memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
 603                wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex);
 604        }
 605
 606        wg_allowedips_free(&t, &mutex);
 607
 608        wg_allowedips_init(&t);
 609        insert(4, a, 192, 95, 5, 93, 27);
 610        insert(6, a, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
 611        insert(4, a, 10, 1, 0, 20, 29);
 612        insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 83);
 613        insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 21);
 614        list_for_each_entry(iter_node, &a->allowedips_list, peer_list) {
 615                u8 cidr, ip[16] __aligned(__alignof(u64));
 616                int family = wg_allowedips_read_node(iter_node, ip, &cidr);
 617
 618                count++;
 619
 620                if (cidr == 27 && family == AF_INET &&
 621                    !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr)))
 622                        found_a = true;
 623                else if (cidr == 128 && family == AF_INET6 &&
 624                         !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543),
 625                                 sizeof(struct in6_addr)))
 626                        found_b = true;
 627                else if (cidr == 29 && family == AF_INET &&
 628                         !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr)))
 629                        found_c = true;
 630                else if (cidr == 83 && family == AF_INET6 &&
 631                         !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0),
 632                                 sizeof(struct in6_addr)))
 633                        found_d = true;
 634                else if (cidr == 21 && family == AF_INET6 &&
 635                         !memcmp(ip, ip6(0x26075000, 0, 0, 0),
 636                                 sizeof(struct in6_addr)))
 637                        found_e = true;
 638                else
 639                        found_other = true;
 640        }
 641        test_boolean(count == 5);
 642        test_boolean(found_a);
 643        test_boolean(found_b);
 644        test_boolean(found_c);
 645        test_boolean(found_d);
 646        test_boolean(found_e);
 647        test_boolean(!found_other);
 648
 649        if (IS_ENABLED(DEBUG_RANDOM_TRIE) && success)
 650                success = randomized_test();
 651
 652        if (success)
 653                pr_info("allowedips self-tests: pass\n");
 654
 655free:
 656        wg_allowedips_free(&t, &mutex);
 657        kfree(a);
 658        kfree(b);
 659        kfree(c);
 660        kfree(d);
 661        kfree(e);
 662        kfree(f);
 663        kfree(g);
 664        kfree(h);
 665        mutex_unlock(&mutex);
 666
 667        return success;
 668}
 669
 670#undef test_negative
 671#undef test
 672#undef remove
 673#undef insert
 674#undef init_peer
 675
 676#endif
 677