linux/lib/test_rhashtable.c
<<
>>
Prefs
   1/*
   2 * Resizable, Scalable, Concurrent Hash Table
   3 *
   4 * Copyright (c) 2014-2015 Thomas Graf <tgraf@suug.ch>
   5 * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
   6 *
   7 * This program is free software; you can redistribute it and/or modify
   8 * it under the terms of the GNU General Public License version 2 as
   9 * published by the Free Software Foundation.
  10 */
  11
  12/**************************************************************************
  13 * Self Test
  14 **************************************************************************/
  15
  16#include <linux/init.h>
  17#include <linux/jhash.h>
  18#include <linux/kernel.h>
  19#include <linux/kthread.h>
  20#include <linux/module.h>
  21#include <linux/rcupdate.h>
  22#include <linux/rhashtable.h>
  23#include <linux/semaphore.h>
  24#include <linux/slab.h>
  25#include <linux/sched.h>
  26#include <linux/random.h>
  27#include <linux/vmalloc.h>
  28
  29#define MAX_ENTRIES     1000000
  30#define TEST_INSERT_FAIL INT_MAX
  31
  32static int parm_entries = 50000;
  33module_param(parm_entries, int, 0);
  34MODULE_PARM_DESC(parm_entries, "Number of entries to add (default: 50000)");
  35
  36static int runs = 4;
  37module_param(runs, int, 0);
  38MODULE_PARM_DESC(runs, "Number of test runs per variant (default: 4)");
  39
  40static int max_size = 0;
  41module_param(max_size, int, 0);
  42MODULE_PARM_DESC(max_size, "Maximum table size (default: calculated)");
  43
  44static bool shrinking = false;
  45module_param(shrinking, bool, 0);
  46MODULE_PARM_DESC(shrinking, "Enable automatic shrinking (default: off)");
  47
  48static int size = 8;
  49module_param(size, int, 0);
  50MODULE_PARM_DESC(size, "Initial size hint of table (default: 8)");
  51
  52static int tcount = 10;
  53module_param(tcount, int, 0);
  54MODULE_PARM_DESC(tcount, "Number of threads to spawn (default: 10)");
  55
  56static bool enomem_retry = false;
  57module_param(enomem_retry, bool, 0);
  58MODULE_PARM_DESC(enomem_retry, "Retry insert even if -ENOMEM was returned (default: off)");
  59
  60struct test_obj_val {
  61        int     id;
  62        int     tid;
  63};
  64
  65struct test_obj {
  66        struct test_obj_val     value;
  67        struct rhash_head       node;
  68};
  69
  70struct test_obj_rhl {
  71        struct test_obj_val     value;
  72        struct rhlist_head      list_node;
  73};
  74
  75struct thread_data {
  76        unsigned int entries;
  77        int id;
  78        struct task_struct *task;
  79        struct test_obj *objs;
  80};
  81
  82static struct rhashtable_params test_rht_params = {
  83        .head_offset = offsetof(struct test_obj, node),
  84        .key_offset = offsetof(struct test_obj, value),
  85        .key_len = sizeof(struct test_obj_val),
  86        .hashfn = jhash,
  87        .nulls_base = (3U << RHT_BASE_SHIFT),
  88};
  89
  90static struct semaphore prestart_sem;
  91static struct semaphore startup_sem = __SEMAPHORE_INITIALIZER(startup_sem, 0);
  92
  93static int insert_retry(struct rhashtable *ht, struct test_obj *obj,
  94                        const struct rhashtable_params params)
  95{
  96        int err, retries = -1, enomem_retries = 0;
  97
  98        do {
  99                retries++;
 100                cond_resched();
 101                err = rhashtable_insert_fast(ht, &obj->node, params);
 102                if (err == -ENOMEM && enomem_retry) {
 103                        enomem_retries++;
 104                        err = -EBUSY;
 105                }
 106        } while (err == -EBUSY);
 107
 108        if (enomem_retries)
 109                pr_info(" %u insertions retried after -ENOMEM\n",
 110                        enomem_retries);
 111
 112        return err ? : retries;
 113}
 114
 115static int __init test_rht_lookup(struct rhashtable *ht, struct test_obj *array,
 116                                  unsigned int entries)
 117{
 118        unsigned int i;
 119
 120        for (i = 0; i < entries; i++) {
 121                struct test_obj *obj;
 122                bool expected = !(i % 2);
 123                struct test_obj_val key = {
 124                        .id = i,
 125                };
 126
 127                if (array[i / 2].value.id == TEST_INSERT_FAIL)
 128                        expected = false;
 129
 130                obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
 131
 132                if (expected && !obj) {
 133                        pr_warn("Test failed: Could not find key %u\n", key.id);
 134                        return -ENOENT;
 135                } else if (!expected && obj) {
 136                        pr_warn("Test failed: Unexpected entry found for key %u\n",
 137                                key.id);
 138                        return -EEXIST;
 139                } else if (expected && obj) {
 140                        if (obj->value.id != i) {
 141                                pr_warn("Test failed: Lookup value mismatch %u!=%u\n",
 142                                        obj->value.id, i);
 143                                return -EINVAL;
 144                        }
 145                }
 146
 147                cond_resched_rcu();
 148        }
 149
 150        return 0;
 151}
 152
 153static void test_bucket_stats(struct rhashtable *ht, unsigned int entries)
 154{
 155        unsigned int err, total = 0, chain_len = 0;
 156        struct rhashtable_iter hti;
 157        struct rhash_head *pos;
 158
 159        err = rhashtable_walk_init(ht, &hti, GFP_KERNEL);
 160        if (err) {
 161                pr_warn("Test failed: allocation error");
 162                return;
 163        }
 164
 165        err = rhashtable_walk_start(&hti);
 166        if (err && err != -EAGAIN) {
 167                pr_warn("Test failed: iterator failed: %d\n", err);
 168                return;
 169        }
 170
 171        while ((pos = rhashtable_walk_next(&hti))) {
 172                if (PTR_ERR(pos) == -EAGAIN) {
 173                        pr_info("Info: encountered resize\n");
 174                        chain_len++;
 175                        continue;
 176                } else if (IS_ERR(pos)) {
 177                        pr_warn("Test failed: rhashtable_walk_next() error: %ld\n",
 178                                PTR_ERR(pos));
 179                        break;
 180                }
 181
 182                total++;
 183        }
 184
 185        rhashtable_walk_stop(&hti);
 186        rhashtable_walk_exit(&hti);
 187
 188        pr_info("  Traversal complete: counted=%u, nelems=%u, entries=%d, table-jumps=%u\n",
 189                total, atomic_read(&ht->nelems), entries, chain_len);
 190
 191        if (total != atomic_read(&ht->nelems) || total != entries)
 192                pr_warn("Test failed: Total count mismatch ^^^");
 193}
 194
 195static s64 __init test_rhashtable(struct rhashtable *ht, struct test_obj *array,
 196                                  unsigned int entries)
 197{
 198        struct test_obj *obj;
 199        int err;
 200        unsigned int i, insert_retries = 0;
 201        s64 start, end;
 202
 203        /*
 204         * Insertion Test:
 205         * Insert entries into table with all keys even numbers
 206         */
 207        pr_info("  Adding %d keys\n", entries);
 208        start = ktime_get_ns();
 209        for (i = 0; i < entries; i++) {
 210                struct test_obj *obj = &array[i];
 211
 212                obj->value.id = i * 2;
 213                err = insert_retry(ht, obj, test_rht_params);
 214                if (err > 0)
 215                        insert_retries += err;
 216                else if (err)
 217                        return err;
 218        }
 219
 220        if (insert_retries)
 221                pr_info("  %u insertions retried due to memory pressure\n",
 222                        insert_retries);
 223
 224        test_bucket_stats(ht, entries);
 225        rcu_read_lock();
 226        test_rht_lookup(ht, array, entries);
 227        rcu_read_unlock();
 228
 229        test_bucket_stats(ht, entries);
 230
 231        pr_info("  Deleting %d keys\n", entries);
 232        for (i = 0; i < entries; i++) {
 233                struct test_obj_val key = {
 234                        .id = i * 2,
 235                };
 236
 237                if (array[i].value.id != TEST_INSERT_FAIL) {
 238                        obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
 239                        BUG_ON(!obj);
 240
 241                        rhashtable_remove_fast(ht, &obj->node, test_rht_params);
 242                }
 243
 244                cond_resched();
 245        }
 246
 247        end = ktime_get_ns();
 248        pr_info("  Duration of test: %lld ns\n", end - start);
 249
 250        return end - start;
 251}
 252
 253static struct rhashtable ht;
 254static struct rhltable rhlt;
 255
 256static int __init test_rhltable(unsigned int entries)
 257{
 258        struct test_obj_rhl *rhl_test_objects;
 259        unsigned long *obj_in_table;
 260        unsigned int i, j, k;
 261        int ret, err;
 262
 263        if (entries == 0)
 264                entries = 1;
 265
 266        rhl_test_objects = vzalloc(sizeof(*rhl_test_objects) * entries);
 267        if (!rhl_test_objects)
 268                return -ENOMEM;
 269
 270        ret = -ENOMEM;
 271        obj_in_table = vzalloc(BITS_TO_LONGS(entries) * sizeof(unsigned long));
 272        if (!obj_in_table)
 273                goto out_free;
 274
 275        /* nulls_base not supported in rhlist interface */
 276        test_rht_params.nulls_base = 0;
 277        err = rhltable_init(&rhlt, &test_rht_params);
 278        if (WARN_ON(err))
 279                goto out_free;
 280
 281        k = prandom_u32();
 282        ret = 0;
 283        for (i = 0; i < entries; i++) {
 284                rhl_test_objects[i].value.id = k;
 285                err = rhltable_insert(&rhlt, &rhl_test_objects[i].list_node,
 286                                      test_rht_params);
 287                if (WARN(err, "error %d on element %d\n", err, i))
 288                        break;
 289                if (err == 0)
 290                        set_bit(i, obj_in_table);
 291        }
 292
 293        if (err)
 294                ret = err;
 295
 296        pr_info("test %d add/delete pairs into rhlist\n", entries);
 297        for (i = 0; i < entries; i++) {
 298                struct rhlist_head *h, *pos;
 299                struct test_obj_rhl *obj;
 300                struct test_obj_val key = {
 301                        .id = k,
 302                };
 303                bool found;
 304
 305                rcu_read_lock();
 306                h = rhltable_lookup(&rhlt, &key, test_rht_params);
 307                if (WARN(!h, "key not found during iteration %d of %d", i, entries)) {
 308                        rcu_read_unlock();
 309                        break;
 310                }
 311
 312                if (i) {
 313                        j = i - 1;
 314                        rhl_for_each_entry_rcu(obj, pos, h, list_node) {
 315                                if (WARN(pos == &rhl_test_objects[j].list_node, "old element found, should be gone"))
 316                                        break;
 317                        }
 318                }
 319
 320                cond_resched_rcu();
 321
 322                found = false;
 323
 324                rhl_for_each_entry_rcu(obj, pos, h, list_node) {
 325                        if (pos == &rhl_test_objects[i].list_node) {
 326                                found = true;
 327                                break;
 328                        }
 329                }
 330
 331                rcu_read_unlock();
 332
 333                if (WARN(!found, "element %d not found", i))
 334                        break;
 335
 336                err = rhltable_remove(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
 337                WARN(err, "rhltable_remove: err %d for iteration %d\n", err, i);
 338                if (err == 0)
 339                        clear_bit(i, obj_in_table);
 340        }
 341
 342        if (ret == 0 && err)
 343                ret = err;
 344
 345        for (i = 0; i < entries; i++) {
 346                WARN(test_bit(i, obj_in_table), "elem %d allegedly still present", i);
 347
 348                err = rhltable_insert(&rhlt, &rhl_test_objects[i].list_node,
 349                                      test_rht_params);
 350                if (WARN(err, "error %d on element %d\n", err, i))
 351                        break;
 352                if (err == 0)
 353                        set_bit(i, obj_in_table);
 354        }
 355
 356        pr_info("test %d random rhlist add/delete operations\n", entries);
 357        for (j = 0; j < entries; j++) {
 358                u32 i = prandom_u32_max(entries);
 359                u32 prand = prandom_u32();
 360
 361                cond_resched();
 362
 363                if (prand == 0)
 364                        prand = prandom_u32();
 365
 366                if (prand & 1) {
 367                        prand >>= 1;
 368                        continue;
 369                }
 370
 371                err = rhltable_remove(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
 372                if (test_bit(i, obj_in_table)) {
 373                        clear_bit(i, obj_in_table);
 374                        if (WARN(err, "cannot remove element at slot %d", i))
 375                                continue;
 376                } else {
 377                        if (WARN(err != -ENOENT, "removed non-existant element %d, error %d not %d",
 378                             i, err, -ENOENT))
 379                                continue;
 380                }
 381
 382                if (prand & 1) {
 383                        prand >>= 1;
 384                        continue;
 385                }
 386
 387                err = rhltable_insert(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
 388                if (err == 0) {
 389                        if (WARN(test_and_set_bit(i, obj_in_table), "succeeded to insert same object %d", i))
 390                                continue;
 391                } else {
 392                        if (WARN(!test_bit(i, obj_in_table), "failed to insert object %d", i))
 393                                continue;
 394                }
 395
 396                if (prand & 1) {
 397                        prand >>= 1;
 398                        continue;
 399                }
 400
 401                i = prandom_u32_max(entries);
 402                if (test_bit(i, obj_in_table)) {
 403                        err = rhltable_remove(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
 404                        WARN(err, "cannot remove element at slot %d", i);
 405                        if (err == 0)
 406                                clear_bit(i, obj_in_table);
 407                } else {
 408                        err = rhltable_insert(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
 409                        WARN(err, "failed to insert object %d", i);
 410                        if (err == 0)
 411                                set_bit(i, obj_in_table);
 412                }
 413        }
 414
 415        for (i = 0; i < entries; i++) {
 416                cond_resched();
 417                err = rhltable_remove(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
 418                if (test_bit(i, obj_in_table)) {
 419                        if (WARN(err, "cannot remove element at slot %d", i))
 420                                continue;
 421                } else {
 422                        if (WARN(err != -ENOENT, "removed non-existant element, error %d not %d",
 423                                 err, -ENOENT))
 424                        continue;
 425                }
 426        }
 427
 428        rhltable_destroy(&rhlt);
 429out_free:
 430        vfree(rhl_test_objects);
 431        vfree(obj_in_table);
 432        return ret;
 433}
 434
 435static int __init test_rhashtable_max(struct test_obj *array,
 436                                      unsigned int entries)
 437{
 438        unsigned int i, insert_retries = 0;
 439        int err;
 440
 441        test_rht_params.max_size = roundup_pow_of_two(entries / 8);
 442        err = rhashtable_init(&ht, &test_rht_params);
 443        if (err)
 444                return err;
 445
 446        for (i = 0; i < ht.max_elems; i++) {
 447                struct test_obj *obj = &array[i];
 448
 449                obj->value.id = i * 2;
 450                err = insert_retry(&ht, obj, test_rht_params);
 451                if (err > 0)
 452                        insert_retries += err;
 453                else if (err)
 454                        return err;
 455        }
 456
 457        err = insert_retry(&ht, &array[ht.max_elems], test_rht_params);
 458        if (err == -E2BIG) {
 459                err = 0;
 460        } else {
 461                pr_info("insert element %u should have failed with %d, got %d\n",
 462                                ht.max_elems, -E2BIG, err);
 463                if (err == 0)
 464                        err = -1;
 465        }
 466
 467        rhashtable_destroy(&ht);
 468
 469        return err;
 470}
 471
 472static int thread_lookup_test(struct thread_data *tdata)
 473{
 474        unsigned int entries = tdata->entries;
 475        int i, err = 0;
 476
 477        for (i = 0; i < entries; i++) {
 478                struct test_obj *obj;
 479                struct test_obj_val key = {
 480                        .id = i,
 481                        .tid = tdata->id,
 482                };
 483
 484                obj = rhashtable_lookup_fast(&ht, &key, test_rht_params);
 485                if (obj && (tdata->objs[i].value.id == TEST_INSERT_FAIL)) {
 486                        pr_err("  found unexpected object %d-%d\n", key.tid, key.id);
 487                        err++;
 488                } else if (!obj && (tdata->objs[i].value.id != TEST_INSERT_FAIL)) {
 489                        pr_err("  object %d-%d not found!\n", key.tid, key.id);
 490                        err++;
 491                } else if (obj && memcmp(&obj->value, &key, sizeof(key))) {
 492                        pr_err("  wrong object returned (got %d-%d, expected %d-%d)\n",
 493                               obj->value.tid, obj->value.id, key.tid, key.id);
 494                        err++;
 495                }
 496
 497                cond_resched();
 498        }
 499        return err;
 500}
 501
 502static int threadfunc(void *data)
 503{
 504        int i, step, err = 0, insert_retries = 0;
 505        struct thread_data *tdata = data;
 506
 507        up(&prestart_sem);
 508        if (down_interruptible(&startup_sem))
 509                pr_err("  thread[%d]: down_interruptible failed\n", tdata->id);
 510
 511        for (i = 0; i < tdata->entries; i++) {
 512                tdata->objs[i].value.id = i;
 513                tdata->objs[i].value.tid = tdata->id;
 514                err = insert_retry(&ht, &tdata->objs[i], test_rht_params);
 515                if (err > 0) {
 516                        insert_retries += err;
 517                } else if (err) {
 518                        pr_err("  thread[%d]: rhashtable_insert_fast failed\n",
 519                               tdata->id);
 520                        goto out;
 521                }
 522        }
 523        if (insert_retries)
 524                pr_info("  thread[%d]: %u insertions retried due to memory pressure\n",
 525                        tdata->id, insert_retries);
 526
 527        err = thread_lookup_test(tdata);
 528        if (err) {
 529                pr_err("  thread[%d]: rhashtable_lookup_test failed\n",
 530                       tdata->id);
 531                goto out;
 532        }
 533
 534        for (step = 10; step > 0; step--) {
 535                for (i = 0; i < tdata->entries; i += step) {
 536                        if (tdata->objs[i].value.id == TEST_INSERT_FAIL)
 537                                continue;
 538                        err = rhashtable_remove_fast(&ht, &tdata->objs[i].node,
 539                                                     test_rht_params);
 540                        if (err) {
 541                                pr_err("  thread[%d]: rhashtable_remove_fast failed\n",
 542                                       tdata->id);
 543                                goto out;
 544                        }
 545                        tdata->objs[i].value.id = TEST_INSERT_FAIL;
 546
 547                        cond_resched();
 548                }
 549                err = thread_lookup_test(tdata);
 550                if (err) {
 551                        pr_err("  thread[%d]: rhashtable_lookup_test (2) failed\n",
 552                               tdata->id);
 553                        goto out;
 554                }
 555        }
 556out:
 557        while (!kthread_should_stop()) {
 558                set_current_state(TASK_INTERRUPTIBLE);
 559                schedule();
 560        }
 561        return err;
 562}
 563
 564static int __init test_rht_init(void)
 565{
 566        unsigned int entries;
 567        int i, err, started_threads = 0, failed_threads = 0;
 568        u64 total_time = 0;
 569        struct thread_data *tdata;
 570        struct test_obj *objs;
 571
 572        if (parm_entries < 0)
 573                parm_entries = 1;
 574
 575        entries = min(parm_entries, MAX_ENTRIES);
 576
 577        test_rht_params.automatic_shrinking = shrinking;
 578        test_rht_params.max_size = max_size ? : roundup_pow_of_two(entries);
 579        test_rht_params.nelem_hint = size;
 580
 581        objs = vzalloc((test_rht_params.max_size + 1) * sizeof(struct test_obj));
 582        if (!objs)
 583                return -ENOMEM;
 584
 585        pr_info("Running rhashtable test nelem=%d, max_size=%d, shrinking=%d\n",
 586                size, max_size, shrinking);
 587
 588        for (i = 0; i < runs; i++) {
 589                s64 time;
 590
 591                pr_info("Test %02d:\n", i);
 592                memset(objs, 0, test_rht_params.max_size * sizeof(struct test_obj));
 593
 594                err = rhashtable_init(&ht, &test_rht_params);
 595                if (err < 0) {
 596                        pr_warn("Test failed: Unable to initialize hashtable: %d\n",
 597                                err);
 598                        continue;
 599                }
 600
 601                time = test_rhashtable(&ht, objs, entries);
 602                rhashtable_destroy(&ht);
 603                if (time < 0) {
 604                        vfree(objs);
 605                        pr_warn("Test failed: return code %lld\n", time);
 606                        return -EINVAL;
 607                }
 608
 609                total_time += time;
 610        }
 611
 612        pr_info("test if its possible to exceed max_size %d: %s\n",
 613                        test_rht_params.max_size, test_rhashtable_max(objs, entries) == 0 ?
 614                        "no, ok" : "YES, failed");
 615        vfree(objs);
 616
 617        do_div(total_time, runs);
 618        pr_info("Average test time: %llu\n", total_time);
 619
 620        if (!tcount)
 621                return 0;
 622
 623        pr_info("Testing concurrent rhashtable access from %d threads\n",
 624                tcount);
 625        sema_init(&prestart_sem, 1 - tcount);
 626        tdata = vzalloc(tcount * sizeof(struct thread_data));
 627        if (!tdata)
 628                return -ENOMEM;
 629        objs  = vzalloc(tcount * entries * sizeof(struct test_obj));
 630        if (!objs) {
 631                vfree(tdata);
 632                return -ENOMEM;
 633        }
 634
 635        test_rht_params.max_size = max_size ? :
 636                                   roundup_pow_of_two(tcount * entries);
 637        err = rhashtable_init(&ht, &test_rht_params);
 638        if (err < 0) {
 639                pr_warn("Test failed: Unable to initialize hashtable: %d\n",
 640                        err);
 641                vfree(tdata);
 642                vfree(objs);
 643                return -EINVAL;
 644        }
 645        for (i = 0; i < tcount; i++) {
 646                tdata[i].id = i;
 647                tdata[i].entries = entries;
 648                tdata[i].objs = objs + i * entries;
 649                tdata[i].task = kthread_run(threadfunc, &tdata[i],
 650                                            "rhashtable_thrad[%d]", i);
 651                if (IS_ERR(tdata[i].task))
 652                        pr_err(" kthread_run failed for thread %d\n", i);
 653                else
 654                        started_threads++;
 655        }
 656        if (down_interruptible(&prestart_sem))
 657                pr_err("  down interruptible failed\n");
 658        for (i = 0; i < tcount; i++)
 659                up(&startup_sem);
 660        for (i = 0; i < tcount; i++) {
 661                if (IS_ERR(tdata[i].task))
 662                        continue;
 663                if ((err = kthread_stop(tdata[i].task))) {
 664                        pr_warn("Test failed: thread %d returned: %d\n",
 665                                i, err);
 666                        failed_threads++;
 667                }
 668        }
 669        rhashtable_destroy(&ht);
 670        vfree(tdata);
 671        vfree(objs);
 672
 673        /*
 674         * rhltable_remove is very expensive, default values can cause test
 675         * to run for 2 minutes or more,  use a smaller number instead.
 676         */
 677        err = test_rhltable(entries / 16);
 678        pr_info("Started %d threads, %d failed, rhltable test returns %d\n",
 679                started_threads, failed_threads, err);
 680        return 0;
 681}
 682
 683static void __exit test_rht_exit(void)
 684{
 685}
 686
 687module_init(test_rht_init);
 688module_exit(test_rht_exit);
 689
 690MODULE_LICENSE("GPL v2");
 691