linux/lib/test_rhashtable.c
<<
>>
Prefs
   1/*
   2 * Resizable, Scalable, Concurrent Hash Table
   3 *
   4 * Copyright (c) 2014-2015 Thomas Graf <tgraf@suug.ch>
   5 * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
   6 *
   7 * This program is free software; you can redistribute it and/or modify
   8 * it under the terms of the GNU General Public License version 2 as
   9 * published by the Free Software Foundation.
  10 */
  11
  12/**************************************************************************
  13 * Self Test
  14 **************************************************************************/
  15
  16#include <linux/init.h>
  17#include <linux/jhash.h>
  18#include <linux/kernel.h>
  19#include <linux/kthread.h>
  20#include <linux/module.h>
  21#include <linux/rcupdate.h>
  22#include <linux/rhashtable.h>
  23#include <linux/semaphore.h>
  24#include <linux/slab.h>
  25#include <linux/sched.h>
  26#include <linux/vmalloc.h>
  27
  28#define MAX_ENTRIES     1000000
  29#define TEST_INSERT_FAIL INT_MAX
  30
  31static int entries = 50000;
  32module_param(entries, int, 0);
  33MODULE_PARM_DESC(entries, "Number of entries to add (default: 50000)");
  34
  35static int runs = 4;
  36module_param(runs, int, 0);
  37MODULE_PARM_DESC(runs, "Number of test runs per variant (default: 4)");
  38
  39static int max_size = 0;
  40module_param(max_size, int, 0);
  41MODULE_PARM_DESC(max_size, "Maximum table size (default: calculated)");
  42
  43static bool shrinking = false;
  44module_param(shrinking, bool, 0);
  45MODULE_PARM_DESC(shrinking, "Enable automatic shrinking (default: off)");
  46
  47static int size = 8;
  48module_param(size, int, 0);
  49MODULE_PARM_DESC(size, "Initial size hint of table (default: 8)");
  50
  51static int tcount = 10;
  52module_param(tcount, int, 0);
  53MODULE_PARM_DESC(tcount, "Number of threads to spawn (default: 10)");
  54
  55static bool enomem_retry = false;
  56module_param(enomem_retry, bool, 0);
  57MODULE_PARM_DESC(enomem_retry, "Retry insert even if -ENOMEM was returned (default: off)");
  58
  59struct test_obj {
  60        int                     value;
  61        struct rhash_head       node;
  62};
  63
  64struct thread_data {
  65        int id;
  66        struct task_struct *task;
  67        struct test_obj *objs;
  68};
  69
  70static struct test_obj array[MAX_ENTRIES];
  71
  72static struct rhashtable_params test_rht_params = {
  73        .head_offset = offsetof(struct test_obj, node),
  74        .key_offset = offsetof(struct test_obj, value),
  75        .key_len = sizeof(int),
  76        .hashfn = jhash,
  77        .nulls_base = (3U << RHT_BASE_SHIFT),
  78};
  79
  80static struct semaphore prestart_sem;
  81static struct semaphore startup_sem = __SEMAPHORE_INITIALIZER(startup_sem, 0);
  82
  83static int insert_retry(struct rhashtable *ht, struct rhash_head *obj,
  84                        const struct rhashtable_params params)
  85{
  86        int err, retries = -1, enomem_retries = 0;
  87
  88        do {
  89                retries++;
  90                cond_resched();
  91                err = rhashtable_insert_fast(ht, obj, params);
  92                if (err == -ENOMEM && enomem_retry) {
  93                        enomem_retries++;
  94                        err = -EBUSY;
  95                }
  96        } while (err == -EBUSY);
  97
  98        if (enomem_retries)
  99                pr_info(" %u insertions retried after -ENOMEM\n",
 100                        enomem_retries);
 101
 102        return err ? : retries;
 103}
 104
 105static int __init test_rht_lookup(struct rhashtable *ht)
 106{
 107        unsigned int i;
 108
 109        for (i = 0; i < entries * 2; i++) {
 110                struct test_obj *obj;
 111                bool expected = !(i % 2);
 112                u32 key = i;
 113
 114                if (array[i / 2].value == TEST_INSERT_FAIL)
 115                        expected = false;
 116
 117                obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
 118
 119                if (expected && !obj) {
 120                        pr_warn("Test failed: Could not find key %u\n", key);
 121                        return -ENOENT;
 122                } else if (!expected && obj) {
 123                        pr_warn("Test failed: Unexpected entry found for key %u\n",
 124                                key);
 125                        return -EEXIST;
 126                } else if (expected && obj) {
 127                        if (obj->value != i) {
 128                                pr_warn("Test failed: Lookup value mismatch %u!=%u\n",
 129                                        obj->value, i);
 130                                return -EINVAL;
 131                        }
 132                }
 133
 134                cond_resched_rcu();
 135        }
 136
 137        return 0;
 138}
 139
 140static void test_bucket_stats(struct rhashtable *ht)
 141{
 142        unsigned int err, total = 0, chain_len = 0;
 143        struct rhashtable_iter hti;
 144        struct rhash_head *pos;
 145
 146        err = rhashtable_walk_init(ht, &hti, GFP_KERNEL);
 147        if (err) {
 148                pr_warn("Test failed: allocation error");
 149                return;
 150        }
 151
 152        err = rhashtable_walk_start(&hti);
 153        if (err && err != -EAGAIN) {
 154                pr_warn("Test failed: iterator failed: %d\n", err);
 155                return;
 156        }
 157
 158        while ((pos = rhashtable_walk_next(&hti))) {
 159                if (PTR_ERR(pos) == -EAGAIN) {
 160                        pr_info("Info: encountered resize\n");
 161                        chain_len++;
 162                        continue;
 163                } else if (IS_ERR(pos)) {
 164                        pr_warn("Test failed: rhashtable_walk_next() error: %ld\n",
 165                                PTR_ERR(pos));
 166                        break;
 167                }
 168
 169                total++;
 170        }
 171
 172        rhashtable_walk_stop(&hti);
 173        rhashtable_walk_exit(&hti);
 174
 175        pr_info("  Traversal complete: counted=%u, nelems=%u, entries=%d, table-jumps=%u\n",
 176                total, atomic_read(&ht->nelems), entries, chain_len);
 177
 178        if (total != atomic_read(&ht->nelems) || total != entries)
 179                pr_warn("Test failed: Total count mismatch ^^^");
 180}
 181
 182static s64 __init test_rhashtable(struct rhashtable *ht)
 183{
 184        struct test_obj *obj;
 185        int err;
 186        unsigned int i, insert_retries = 0;
 187        s64 start, end;
 188
 189        /*
 190         * Insertion Test:
 191         * Insert entries into table with all keys even numbers
 192         */
 193        pr_info("  Adding %d keys\n", entries);
 194        start = ktime_get_ns();
 195        for (i = 0; i < entries; i++) {
 196                struct test_obj *obj = &array[i];
 197
 198                obj->value = i * 2;
 199                err = insert_retry(ht, &obj->node, test_rht_params);
 200                if (err > 0)
 201                        insert_retries += err;
 202                else if (err)
 203                        return err;
 204        }
 205
 206        if (insert_retries)
 207                pr_info("  %u insertions retried due to memory pressure\n",
 208                        insert_retries);
 209
 210        test_bucket_stats(ht);
 211        rcu_read_lock();
 212        test_rht_lookup(ht);
 213        rcu_read_unlock();
 214
 215        test_bucket_stats(ht);
 216
 217        pr_info("  Deleting %d keys\n", entries);
 218        for (i = 0; i < entries; i++) {
 219                u32 key = i * 2;
 220
 221                if (array[i].value != TEST_INSERT_FAIL) {
 222                        obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
 223                        BUG_ON(!obj);
 224
 225                        rhashtable_remove_fast(ht, &obj->node, test_rht_params);
 226                }
 227
 228                cond_resched();
 229        }
 230
 231        end = ktime_get_ns();
 232        pr_info("  Duration of test: %lld ns\n", end - start);
 233
 234        return end - start;
 235}
 236
 237static struct rhashtable ht;
 238
 239static int thread_lookup_test(struct thread_data *tdata)
 240{
 241        int i, err = 0;
 242
 243        for (i = 0; i < entries; i++) {
 244                struct test_obj *obj;
 245                int key = (tdata->id << 16) | i;
 246
 247                obj = rhashtable_lookup_fast(&ht, &key, test_rht_params);
 248                if (obj && (tdata->objs[i].value == TEST_INSERT_FAIL)) {
 249                        pr_err("  found unexpected object %d\n", key);
 250                        err++;
 251                } else if (!obj && (tdata->objs[i].value != TEST_INSERT_FAIL)) {
 252                        pr_err("  object %d not found!\n", key);
 253                        err++;
 254                } else if (obj && (obj->value != key)) {
 255                        pr_err("  wrong object returned (got %d, expected %d)\n",
 256                               obj->value, key);
 257                        err++;
 258                }
 259
 260                cond_resched();
 261        }
 262        return err;
 263}
 264
 265static int threadfunc(void *data)
 266{
 267        int i, step, err = 0, insert_retries = 0;
 268        struct thread_data *tdata = data;
 269
 270        up(&prestart_sem);
 271        if (down_interruptible(&startup_sem))
 272                pr_err("  thread[%d]: down_interruptible failed\n", tdata->id);
 273
 274        for (i = 0; i < entries; i++) {
 275                tdata->objs[i].value = (tdata->id << 16) | i;
 276                err = insert_retry(&ht, &tdata->objs[i].node, test_rht_params);
 277                if (err > 0) {
 278                        insert_retries += err;
 279                } else if (err) {
 280                        pr_err("  thread[%d]: rhashtable_insert_fast failed\n",
 281                               tdata->id);
 282                        goto out;
 283                }
 284        }
 285        if (insert_retries)
 286                pr_info("  thread[%d]: %u insertions retried due to memory pressure\n",
 287                        tdata->id, insert_retries);
 288
 289        err = thread_lookup_test(tdata);
 290        if (err) {
 291                pr_err("  thread[%d]: rhashtable_lookup_test failed\n",
 292                       tdata->id);
 293                goto out;
 294        }
 295
 296        for (step = 10; step > 0; step--) {
 297                for (i = 0; i < entries; i += step) {
 298                        if (tdata->objs[i].value == TEST_INSERT_FAIL)
 299                                continue;
 300                        err = rhashtable_remove_fast(&ht, &tdata->objs[i].node,
 301                                                     test_rht_params);
 302                        if (err) {
 303                                pr_err("  thread[%d]: rhashtable_remove_fast failed\n",
 304                                       tdata->id);
 305                                goto out;
 306                        }
 307                        tdata->objs[i].value = TEST_INSERT_FAIL;
 308
 309                        cond_resched();
 310                }
 311                err = thread_lookup_test(tdata);
 312                if (err) {
 313                        pr_err("  thread[%d]: rhashtable_lookup_test (2) failed\n",
 314                               tdata->id);
 315                        goto out;
 316                }
 317        }
 318out:
 319        while (!kthread_should_stop()) {
 320                set_current_state(TASK_INTERRUPTIBLE);
 321                schedule();
 322        }
 323        return err;
 324}
 325
 326static int __init test_rht_init(void)
 327{
 328        int i, err, started_threads = 0, failed_threads = 0;
 329        u64 total_time = 0;
 330        struct thread_data *tdata;
 331        struct test_obj *objs;
 332
 333        entries = min(entries, MAX_ENTRIES);
 334
 335        test_rht_params.automatic_shrinking = shrinking;
 336        test_rht_params.max_size = max_size ? : roundup_pow_of_two(entries);
 337        test_rht_params.nelem_hint = size;
 338
 339        pr_info("Running rhashtable test nelem=%d, max_size=%d, shrinking=%d\n",
 340                size, max_size, shrinking);
 341
 342        for (i = 0; i < runs; i++) {
 343                s64 time;
 344
 345                pr_info("Test %02d:\n", i);
 346                memset(&array, 0, sizeof(array));
 347                err = rhashtable_init(&ht, &test_rht_params);
 348                if (err < 0) {
 349                        pr_warn("Test failed: Unable to initialize hashtable: %d\n",
 350                                err);
 351                        continue;
 352                }
 353
 354                time = test_rhashtable(&ht);
 355                rhashtable_destroy(&ht);
 356                if (time < 0) {
 357                        pr_warn("Test failed: return code %lld\n", time);
 358                        return -EINVAL;
 359                }
 360
 361                total_time += time;
 362        }
 363
 364        do_div(total_time, runs);
 365        pr_info("Average test time: %llu\n", total_time);
 366
 367        if (!tcount)
 368                return 0;
 369
 370        pr_info("Testing concurrent rhashtable access from %d threads\n",
 371                tcount);
 372        sema_init(&prestart_sem, 1 - tcount);
 373        tdata = vzalloc(tcount * sizeof(struct thread_data));
 374        if (!tdata)
 375                return -ENOMEM;
 376        objs  = vzalloc(tcount * entries * sizeof(struct test_obj));
 377        if (!objs) {
 378                vfree(tdata);
 379                return -ENOMEM;
 380        }
 381
 382        test_rht_params.max_size = max_size ? :
 383                                   roundup_pow_of_two(tcount * entries);
 384        err = rhashtable_init(&ht, &test_rht_params);
 385        if (err < 0) {
 386                pr_warn("Test failed: Unable to initialize hashtable: %d\n",
 387                        err);
 388                vfree(tdata);
 389                vfree(objs);
 390                return -EINVAL;
 391        }
 392        for (i = 0; i < tcount; i++) {
 393                tdata[i].id = i;
 394                tdata[i].objs = objs + i * entries;
 395                tdata[i].task = kthread_run(threadfunc, &tdata[i],
 396                                            "rhashtable_thrad[%d]", i);
 397                if (IS_ERR(tdata[i].task))
 398                        pr_err(" kthread_run failed for thread %d\n", i);
 399                else
 400                        started_threads++;
 401        }
 402        if (down_interruptible(&prestart_sem))
 403                pr_err("  down interruptible failed\n");
 404        for (i = 0; i < tcount; i++)
 405                up(&startup_sem);
 406        for (i = 0; i < tcount; i++) {
 407                if (IS_ERR(tdata[i].task))
 408                        continue;
 409                if ((err = kthread_stop(tdata[i].task))) {
 410                        pr_warn("Test failed: thread %d returned: %d\n",
 411                                i, err);
 412                        failed_threads++;
 413                }
 414        }
 415        pr_info("Started %d threads, %d failed\n",
 416                started_threads, failed_threads);
 417        rhashtable_destroy(&ht);
 418        vfree(tdata);
 419        vfree(objs);
 420        return 0;
 421}
 422
 423static void __exit test_rht_exit(void)
 424{
 425}
 426
 427module_init(test_rht_init);
 428module_exit(test_rht_exit);
 429
 430MODULE_LICENSE("GPL v2");
 431