linux/drivers/net/wireguard/selftest/allowedips.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
   4 *
   5 * This contains some basic static unit tests for the allowedips data structure.
   6 * It also has two additional modes that are disabled and meant to be used by
   7 * folks directly playing with this file. If you define the macro
   8 * DEBUG_PRINT_TRIE_GRAPHVIZ to be 1, then every time there's a full tree in
   9 * memory, it will be printed out as KERN_DEBUG in a format that can be passed
  10 * to graphviz (the dot command) to visualize it. If you define the macro
  11 * DEBUG_RANDOM_TRIE to be 1, then there will be an extremely costly set of
  12 * randomized tests done against a trivial implementation, which may take
  13 * upwards of a half-hour to complete. There's no set of users who should be
  14 * enabling these, and the only developers that should go anywhere near these
  15 * nobs are the ones who are reading this comment.
  16 */
  17
  18#ifdef DEBUG
  19
  20#include <linux/siphash.h>
  21
  22static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits,
  23                                              u8 cidr)
  24{
  25        swap_endian(dst, src, bits);
  26        memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8);
  27        if (cidr)
  28                dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8);
  29}
  30
  31static __init void print_node(struct allowedips_node *node, u8 bits)
  32{
  33        char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
  34        char *fmt_declaration = KERN_DEBUG
  35                "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
  36        char *style = "dotted";
  37        u8 ip1[16], ip2[16];
  38        u32 color = 0;
  39
  40        if (bits == 32) {
  41                fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
  42                fmt_declaration = KERN_DEBUG
  43                        "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
  44        } else if (bits == 128) {
  45                fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
  46                fmt_declaration = KERN_DEBUG
  47                        "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
  48        }
  49        if (node->peer) {
  50                hsiphash_key_t key = { { 0 } };
  51
  52                memcpy(&key, &node->peer, sizeof(node->peer));
  53                color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 |
  54                        hsiphash_1u32(0xbabecafe, &key) % 200 << 8 |
  55                        hsiphash_1u32(0xabad1dea, &key) % 200;
  56                style = "bold";
  57        }
  58        swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr);
  59        printk(fmt_declaration, ip1, node->cidr, style, color);
  60        if (node->bit[0]) {
  61                swap_endian_and_apply_cidr(ip2,
  62                                rcu_dereference_raw(node->bit[0])->bits, bits,
  63                                node->cidr);
  64                printk(fmt_connection, ip1, node->cidr, ip2,
  65                       rcu_dereference_raw(node->bit[0])->cidr);
  66                print_node(rcu_dereference_raw(node->bit[0]), bits);
  67        }
  68        if (node->bit[1]) {
  69                swap_endian_and_apply_cidr(ip2,
  70                                rcu_dereference_raw(node->bit[1])->bits,
  71                                bits, node->cidr);
  72                printk(fmt_connection, ip1, node->cidr, ip2,
  73                       rcu_dereference_raw(node->bit[1])->cidr);
  74                print_node(rcu_dereference_raw(node->bit[1]), bits);
  75        }
  76}
  77
  78static __init void print_tree(struct allowedips_node __rcu *top, u8 bits)
  79{
  80        printk(KERN_DEBUG "digraph trie {\n");
  81        print_node(rcu_dereference_raw(top), bits);
  82        printk(KERN_DEBUG "}\n");
  83}
  84
  85enum {
  86        NUM_PEERS = 2000,
  87        NUM_RAND_ROUTES = 400,
  88        NUM_MUTATED_ROUTES = 100,
  89        NUM_QUERIES = NUM_RAND_ROUTES * NUM_MUTATED_ROUTES * 30
  90};
  91
  92struct horrible_allowedips {
  93        struct hlist_head head;
  94};
  95
  96struct horrible_allowedips_node {
  97        struct hlist_node table;
  98        union nf_inet_addr ip;
  99        union nf_inet_addr mask;
 100        u8 ip_version;
 101        void *value;
 102};
 103
 104static __init void horrible_allowedips_init(struct horrible_allowedips *table)
 105{
 106        INIT_HLIST_HEAD(&table->head);
 107}
 108
 109static __init void horrible_allowedips_free(struct horrible_allowedips *table)
 110{
 111        struct horrible_allowedips_node *node;
 112        struct hlist_node *h;
 113
 114        hlist_for_each_entry_safe(node, h, &table->head, table) {
 115                hlist_del(&node->table);
 116                kfree(node);
 117        }
 118}
 119
 120static __init inline union nf_inet_addr horrible_cidr_to_mask(u8 cidr)
 121{
 122        union nf_inet_addr mask;
 123
 124        memset(&mask, 0x00, 128 / 8);
 125        memset(&mask, 0xff, cidr / 8);
 126        if (cidr % 32)
 127                mask.all[cidr / 32] = (__force u32)htonl(
 128                        (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
 129        return mask;
 130}
 131
 132static __init inline u8 horrible_mask_to_cidr(union nf_inet_addr subnet)
 133{
 134        return hweight32(subnet.all[0]) + hweight32(subnet.all[1]) +
 135               hweight32(subnet.all[2]) + hweight32(subnet.all[3]);
 136}
 137
 138static __init inline void
 139horrible_mask_self(struct horrible_allowedips_node *node)
 140{
 141        if (node->ip_version == 4) {
 142                node->ip.ip &= node->mask.ip;
 143        } else if (node->ip_version == 6) {
 144                node->ip.ip6[0] &= node->mask.ip6[0];
 145                node->ip.ip6[1] &= node->mask.ip6[1];
 146                node->ip.ip6[2] &= node->mask.ip6[2];
 147                node->ip.ip6[3] &= node->mask.ip6[3];
 148        }
 149}
 150
 151static __init inline bool
 152horrible_match_v4(const struct horrible_allowedips_node *node,
 153                  struct in_addr *ip)
 154{
 155        return (ip->s_addr & node->mask.ip) == node->ip.ip;
 156}
 157
 158static __init inline bool
 159horrible_match_v6(const struct horrible_allowedips_node *node,
 160                  struct in6_addr *ip)
 161{
 162        return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) ==
 163                       node->ip.ip6[0] &&
 164               (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) ==
 165                       node->ip.ip6[1] &&
 166               (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) ==
 167                       node->ip.ip6[2] &&
 168               (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
 169}
 170
 171static __init void
 172horrible_insert_ordered(struct horrible_allowedips *table,
 173                        struct horrible_allowedips_node *node)
 174{
 175        struct horrible_allowedips_node *other = NULL, *where = NULL;
 176        u8 my_cidr = horrible_mask_to_cidr(node->mask);
 177
 178        hlist_for_each_entry(other, &table->head, table) {
 179                if (!memcmp(&other->mask, &node->mask,
 180                            sizeof(union nf_inet_addr)) &&
 181                    !memcmp(&other->ip, &node->ip,
 182                            sizeof(union nf_inet_addr)) &&
 183                    other->ip_version == node->ip_version) {
 184                        other->value = node->value;
 185                        kfree(node);
 186                        return;
 187                }
 188                where = other;
 189                if (horrible_mask_to_cidr(other->mask) <= my_cidr)
 190                        break;
 191        }
 192        if (!other && !where)
 193                hlist_add_head(&node->table, &table->head);
 194        else if (!other)
 195                hlist_add_behind(&node->table, &where->table);
 196        else
 197                hlist_add_before(&node->table, &where->table);
 198}
 199
 200static __init int
 201horrible_allowedips_insert_v4(struct horrible_allowedips *table,
 202                              struct in_addr *ip, u8 cidr, void *value)
 203{
 204        struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
 205                                                        GFP_KERNEL);
 206
 207        if (unlikely(!node))
 208                return -ENOMEM;
 209        node->ip.in = *ip;
 210        node->mask = horrible_cidr_to_mask(cidr);
 211        node->ip_version = 4;
 212        node->value = value;
 213        horrible_mask_self(node);
 214        horrible_insert_ordered(table, node);
 215        return 0;
 216}
 217
 218static __init int
 219horrible_allowedips_insert_v6(struct horrible_allowedips *table,
 220                              struct in6_addr *ip, u8 cidr, void *value)
 221{
 222        struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
 223                                                        GFP_KERNEL);
 224
 225        if (unlikely(!node))
 226                return -ENOMEM;
 227        node->ip.in6 = *ip;
 228        node->mask = horrible_cidr_to_mask(cidr);
 229        node->ip_version = 6;
 230        node->value = value;
 231        horrible_mask_self(node);
 232        horrible_insert_ordered(table, node);
 233        return 0;
 234}
 235
 236static __init void *
 237horrible_allowedips_lookup_v4(struct horrible_allowedips *table,
 238                              struct in_addr *ip)
 239{
 240        struct horrible_allowedips_node *node;
 241        void *ret = NULL;
 242
 243        hlist_for_each_entry(node, &table->head, table) {
 244                if (node->ip_version != 4)
 245                        continue;
 246                if (horrible_match_v4(node, ip)) {
 247                        ret = node->value;
 248                        break;
 249                }
 250        }
 251        return ret;
 252}
 253
 254static __init void *
 255horrible_allowedips_lookup_v6(struct horrible_allowedips *table,
 256                              struct in6_addr *ip)
 257{
 258        struct horrible_allowedips_node *node;
 259        void *ret = NULL;
 260
 261        hlist_for_each_entry(node, &table->head, table) {
 262                if (node->ip_version != 6)
 263                        continue;
 264                if (horrible_match_v6(node, ip)) {
 265                        ret = node->value;
 266                        break;
 267                }
 268        }
 269        return ret;
 270}
 271
 272static __init bool randomized_test(void)
 273{
 274        unsigned int i, j, k, mutate_amount, cidr;
 275        u8 ip[16], mutate_mask[16], mutated[16];
 276        struct wg_peer **peers, *peer;
 277        struct horrible_allowedips h;
 278        DEFINE_MUTEX(mutex);
 279        struct allowedips t;
 280        bool ret = false;
 281
 282        mutex_init(&mutex);
 283
 284        wg_allowedips_init(&t);
 285        horrible_allowedips_init(&h);
 286
 287        peers = kcalloc(NUM_PEERS, sizeof(*peers), GFP_KERNEL);
 288        if (unlikely(!peers)) {
 289                pr_err("allowedips random self-test malloc: FAIL\n");
 290                goto free;
 291        }
 292        for (i = 0; i < NUM_PEERS; ++i) {
 293                peers[i] = kzalloc(sizeof(*peers[i]), GFP_KERNEL);
 294                if (unlikely(!peers[i])) {
 295                        pr_err("allowedips random self-test malloc: FAIL\n");
 296                        goto free;
 297                }
 298                kref_init(&peers[i]->refcount);
 299        }
 300
 301        mutex_lock(&mutex);
 302
 303        for (i = 0; i < NUM_RAND_ROUTES; ++i) {
 304                prandom_bytes(ip, 4);
 305                cidr = prandom_u32_max(32) + 1;
 306                peer = peers[prandom_u32_max(NUM_PEERS)];
 307                if (wg_allowedips_insert_v4(&t, (struct in_addr *)ip, cidr,
 308                                            peer, &mutex) < 0) {
 309                        pr_err("allowedips random self-test malloc: FAIL\n");
 310                        goto free_locked;
 311                }
 312                if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip,
 313                                                  cidr, peer) < 0) {
 314                        pr_err("allowedips random self-test malloc: FAIL\n");
 315                        goto free_locked;
 316                }
 317                for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
 318                        memcpy(mutated, ip, 4);
 319                        prandom_bytes(mutate_mask, 4);
 320                        mutate_amount = prandom_u32_max(32);
 321                        for (k = 0; k < mutate_amount / 8; ++k)
 322                                mutate_mask[k] = 0xff;
 323                        mutate_mask[k] = 0xff
 324                                         << ((8 - (mutate_amount % 8)) % 8);
 325                        for (; k < 4; ++k)
 326                                mutate_mask[k] = 0;
 327                        for (k = 0; k < 4; ++k)
 328                                mutated[k] = (mutated[k] & mutate_mask[k]) |
 329                                             (~mutate_mask[k] &
 330                                              prandom_u32_max(256));
 331                        cidr = prandom_u32_max(32) + 1;
 332                        peer = peers[prandom_u32_max(NUM_PEERS)];
 333                        if (wg_allowedips_insert_v4(&t,
 334                                                    (struct in_addr *)mutated,
 335                                                    cidr, peer, &mutex) < 0) {
 336                                pr_err("allowedips random malloc: FAIL\n");
 337                                goto free_locked;
 338                        }
 339                        if (horrible_allowedips_insert_v4(&h,
 340                                (struct in_addr *)mutated, cidr, peer)) {
 341                                pr_err("allowedips random self-test malloc: FAIL\n");
 342                                goto free_locked;
 343                        }
 344                }
 345        }
 346
 347        for (i = 0; i < NUM_RAND_ROUTES; ++i) {
 348                prandom_bytes(ip, 16);
 349                cidr = prandom_u32_max(128) + 1;
 350                peer = peers[prandom_u32_max(NUM_PEERS)];
 351                if (wg_allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr,
 352                                            peer, &mutex) < 0) {
 353                        pr_err("allowedips random self-test malloc: FAIL\n");
 354                        goto free_locked;
 355                }
 356                if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip,
 357                                                  cidr, peer) < 0) {
 358                        pr_err("allowedips random self-test malloc: FAIL\n");
 359                        goto free_locked;
 360                }
 361                for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
 362                        memcpy(mutated, ip, 16);
 363                        prandom_bytes(mutate_mask, 16);
 364                        mutate_amount = prandom_u32_max(128);
 365                        for (k = 0; k < mutate_amount / 8; ++k)
 366                                mutate_mask[k] = 0xff;
 367                        mutate_mask[k] = 0xff
 368                                         << ((8 - (mutate_amount % 8)) % 8);
 369                        for (; k < 4; ++k)
 370                                mutate_mask[k] = 0;
 371                        for (k = 0; k < 4; ++k)
 372                                mutated[k] = (mutated[k] & mutate_mask[k]) |
 373                                             (~mutate_mask[k] &
 374                                              prandom_u32_max(256));
 375                        cidr = prandom_u32_max(128) + 1;
 376                        peer = peers[prandom_u32_max(NUM_PEERS)];
 377                        if (wg_allowedips_insert_v6(&t,
 378                                                    (struct in6_addr *)mutated,
 379                                                    cidr, peer, &mutex) < 0) {
 380                                pr_err("allowedips random self-test malloc: FAIL\n");
 381                                goto free_locked;
 382                        }
 383                        if (horrible_allowedips_insert_v6(
 384                                    &h, (struct in6_addr *)mutated, cidr,
 385                                    peer)) {
 386                                pr_err("allowedips random self-test malloc: FAIL\n");
 387                                goto free_locked;
 388                        }
 389                }
 390        }
 391
 392        mutex_unlock(&mutex);
 393
 394        if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
 395                print_tree(t.root4, 32);
 396                print_tree(t.root6, 128);
 397        }
 398
 399        for (i = 0; i < NUM_QUERIES; ++i) {
 400                prandom_bytes(ip, 4);
 401                if (lookup(t.root4, 32, ip) !=
 402                    horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
 403                        pr_err("allowedips random self-test: FAIL\n");
 404                        goto free;
 405                }
 406        }
 407
 408        for (i = 0; i < NUM_QUERIES; ++i) {
 409                prandom_bytes(ip, 16);
 410                if (lookup(t.root6, 128, ip) !=
 411                    horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
 412                        pr_err("allowedips random self-test: FAIL\n");
 413                        goto free;
 414                }
 415        }
 416        ret = true;
 417
 418free:
 419        mutex_lock(&mutex);
 420free_locked:
 421        wg_allowedips_free(&t, &mutex);
 422        mutex_unlock(&mutex);
 423        horrible_allowedips_free(&h);
 424        if (peers) {
 425                for (i = 0; i < NUM_PEERS; ++i)
 426                        kfree(peers[i]);
 427        }
 428        kfree(peers);
 429        return ret;
 430}
 431
 432static __init inline struct in_addr *ip4(u8 a, u8 b, u8 c, u8 d)
 433{
 434        static struct in_addr ip;
 435        u8 *split = (u8 *)&ip;
 436
 437        split[0] = a;
 438        split[1] = b;
 439        split[2] = c;
 440        split[3] = d;
 441        return &ip;
 442}
 443
 444static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
 445{
 446        static struct in6_addr ip;
 447        __be32 *split = (__be32 *)&ip;
 448
 449        split[0] = cpu_to_be32(a);
 450        split[1] = cpu_to_be32(b);
 451        split[2] = cpu_to_be32(c);
 452        split[3] = cpu_to_be32(d);
 453        return &ip;
 454}
 455
 456static __init struct wg_peer *init_peer(void)
 457{
 458        struct wg_peer *peer = kzalloc(sizeof(*peer), GFP_KERNEL);
 459
 460        if (!peer)
 461                return NULL;
 462        kref_init(&peer->refcount);
 463        INIT_LIST_HEAD(&peer->allowedips_list);
 464        return peer;
 465}
 466
 467#define insert(version, mem, ipa, ipb, ipc, ipd, cidr)                       \
 468        wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
 469                                        cidr, mem, &mutex)
 470
 471#define maybe_fail() do {                                               \
 472                ++i;                                                    \
 473                if (!_s) {                                              \
 474                        pr_info("allowedips self-test %zu: FAIL\n", i); \
 475                        success = false;                                \
 476                }                                                       \
 477        } while (0)
 478
 479#define test(version, mem, ipa, ipb, ipc, ipd) do {                          \
 480                bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
 481                                 ip##version(ipa, ipb, ipc, ipd)) == (mem);  \
 482                maybe_fail();                                                \
 483        } while (0)
 484
 485#define test_negative(version, mem, ipa, ipb, ipc, ipd) do {                 \
 486                bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
 487                                 ip##version(ipa, ipb, ipc, ipd)) != (mem);  \
 488                maybe_fail();                                                \
 489        } while (0)
 490
 491#define test_boolean(cond) do {   \
 492                bool _s = (cond); \
 493                maybe_fail();     \
 494        } while (0)
 495
 496bool __init wg_allowedips_selftest(void)
 497{
 498        bool found_a = false, found_b = false, found_c = false, found_d = false,
 499             found_e = false, found_other = false;
 500        struct wg_peer *a = init_peer(), *b = init_peer(), *c = init_peer(),
 501                       *d = init_peer(), *e = init_peer(), *f = init_peer(),
 502                       *g = init_peer(), *h = init_peer();
 503        struct allowedips_node *iter_node;
 504        bool success = false;
 505        struct allowedips t;
 506        DEFINE_MUTEX(mutex);
 507        struct in6_addr ip;
 508        size_t i = 0, count = 0;
 509        __be64 part;
 510
 511        mutex_init(&mutex);
 512        mutex_lock(&mutex);
 513        wg_allowedips_init(&t);
 514
 515        if (!a || !b || !c || !d || !e || !f || !g || !h) {
 516                pr_err("allowedips self-test malloc: FAIL\n");
 517                goto free;
 518        }
 519
 520        insert(4, a, 192, 168, 4, 0, 24);
 521        insert(4, b, 192, 168, 4, 4, 32);
 522        insert(4, c, 192, 168, 0, 0, 16);
 523        insert(4, d, 192, 95, 5, 64, 27);
 524        /* replaces previous entry, and maskself is required */
 525        insert(4, c, 192, 95, 5, 65, 27);
 526        insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
 527        insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
 528        insert(4, e, 0, 0, 0, 0, 0);
 529        insert(6, e, 0, 0, 0, 0, 0);
 530        /* replaces previous entry */
 531        insert(6, f, 0, 0, 0, 0, 0);
 532        insert(6, g, 0x24046800, 0, 0, 0, 32);
 533        /* maskself is required */
 534        insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64);
 535        insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
 536        insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
 537        insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
 538        insert(4, g, 64, 15, 112, 0, 20);
 539        /* maskself is required */
 540        insert(4, h, 64, 15, 123, 211, 25);
 541        insert(4, a, 10, 0, 0, 0, 25);
 542        insert(4, b, 10, 0, 0, 128, 25);
 543        insert(4, a, 10, 1, 0, 0, 30);
 544        insert(4, b, 10, 1, 0, 4, 30);
 545        insert(4, c, 10, 1, 0, 8, 29);
 546        insert(4, d, 10, 1, 0, 16, 29);
 547
 548        if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
 549                print_tree(t.root4, 32);
 550                print_tree(t.root6, 128);
 551        }
 552
 553        success = true;
 554
 555        test(4, a, 192, 168, 4, 20);
 556        test(4, a, 192, 168, 4, 0);
 557        test(4, b, 192, 168, 4, 4);
 558        test(4, c, 192, 168, 200, 182);
 559        test(4, c, 192, 95, 5, 68);
 560        test(4, e, 192, 95, 5, 96);
 561        test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543);
 562        test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee);
 563        test(6, f, 0x26075300, 0x60006b01, 0, 0);
 564        test(6, g, 0x24046800, 0x40040806, 0, 0x1006);
 565        test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678);
 566        test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678);
 567        test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678);
 568        test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678);
 569        test(6, h, 0x24046800, 0x40040800, 0, 0);
 570        test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010);
 571        test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef);
 572        test(4, g, 64, 15, 116, 26);
 573        test(4, g, 64, 15, 127, 3);
 574        test(4, g, 64, 15, 123, 1);
 575        test(4, h, 64, 15, 123, 128);
 576        test(4, h, 64, 15, 123, 129);
 577        test(4, a, 10, 0, 0, 52);
 578        test(4, b, 10, 0, 0, 220);
 579        test(4, a, 10, 1, 0, 2);
 580        test(4, b, 10, 1, 0, 6);
 581        test(4, c, 10, 1, 0, 10);
 582        test(4, d, 10, 1, 0, 20);
 583
 584        insert(4, a, 1, 0, 0, 0, 32);
 585        insert(4, a, 64, 0, 0, 0, 32);
 586        insert(4, a, 128, 0, 0, 0, 32);
 587        insert(4, a, 192, 0, 0, 0, 32);
 588        insert(4, a, 255, 0, 0, 0, 32);
 589        wg_allowedips_remove_by_peer(&t, a, &mutex);
 590        test_negative(4, a, 1, 0, 0, 0);
 591        test_negative(4, a, 64, 0, 0, 0);
 592        test_negative(4, a, 128, 0, 0, 0);
 593        test_negative(4, a, 192, 0, 0, 0);
 594        test_negative(4, a, 255, 0, 0, 0);
 595
 596        wg_allowedips_free(&t, &mutex);
 597        wg_allowedips_init(&t);
 598        insert(4, a, 192, 168, 0, 0, 16);
 599        insert(4, a, 192, 168, 0, 0, 24);
 600        wg_allowedips_remove_by_peer(&t, a, &mutex);
 601        test_negative(4, a, 192, 168, 0, 1);
 602
 603        /* These will hit the WARN_ON(len >= 128) in free_node if something
 604         * goes wrong.
 605         */
 606        for (i = 0; i < 128; ++i) {
 607                part = cpu_to_be64(~(1LLU << (i % 64)));
 608                memset(&ip, 0xff, 16);
 609                memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
 610                wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex);
 611        }
 612
 613        wg_allowedips_free(&t, &mutex);
 614
 615        wg_allowedips_init(&t);
 616        insert(4, a, 192, 95, 5, 93, 27);
 617        insert(6, a, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
 618        insert(4, a, 10, 1, 0, 20, 29);
 619        insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 83);
 620        insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 21);
 621        list_for_each_entry(iter_node, &a->allowedips_list, peer_list) {
 622                u8 cidr, ip[16] __aligned(__alignof(u64));
 623                int family = wg_allowedips_read_node(iter_node, ip, &cidr);
 624
 625                count++;
 626
 627                if (cidr == 27 && family == AF_INET &&
 628                    !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr)))
 629                        found_a = true;
 630                else if (cidr == 128 && family == AF_INET6 &&
 631                         !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543),
 632                                 sizeof(struct in6_addr)))
 633                        found_b = true;
 634                else if (cidr == 29 && family == AF_INET &&
 635                         !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr)))
 636                        found_c = true;
 637                else if (cidr == 83 && family == AF_INET6 &&
 638                         !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0),
 639                                 sizeof(struct in6_addr)))
 640                        found_d = true;
 641                else if (cidr == 21 && family == AF_INET6 &&
 642                         !memcmp(ip, ip6(0x26075000, 0, 0, 0),
 643                                 sizeof(struct in6_addr)))
 644                        found_e = true;
 645                else
 646                        found_other = true;
 647        }
 648        test_boolean(count == 5);
 649        test_boolean(found_a);
 650        test_boolean(found_b);
 651        test_boolean(found_c);
 652        test_boolean(found_d);
 653        test_boolean(found_e);
 654        test_boolean(!found_other);
 655
 656        if (IS_ENABLED(DEBUG_RANDOM_TRIE) && success)
 657                success = randomized_test();
 658
 659        if (success)
 660                pr_info("allowedips self-tests: pass\n");
 661
 662free:
 663        wg_allowedips_free(&t, &mutex);
 664        kfree(a);
 665        kfree(b);
 666        kfree(c);
 667        kfree(d);
 668        kfree(e);
 669        kfree(f);
 670        kfree(g);
 671        kfree(h);
 672        mutex_unlock(&mutex);
 673
 674        return success;
 675}
 676
 677#undef test_negative
 678#undef test
 679#undef remove
 680#undef insert
 681#undef init_peer
 682
 683#endif
 684