linux/lib/rbtree_test.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2#include <linux/module.h>
   3#include <linux/moduleparam.h>
   4#include <linux/rbtree_augmented.h>
   5#include <linux/random.h>
   6#include <linux/slab.h>
   7#include <asm/timex.h>
   8
   9#define __param(type, name, init, msg)          \
  10        static type name = init;                \
  11        module_param(name, type, 0444);         \
  12        MODULE_PARM_DESC(name, msg);
  13
  14__param(int, nnodes, 100, "Number of nodes in the rb-tree");
  15__param(int, perf_loops, 1000, "Number of iterations modifying the rb-tree");
  16__param(int, check_loops, 100, "Number of iterations modifying and verifying the rb-tree");
  17
  18struct test_node {
  19        u32 key;
  20        struct rb_node rb;
  21
  22        /* following fields used for testing augmented rbtree functionality */
  23        u32 val;
  24        u32 augmented;
  25};
  26
  27static struct rb_root_cached root = RB_ROOT_CACHED;
  28static struct test_node *nodes = NULL;
  29
  30static struct rnd_state rnd;
  31
  32static void insert(struct test_node *node, struct rb_root_cached *root)
  33{
  34        struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
  35        u32 key = node->key;
  36
  37        while (*new) {
  38                parent = *new;
  39                if (key < rb_entry(parent, struct test_node, rb)->key)
  40                        new = &parent->rb_left;
  41                else
  42                        new = &parent->rb_right;
  43        }
  44
  45        rb_link_node(&node->rb, parent, new);
  46        rb_insert_color(&node->rb, &root->rb_root);
  47}
  48
  49static void insert_cached(struct test_node *node, struct rb_root_cached *root)
  50{
  51        struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
  52        u32 key = node->key;
  53        bool leftmost = true;
  54
  55        while (*new) {
  56                parent = *new;
  57                if (key < rb_entry(parent, struct test_node, rb)->key)
  58                        new = &parent->rb_left;
  59                else {
  60                        new = &parent->rb_right;
  61                        leftmost = false;
  62                }
  63        }
  64
  65        rb_link_node(&node->rb, parent, new);
  66        rb_insert_color_cached(&node->rb, root, leftmost);
  67}
  68
  69static inline void erase(struct test_node *node, struct rb_root_cached *root)
  70{
  71        rb_erase(&node->rb, &root->rb_root);
  72}
  73
  74static inline void erase_cached(struct test_node *node, struct rb_root_cached *root)
  75{
  76        rb_erase_cached(&node->rb, root);
  77}
  78
  79
  80#define NODE_VAL(node) ((node)->val)
  81
  82RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks,
  83                         struct test_node, rb, u32, augmented, NODE_VAL)
  84
  85static void insert_augmented(struct test_node *node,
  86                             struct rb_root_cached *root)
  87{
  88        struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
  89        u32 key = node->key;
  90        u32 val = node->val;
  91        struct test_node *parent;
  92
  93        while (*new) {
  94                rb_parent = *new;
  95                parent = rb_entry(rb_parent, struct test_node, rb);
  96                if (parent->augmented < val)
  97                        parent->augmented = val;
  98                if (key < parent->key)
  99                        new = &parent->rb.rb_left;
 100                else
 101                        new = &parent->rb.rb_right;
 102        }
 103
 104        node->augmented = val;
 105        rb_link_node(&node->rb, rb_parent, new);
 106        rb_insert_augmented(&node->rb, &root->rb_root, &augment_callbacks);
 107}
 108
 109static void insert_augmented_cached(struct test_node *node,
 110                                    struct rb_root_cached *root)
 111{
 112        struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
 113        u32 key = node->key;
 114        u32 val = node->val;
 115        struct test_node *parent;
 116        bool leftmost = true;
 117
 118        while (*new) {
 119                rb_parent = *new;
 120                parent = rb_entry(rb_parent, struct test_node, rb);
 121                if (parent->augmented < val)
 122                        parent->augmented = val;
 123                if (key < parent->key)
 124                        new = &parent->rb.rb_left;
 125                else {
 126                        new = &parent->rb.rb_right;
 127                        leftmost = false;
 128                }
 129        }
 130
 131        node->augmented = val;
 132        rb_link_node(&node->rb, rb_parent, new);
 133        rb_insert_augmented_cached(&node->rb, root,
 134                                   leftmost, &augment_callbacks);
 135}
 136
 137
 138static void erase_augmented(struct test_node *node, struct rb_root_cached *root)
 139{
 140        rb_erase_augmented(&node->rb, &root->rb_root, &augment_callbacks);
 141}
 142
 143static void erase_augmented_cached(struct test_node *node,
 144                                   struct rb_root_cached *root)
 145{
 146        rb_erase_augmented_cached(&node->rb, root, &augment_callbacks);
 147}
 148
 149static void init(void)
 150{
 151        int i;
 152        for (i = 0; i < nnodes; i++) {
 153                nodes[i].key = prandom_u32_state(&rnd);
 154                nodes[i].val = prandom_u32_state(&rnd);
 155        }
 156}
 157
 158static bool is_red(struct rb_node *rb)
 159{
 160        return !(rb->__rb_parent_color & 1);
 161}
 162
 163static int black_path_count(struct rb_node *rb)
 164{
 165        int count;
 166        for (count = 0; rb; rb = rb_parent(rb))
 167                count += !is_red(rb);
 168        return count;
 169}
 170
 171static void check_postorder_foreach(int nr_nodes)
 172{
 173        struct test_node *cur, *n;
 174        int count = 0;
 175        rbtree_postorder_for_each_entry_safe(cur, n, &root.rb_root, rb)
 176                count++;
 177
 178        WARN_ON_ONCE(count != nr_nodes);
 179}
 180
 181static void check_postorder(int nr_nodes)
 182{
 183        struct rb_node *rb;
 184        int count = 0;
 185        for (rb = rb_first_postorder(&root.rb_root); rb; rb = rb_next_postorder(rb))
 186                count++;
 187
 188        WARN_ON_ONCE(count != nr_nodes);
 189}
 190
 191static void check(int nr_nodes)
 192{
 193        struct rb_node *rb;
 194        int count = 0, blacks = 0;
 195        u32 prev_key = 0;
 196
 197        for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
 198                struct test_node *node = rb_entry(rb, struct test_node, rb);
 199                WARN_ON_ONCE(node->key < prev_key);
 200                WARN_ON_ONCE(is_red(rb) &&
 201                             (!rb_parent(rb) || is_red(rb_parent(rb))));
 202                if (!count)
 203                        blacks = black_path_count(rb);
 204                else
 205                        WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) &&
 206                                     blacks != black_path_count(rb));
 207                prev_key = node->key;
 208                count++;
 209        }
 210
 211        WARN_ON_ONCE(count != nr_nodes);
 212        WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root.rb_root))) - 1);
 213
 214        check_postorder(nr_nodes);
 215        check_postorder_foreach(nr_nodes);
 216}
 217
 218static void check_augmented(int nr_nodes)
 219{
 220        struct rb_node *rb;
 221
 222        check(nr_nodes);
 223        for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
 224                struct test_node *node = rb_entry(rb, struct test_node, rb);
 225                u32 subtree, max = node->val;
 226                if (node->rb.rb_left) {
 227                        subtree = rb_entry(node->rb.rb_left, struct test_node,
 228                                           rb)->augmented;
 229                        if (max < subtree)
 230                                max = subtree;
 231                }
 232                if (node->rb.rb_right) {
 233                        subtree = rb_entry(node->rb.rb_right, struct test_node,
 234                                           rb)->augmented;
 235                        if (max < subtree)
 236                                max = subtree;
 237                }
 238                WARN_ON_ONCE(node->augmented != max);
 239        }
 240}
 241
 242static int __init rbtree_test_init(void)
 243{
 244        int i, j;
 245        cycles_t time1, time2, time;
 246        struct rb_node *node;
 247
 248        nodes = kmalloc_array(nnodes, sizeof(*nodes), GFP_KERNEL);
 249        if (!nodes)
 250                return -ENOMEM;
 251
 252        printk(KERN_ALERT "rbtree testing");
 253
 254        prandom_seed_state(&rnd, 3141592653589793238ULL);
 255        init();
 256
 257        time1 = get_cycles();
 258
 259        for (i = 0; i < perf_loops; i++) {
 260                for (j = 0; j < nnodes; j++)
 261                        insert(nodes + j, &root);
 262                for (j = 0; j < nnodes; j++)
 263                        erase(nodes + j, &root);
 264        }
 265
 266        time2 = get_cycles();
 267        time = time2 - time1;
 268
 269        time = div_u64(time, perf_loops);
 270        printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n",
 271               (unsigned long long)time);
 272
 273        time1 = get_cycles();
 274
 275        for (i = 0; i < perf_loops; i++) {
 276                for (j = 0; j < nnodes; j++)
 277                        insert_cached(nodes + j, &root);
 278                for (j = 0; j < nnodes; j++)
 279                        erase_cached(nodes + j, &root);
 280        }
 281
 282        time2 = get_cycles();
 283        time = time2 - time1;
 284
 285        time = div_u64(time, perf_loops);
 286        printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n",
 287               (unsigned long long)time);
 288
 289        for (i = 0; i < nnodes; i++)
 290                insert(nodes + i, &root);
 291
 292        time1 = get_cycles();
 293
 294        for (i = 0; i < perf_loops; i++) {
 295                for (node = rb_first(&root.rb_root); node; node = rb_next(node))
 296                        ;
 297        }
 298
 299        time2 = get_cycles();
 300        time = time2 - time1;
 301
 302        time = div_u64(time, perf_loops);
 303        printk(" -> test 3 (latency of inorder traversal): %llu cycles\n",
 304               (unsigned long long)time);
 305
 306        time1 = get_cycles();
 307
 308        for (i = 0; i < perf_loops; i++)
 309                node = rb_first(&root.rb_root);
 310
 311        time2 = get_cycles();
 312        time = time2 - time1;
 313
 314        time = div_u64(time, perf_loops);
 315        printk(" -> test 4 (latency to fetch first node)\n");
 316        printk("        non-cached: %llu cycles\n", (unsigned long long)time);
 317
 318        time1 = get_cycles();
 319
 320        for (i = 0; i < perf_loops; i++)
 321                node = rb_first_cached(&root);
 322
 323        time2 = get_cycles();
 324        time = time2 - time1;
 325
 326        time = div_u64(time, perf_loops);
 327        printk("        cached: %llu cycles\n", (unsigned long long)time);
 328
 329        for (i = 0; i < nnodes; i++)
 330                erase(nodes + i, &root);
 331
 332        /* run checks */
 333        for (i = 0; i < check_loops; i++) {
 334                init();
 335                for (j = 0; j < nnodes; j++) {
 336                        check(j);
 337                        insert(nodes + j, &root);
 338                }
 339                for (j = 0; j < nnodes; j++) {
 340                        check(nnodes - j);
 341                        erase(nodes + j, &root);
 342                }
 343                check(0);
 344        }
 345
 346        printk(KERN_ALERT "augmented rbtree testing");
 347
 348        init();
 349
 350        time1 = get_cycles();
 351
 352        for (i = 0; i < perf_loops; i++) {
 353                for (j = 0; j < nnodes; j++)
 354                        insert_augmented(nodes + j, &root);
 355                for (j = 0; j < nnodes; j++)
 356                        erase_augmented(nodes + j, &root);
 357        }
 358
 359        time2 = get_cycles();
 360        time = time2 - time1;
 361
 362        time = div_u64(time, perf_loops);
 363        printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time);
 364
 365        time1 = get_cycles();
 366
 367        for (i = 0; i < perf_loops; i++) {
 368                for (j = 0; j < nnodes; j++)
 369                        insert_augmented_cached(nodes + j, &root);
 370                for (j = 0; j < nnodes; j++)
 371                        erase_augmented_cached(nodes + j, &root);
 372        }
 373
 374        time2 = get_cycles();
 375        time = time2 - time1;
 376
 377        time = div_u64(time, perf_loops);
 378        printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", (unsigned long long)time);
 379
 380        for (i = 0; i < check_loops; i++) {
 381                init();
 382                for (j = 0; j < nnodes; j++) {
 383                        check_augmented(j);
 384                        insert_augmented(nodes + j, &root);
 385                }
 386                for (j = 0; j < nnodes; j++) {
 387                        check_augmented(nnodes - j);
 388                        erase_augmented(nodes + j, &root);
 389                }
 390                check_augmented(0);
 391        }
 392
 393        kfree(nodes);
 394
 395        return -EAGAIN; /* Fail will directly unload the module */
 396}
 397
 398static void __exit rbtree_test_exit(void)
 399{
 400        printk(KERN_ALERT "test exit\n");
 401}
 402
 403module_init(rbtree_test_init)
 404module_exit(rbtree_test_exit)
 405
 406MODULE_LICENSE("GPL");
 407MODULE_AUTHOR("Michel Lespinasse");
 408MODULE_DESCRIPTION("Red Black Tree test");
 409