linux/lib/rbtree_test.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2#include <linux/module.h>
   3#include <linux/moduleparam.h>
   4#include <linux/rbtree_augmented.h>
   5#include <linux/random.h>
   6#include <linux/slab.h>
   7#include <asm/timex.h>
   8
   9#define __param(type, name, init, msg)          \
  10        static type name = init;                \
  11        module_param(name, type, 0444);         \
  12        MODULE_PARM_DESC(name, msg);
  13
  14__param(int, nnodes, 100, "Number of nodes in the rb-tree");
  15__param(int, perf_loops, 1000, "Number of iterations modifying the rb-tree");
  16__param(int, check_loops, 100, "Number of iterations modifying and verifying the rb-tree");
  17
  18struct test_node {
  19        u32 key;
  20        struct rb_node rb;
  21
  22        /* following fields used for testing augmented rbtree functionality */
  23        u32 val;
  24        u32 augmented;
  25};
  26
  27static struct rb_root_cached root = RB_ROOT_CACHED;
  28static struct test_node *nodes = NULL;
  29
  30static struct rnd_state rnd;
  31
  32static void insert(struct test_node *node, struct rb_root_cached *root)
  33{
  34        struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
  35        u32 key = node->key;
  36
  37        while (*new) {
  38                parent = *new;
  39                if (key < rb_entry(parent, struct test_node, rb)->key)
  40                        new = &parent->rb_left;
  41                else
  42                        new = &parent->rb_right;
  43        }
  44
  45        rb_link_node(&node->rb, parent, new);
  46        rb_insert_color(&node->rb, &root->rb_root);
  47}
  48
  49static void insert_cached(struct test_node *node, struct rb_root_cached *root)
  50{
  51        struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
  52        u32 key = node->key;
  53        bool leftmost = true;
  54
  55        while (*new) {
  56                parent = *new;
  57                if (key < rb_entry(parent, struct test_node, rb)->key)
  58                        new = &parent->rb_left;
  59                else {
  60                        new = &parent->rb_right;
  61                        leftmost = false;
  62                }
  63        }
  64
  65        rb_link_node(&node->rb, parent, new);
  66        rb_insert_color_cached(&node->rb, root, leftmost);
  67}
  68
  69static inline void erase(struct test_node *node, struct rb_root_cached *root)
  70{
  71        rb_erase(&node->rb, &root->rb_root);
  72}
  73
  74static inline void erase_cached(struct test_node *node, struct rb_root_cached *root)
  75{
  76        rb_erase_cached(&node->rb, root);
  77}
  78
  79
  80static inline u32 augment_recompute(struct test_node *node)
  81{
  82        u32 max = node->val, child_augmented;
  83        if (node->rb.rb_left) {
  84                child_augmented = rb_entry(node->rb.rb_left, struct test_node,
  85                                           rb)->augmented;
  86                if (max < child_augmented)
  87                        max = child_augmented;
  88        }
  89        if (node->rb.rb_right) {
  90                child_augmented = rb_entry(node->rb.rb_right, struct test_node,
  91                                           rb)->augmented;
  92                if (max < child_augmented)
  93                        max = child_augmented;
  94        }
  95        return max;
  96}
  97
  98RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb,
  99                     u32, augmented, augment_recompute)
 100
 101static void insert_augmented(struct test_node *node,
 102                             struct rb_root_cached *root)
 103{
 104        struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
 105        u32 key = node->key;
 106        u32 val = node->val;
 107        struct test_node *parent;
 108
 109        while (*new) {
 110                rb_parent = *new;
 111                parent = rb_entry(rb_parent, struct test_node, rb);
 112                if (parent->augmented < val)
 113                        parent->augmented = val;
 114                if (key < parent->key)
 115                        new = &parent->rb.rb_left;
 116                else
 117                        new = &parent->rb.rb_right;
 118        }
 119
 120        node->augmented = val;
 121        rb_link_node(&node->rb, rb_parent, new);
 122        rb_insert_augmented(&node->rb, &root->rb_root, &augment_callbacks);
 123}
 124
 125static void insert_augmented_cached(struct test_node *node,
 126                                    struct rb_root_cached *root)
 127{
 128        struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
 129        u32 key = node->key;
 130        u32 val = node->val;
 131        struct test_node *parent;
 132        bool leftmost = true;
 133
 134        while (*new) {
 135                rb_parent = *new;
 136                parent = rb_entry(rb_parent, struct test_node, rb);
 137                if (parent->augmented < val)
 138                        parent->augmented = val;
 139                if (key < parent->key)
 140                        new = &parent->rb.rb_left;
 141                else {
 142                        new = &parent->rb.rb_right;
 143                        leftmost = false;
 144                }
 145        }
 146
 147        node->augmented = val;
 148        rb_link_node(&node->rb, rb_parent, new);
 149        rb_insert_augmented_cached(&node->rb, root,
 150                                   leftmost, &augment_callbacks);
 151}
 152
 153
 154static void erase_augmented(struct test_node *node, struct rb_root_cached *root)
 155{
 156        rb_erase_augmented(&node->rb, &root->rb_root, &augment_callbacks);
 157}
 158
 159static void erase_augmented_cached(struct test_node *node,
 160                                   struct rb_root_cached *root)
 161{
 162        rb_erase_augmented_cached(&node->rb, root, &augment_callbacks);
 163}
 164
 165static void init(void)
 166{
 167        int i;
 168        for (i = 0; i < nnodes; i++) {
 169                nodes[i].key = prandom_u32_state(&rnd);
 170                nodes[i].val = prandom_u32_state(&rnd);
 171        }
 172}
 173
 174static bool is_red(struct rb_node *rb)
 175{
 176        return !(rb->__rb_parent_color & 1);
 177}
 178
 179static int black_path_count(struct rb_node *rb)
 180{
 181        int count;
 182        for (count = 0; rb; rb = rb_parent(rb))
 183                count += !is_red(rb);
 184        return count;
 185}
 186
 187static void check_postorder_foreach(int nr_nodes)
 188{
 189        struct test_node *cur, *n;
 190        int count = 0;
 191        rbtree_postorder_for_each_entry_safe(cur, n, &root.rb_root, rb)
 192                count++;
 193
 194        WARN_ON_ONCE(count != nr_nodes);
 195}
 196
 197static void check_postorder(int nr_nodes)
 198{
 199        struct rb_node *rb;
 200        int count = 0;
 201        for (rb = rb_first_postorder(&root.rb_root); rb; rb = rb_next_postorder(rb))
 202                count++;
 203
 204        WARN_ON_ONCE(count != nr_nodes);
 205}
 206
 207static void check(int nr_nodes)
 208{
 209        struct rb_node *rb;
 210        int count = 0, blacks = 0;
 211        u32 prev_key = 0;
 212
 213        for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
 214                struct test_node *node = rb_entry(rb, struct test_node, rb);
 215                WARN_ON_ONCE(node->key < prev_key);
 216                WARN_ON_ONCE(is_red(rb) &&
 217                             (!rb_parent(rb) || is_red(rb_parent(rb))));
 218                if (!count)
 219                        blacks = black_path_count(rb);
 220                else
 221                        WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) &&
 222                                     blacks != black_path_count(rb));
 223                prev_key = node->key;
 224                count++;
 225        }
 226
 227        WARN_ON_ONCE(count != nr_nodes);
 228        WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root.rb_root))) - 1);
 229
 230        check_postorder(nr_nodes);
 231        check_postorder_foreach(nr_nodes);
 232}
 233
 234static void check_augmented(int nr_nodes)
 235{
 236        struct rb_node *rb;
 237
 238        check(nr_nodes);
 239        for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
 240                struct test_node *node = rb_entry(rb, struct test_node, rb);
 241                WARN_ON_ONCE(node->augmented != augment_recompute(node));
 242        }
 243}
 244
 245static int __init rbtree_test_init(void)
 246{
 247        int i, j;
 248        cycles_t time1, time2, time;
 249        struct rb_node *node;
 250
 251        nodes = kmalloc_array(nnodes, sizeof(*nodes), GFP_KERNEL);
 252        if (!nodes)
 253                return -ENOMEM;
 254
 255        printk(KERN_ALERT "rbtree testing");
 256
 257        prandom_seed_state(&rnd, 3141592653589793238ULL);
 258        init();
 259
 260        time1 = get_cycles();
 261
 262        for (i = 0; i < perf_loops; i++) {
 263                for (j = 0; j < nnodes; j++)
 264                        insert(nodes + j, &root);
 265                for (j = 0; j < nnodes; j++)
 266                        erase(nodes + j, &root);
 267        }
 268
 269        time2 = get_cycles();
 270        time = time2 - time1;
 271
 272        time = div_u64(time, perf_loops);
 273        printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n",
 274               (unsigned long long)time);
 275
 276        time1 = get_cycles();
 277
 278        for (i = 0; i < perf_loops; i++) {
 279                for (j = 0; j < nnodes; j++)
 280                        insert_cached(nodes + j, &root);
 281                for (j = 0; j < nnodes; j++)
 282                        erase_cached(nodes + j, &root);
 283        }
 284
 285        time2 = get_cycles();
 286        time = time2 - time1;
 287
 288        time = div_u64(time, perf_loops);
 289        printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n",
 290               (unsigned long long)time);
 291
 292        for (i = 0; i < nnodes; i++)
 293                insert(nodes + i, &root);
 294
 295        time1 = get_cycles();
 296
 297        for (i = 0; i < perf_loops; i++) {
 298                for (node = rb_first(&root.rb_root); node; node = rb_next(node))
 299                        ;
 300        }
 301
 302        time2 = get_cycles();
 303        time = time2 - time1;
 304
 305        time = div_u64(time, perf_loops);
 306        printk(" -> test 3 (latency of inorder traversal): %llu cycles\n",
 307               (unsigned long long)time);
 308
 309        time1 = get_cycles();
 310
 311        for (i = 0; i < perf_loops; i++)
 312                node = rb_first(&root.rb_root);
 313
 314        time2 = get_cycles();
 315        time = time2 - time1;
 316
 317        time = div_u64(time, perf_loops);
 318        printk(" -> test 4 (latency to fetch first node)\n");
 319        printk("        non-cached: %llu cycles\n", (unsigned long long)time);
 320
 321        time1 = get_cycles();
 322
 323        for (i = 0; i < perf_loops; i++)
 324                node = rb_first_cached(&root);
 325
 326        time2 = get_cycles();
 327        time = time2 - time1;
 328
 329        time = div_u64(time, perf_loops);
 330        printk("        cached: %llu cycles\n", (unsigned long long)time);
 331
 332        for (i = 0; i < nnodes; i++)
 333                erase(nodes + i, &root);
 334
 335        /* run checks */
 336        for (i = 0; i < check_loops; i++) {
 337                init();
 338                for (j = 0; j < nnodes; j++) {
 339                        check(j);
 340                        insert(nodes + j, &root);
 341                }
 342                for (j = 0; j < nnodes; j++) {
 343                        check(nnodes - j);
 344                        erase(nodes + j, &root);
 345                }
 346                check(0);
 347        }
 348
 349        printk(KERN_ALERT "augmented rbtree testing");
 350
 351        init();
 352
 353        time1 = get_cycles();
 354
 355        for (i = 0; i < perf_loops; i++) {
 356                for (j = 0; j < nnodes; j++)
 357                        insert_augmented(nodes + j, &root);
 358                for (j = 0; j < nnodes; j++)
 359                        erase_augmented(nodes + j, &root);
 360        }
 361
 362        time2 = get_cycles();
 363        time = time2 - time1;
 364
 365        time = div_u64(time, perf_loops);
 366        printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time);
 367
 368        time1 = get_cycles();
 369
 370        for (i = 0; i < perf_loops; i++) {
 371                for (j = 0; j < nnodes; j++)
 372                        insert_augmented_cached(nodes + j, &root);
 373                for (j = 0; j < nnodes; j++)
 374                        erase_augmented_cached(nodes + j, &root);
 375        }
 376
 377        time2 = get_cycles();
 378        time = time2 - time1;
 379
 380        time = div_u64(time, perf_loops);
 381        printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", (unsigned long long)time);
 382
 383        for (i = 0; i < check_loops; i++) {
 384                init();
 385                for (j = 0; j < nnodes; j++) {
 386                        check_augmented(j);
 387                        insert_augmented(nodes + j, &root);
 388                }
 389                for (j = 0; j < nnodes; j++) {
 390                        check_augmented(nnodes - j);
 391                        erase_augmented(nodes + j, &root);
 392                }
 393                check_augmented(0);
 394        }
 395
 396        kfree(nodes);
 397
 398        return -EAGAIN; /* Fail will directly unload the module */
 399}
 400
 401static void __exit rbtree_test_exit(void)
 402{
 403        printk(KERN_ALERT "test exit\n");
 404}
 405
 406module_init(rbtree_test_init)
 407module_exit(rbtree_test_exit)
 408
 409MODULE_LICENSE("GPL");
 410MODULE_AUTHOR("Michel Lespinasse");
 411MODULE_DESCRIPTION("Red Black Tree test");
 412