linux/lib/rbtree_test.c
<<
>>
Prefs
   1#include <linux/module.h>
   2#include <linux/moduleparam.h>
   3#include <linux/rbtree_augmented.h>
   4#include <linux/random.h>
   5#include <linux/slab.h>
   6#include <asm/timex.h>
   7
   8#define __param(type, name, init, msg)          \
   9        static type name = init;                \
  10        module_param(name, type, 0444);         \
  11        MODULE_PARM_DESC(name, msg);
  12
  13__param(int, nnodes, 100, "Number of nodes in the rb-tree");
  14__param(int, perf_loops, 1000, "Number of iterations modifying the rb-tree");
  15__param(int, check_loops, 100, "Number of iterations modifying and verifying the rb-tree");
  16
  17struct test_node {
  18        u32 key;
  19        struct rb_node rb;
  20
  21        /* following fields used for testing augmented rbtree functionality */
  22        u32 val;
  23        u32 augmented;
  24};
  25
  26static struct rb_root_cached root = RB_ROOT_CACHED;
  27static struct test_node *nodes = NULL;
  28
  29static struct rnd_state rnd;
  30
  31static void insert(struct test_node *node, struct rb_root_cached *root)
  32{
  33        struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
  34        u32 key = node->key;
  35
  36        while (*new) {
  37                parent = *new;
  38                if (key < rb_entry(parent, struct test_node, rb)->key)
  39                        new = &parent->rb_left;
  40                else
  41                        new = &parent->rb_right;
  42        }
  43
  44        rb_link_node(&node->rb, parent, new);
  45        rb_insert_color(&node->rb, &root->rb_root);
  46}
  47
  48static void insert_cached(struct test_node *node, struct rb_root_cached *root)
  49{
  50        struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
  51        u32 key = node->key;
  52        bool leftmost = true;
  53
  54        while (*new) {
  55                parent = *new;
  56                if (key < rb_entry(parent, struct test_node, rb)->key)
  57                        new = &parent->rb_left;
  58                else {
  59                        new = &parent->rb_right;
  60                        leftmost = false;
  61                }
  62        }
  63
  64        rb_link_node(&node->rb, parent, new);
  65        rb_insert_color_cached(&node->rb, root, leftmost);
  66}
  67
  68static inline void erase(struct test_node *node, struct rb_root_cached *root)
  69{
  70        rb_erase(&node->rb, &root->rb_root);
  71}
  72
  73static inline void erase_cached(struct test_node *node, struct rb_root_cached *root)
  74{
  75        rb_erase_cached(&node->rb, root);
  76}
  77
  78
  79static inline u32 augment_recompute(struct test_node *node)
  80{
  81        u32 max = node->val, child_augmented;
  82        if (node->rb.rb_left) {
  83                child_augmented = rb_entry(node->rb.rb_left, struct test_node,
  84                                           rb)->augmented;
  85                if (max < child_augmented)
  86                        max = child_augmented;
  87        }
  88        if (node->rb.rb_right) {
  89                child_augmented = rb_entry(node->rb.rb_right, struct test_node,
  90                                           rb)->augmented;
  91                if (max < child_augmented)
  92                        max = child_augmented;
  93        }
  94        return max;
  95}
  96
  97RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb,
  98                     u32, augmented, augment_recompute)
  99
 100static void insert_augmented(struct test_node *node,
 101                             struct rb_root_cached *root)
 102{
 103        struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
 104        u32 key = node->key;
 105        u32 val = node->val;
 106        struct test_node *parent;
 107
 108        while (*new) {
 109                rb_parent = *new;
 110                parent = rb_entry(rb_parent, struct test_node, rb);
 111                if (parent->augmented < val)
 112                        parent->augmented = val;
 113                if (key < parent->key)
 114                        new = &parent->rb.rb_left;
 115                else
 116                        new = &parent->rb.rb_right;
 117        }
 118
 119        node->augmented = val;
 120        rb_link_node(&node->rb, rb_parent, new);
 121        rb_insert_augmented(&node->rb, &root->rb_root, &augment_callbacks);
 122}
 123
 124static void insert_augmented_cached(struct test_node *node,
 125                                    struct rb_root_cached *root)
 126{
 127        struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
 128        u32 key = node->key;
 129        u32 val = node->val;
 130        struct test_node *parent;
 131        bool leftmost = true;
 132
 133        while (*new) {
 134                rb_parent = *new;
 135                parent = rb_entry(rb_parent, struct test_node, rb);
 136                if (parent->augmented < val)
 137                        parent->augmented = val;
 138                if (key < parent->key)
 139                        new = &parent->rb.rb_left;
 140                else {
 141                        new = &parent->rb.rb_right;
 142                        leftmost = false;
 143                }
 144        }
 145
 146        node->augmented = val;
 147        rb_link_node(&node->rb, rb_parent, new);
 148        rb_insert_augmented_cached(&node->rb, root,
 149                                   leftmost, &augment_callbacks);
 150}
 151
 152
 153static void erase_augmented(struct test_node *node, struct rb_root_cached *root)
 154{
 155        rb_erase_augmented(&node->rb, &root->rb_root, &augment_callbacks);
 156}
 157
 158static void erase_augmented_cached(struct test_node *node,
 159                                   struct rb_root_cached *root)
 160{
 161        rb_erase_augmented_cached(&node->rb, root, &augment_callbacks);
 162}
 163
 164static void init(void)
 165{
 166        int i;
 167        for (i = 0; i < nnodes; i++) {
 168                nodes[i].key = prandom_u32_state(&rnd);
 169                nodes[i].val = prandom_u32_state(&rnd);
 170        }
 171}
 172
 173static bool is_red(struct rb_node *rb)
 174{
 175        return !(rb->__rb_parent_color & 1);
 176}
 177
 178static int black_path_count(struct rb_node *rb)
 179{
 180        int count;
 181        for (count = 0; rb; rb = rb_parent(rb))
 182                count += !is_red(rb);
 183        return count;
 184}
 185
 186static void check_postorder_foreach(int nr_nodes)
 187{
 188        struct test_node *cur, *n;
 189        int count = 0;
 190        rbtree_postorder_for_each_entry_safe(cur, n, &root.rb_root, rb)
 191                count++;
 192
 193        WARN_ON_ONCE(count != nr_nodes);
 194}
 195
 196static void check_postorder(int nr_nodes)
 197{
 198        struct rb_node *rb;
 199        int count = 0;
 200        for (rb = rb_first_postorder(&root.rb_root); rb; rb = rb_next_postorder(rb))
 201                count++;
 202
 203        WARN_ON_ONCE(count != nr_nodes);
 204}
 205
 206static void check(int nr_nodes)
 207{
 208        struct rb_node *rb;
 209        int count = 0, blacks = 0;
 210        u32 prev_key = 0;
 211
 212        for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
 213                struct test_node *node = rb_entry(rb, struct test_node, rb);
 214                WARN_ON_ONCE(node->key < prev_key);
 215                WARN_ON_ONCE(is_red(rb) &&
 216                             (!rb_parent(rb) || is_red(rb_parent(rb))));
 217                if (!count)
 218                        blacks = black_path_count(rb);
 219                else
 220                        WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) &&
 221                                     blacks != black_path_count(rb));
 222                prev_key = node->key;
 223                count++;
 224        }
 225
 226        WARN_ON_ONCE(count != nr_nodes);
 227        WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root.rb_root))) - 1);
 228
 229        check_postorder(nr_nodes);
 230        check_postorder_foreach(nr_nodes);
 231}
 232
 233static void check_augmented(int nr_nodes)
 234{
 235        struct rb_node *rb;
 236
 237        check(nr_nodes);
 238        for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
 239                struct test_node *node = rb_entry(rb, struct test_node, rb);
 240                WARN_ON_ONCE(node->augmented != augment_recompute(node));
 241        }
 242}
 243
 244static int __init rbtree_test_init(void)
 245{
 246        int i, j;
 247        cycles_t time1, time2, time;
 248        struct rb_node *node;
 249
 250        nodes = kmalloc_array(nnodes, sizeof(*nodes), GFP_KERNEL);
 251        if (!nodes)
 252                return -ENOMEM;
 253
 254        printk(KERN_ALERT "rbtree testing");
 255
 256        prandom_seed_state(&rnd, 3141592653589793238ULL);
 257        init();
 258
 259        time1 = get_cycles();
 260
 261        for (i = 0; i < perf_loops; i++) {
 262                for (j = 0; j < nnodes; j++)
 263                        insert(nodes + j, &root);
 264                for (j = 0; j < nnodes; j++)
 265                        erase(nodes + j, &root);
 266        }
 267
 268        time2 = get_cycles();
 269        time = time2 - time1;
 270
 271        time = div_u64(time, perf_loops);
 272        printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n",
 273               (unsigned long long)time);
 274
 275        time1 = get_cycles();
 276
 277        for (i = 0; i < perf_loops; i++) {
 278                for (j = 0; j < nnodes; j++)
 279                        insert_cached(nodes + j, &root);
 280                for (j = 0; j < nnodes; j++)
 281                        erase_cached(nodes + j, &root);
 282        }
 283
 284        time2 = get_cycles();
 285        time = time2 - time1;
 286
 287        time = div_u64(time, perf_loops);
 288        printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n",
 289               (unsigned long long)time);
 290
 291        for (i = 0; i < nnodes; i++)
 292                insert(nodes + i, &root);
 293
 294        time1 = get_cycles();
 295
 296        for (i = 0; i < perf_loops; i++) {
 297                for (node = rb_first(&root.rb_root); node; node = rb_next(node))
 298                        ;
 299        }
 300
 301        time2 = get_cycles();
 302        time = time2 - time1;
 303
 304        time = div_u64(time, perf_loops);
 305        printk(" -> test 3 (latency of inorder traversal): %llu cycles\n",
 306               (unsigned long long)time);
 307
 308        time1 = get_cycles();
 309
 310        for (i = 0; i < perf_loops; i++)
 311                node = rb_first(&root.rb_root);
 312
 313        time2 = get_cycles();
 314        time = time2 - time1;
 315
 316        time = div_u64(time, perf_loops);
 317        printk(" -> test 4 (latency to fetch first node)\n");
 318        printk("        non-cached: %llu cycles\n", (unsigned long long)time);
 319
 320        time1 = get_cycles();
 321
 322        for (i = 0; i < perf_loops; i++)
 323                node = rb_first_cached(&root);
 324
 325        time2 = get_cycles();
 326        time = time2 - time1;
 327
 328        time = div_u64(time, perf_loops);
 329        printk("        cached: %llu cycles\n", (unsigned long long)time);
 330
 331        for (i = 0; i < nnodes; i++)
 332                erase(nodes + i, &root);
 333
 334        /* run checks */
 335        for (i = 0; i < check_loops; i++) {
 336                init();
 337                for (j = 0; j < nnodes; j++) {
 338                        check(j);
 339                        insert(nodes + j, &root);
 340                }
 341                for (j = 0; j < nnodes; j++) {
 342                        check(nnodes - j);
 343                        erase(nodes + j, &root);
 344                }
 345                check(0);
 346        }
 347
 348        printk(KERN_ALERT "augmented rbtree testing");
 349
 350        init();
 351
 352        time1 = get_cycles();
 353
 354        for (i = 0; i < perf_loops; i++) {
 355                for (j = 0; j < nnodes; j++)
 356                        insert_augmented(nodes + j, &root);
 357                for (j = 0; j < nnodes; j++)
 358                        erase_augmented(nodes + j, &root);
 359        }
 360
 361        time2 = get_cycles();
 362        time = time2 - time1;
 363
 364        time = div_u64(time, perf_loops);
 365        printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time);
 366
 367        time1 = get_cycles();
 368
 369        for (i = 0; i < perf_loops; i++) {
 370                for (j = 0; j < nnodes; j++)
 371                        insert_augmented_cached(nodes + j, &root);
 372                for (j = 0; j < nnodes; j++)
 373                        erase_augmented_cached(nodes + j, &root);
 374        }
 375
 376        time2 = get_cycles();
 377        time = time2 - time1;
 378
 379        time = div_u64(time, perf_loops);
 380        printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", (unsigned long long)time);
 381
 382        for (i = 0; i < check_loops; i++) {
 383                init();
 384                for (j = 0; j < nnodes; j++) {
 385                        check_augmented(j);
 386                        insert_augmented(nodes + j, &root);
 387                }
 388                for (j = 0; j < nnodes; j++) {
 389                        check_augmented(nnodes - j);
 390                        erase_augmented(nodes + j, &root);
 391                }
 392                check_augmented(0);
 393        }
 394
 395        kfree(nodes);
 396
 397        return -EAGAIN; /* Fail will directly unload the module */
 398}
 399
 400static void __exit rbtree_test_exit(void)
 401{
 402        printk(KERN_ALERT "test exit\n");
 403}
 404
 405module_init(rbtree_test_init)
 406module_exit(rbtree_test_exit)
 407
 408MODULE_LICENSE("GPL");
 409MODULE_AUTHOR("Michel Lespinasse");
 410MODULE_DESCRIPTION("Red Black Tree test");
 411