linux/lib/test_rhashtable.c
<<
>>
Prefs
   1/*
   2 * Resizable, Scalable, Concurrent Hash Table
   3 *
   4 * Copyright (c) 2014-2015 Thomas Graf <tgraf@suug.ch>
   5 * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
   6 *
   7 * This program is free software; you can redistribute it and/or modify
   8 * it under the terms of the GNU General Public License version 2 as
   9 * published by the Free Software Foundation.
  10 */
  11
  12/**************************************************************************
  13 * Self Test
  14 **************************************************************************/
  15
  16#include <linux/init.h>
  17#include <linux/jhash.h>
  18#include <linux/kernel.h>
  19#include <linux/kthread.h>
  20#include <linux/module.h>
  21#include <linux/rcupdate.h>
  22#include <linux/rhashtable.h>
  23#include <linux/semaphore.h>
  24#include <linux/slab.h>
  25#include <linux/sched.h>
  26#include <linux/vmalloc.h>
  27
  28#define MAX_ENTRIES     1000000
  29#define TEST_INSERT_FAIL INT_MAX
  30
  31static int entries = 50000;
  32module_param(entries, int, 0);
  33MODULE_PARM_DESC(entries, "Number of entries to add (default: 50000)");
  34
  35static int runs = 4;
  36module_param(runs, int, 0);
  37MODULE_PARM_DESC(runs, "Number of test runs per variant (default: 4)");
  38
  39static int max_size = 0;
  40module_param(max_size, int, 0);
  41MODULE_PARM_DESC(max_size, "Maximum table size (default: calculated)");
  42
  43static bool shrinking = false;
  44module_param(shrinking, bool, 0);
  45MODULE_PARM_DESC(shrinking, "Enable automatic shrinking (default: off)");
  46
  47static int size = 8;
  48module_param(size, int, 0);
  49MODULE_PARM_DESC(size, "Initial size hint of table (default: 8)");
  50
  51static int tcount = 10;
  52module_param(tcount, int, 0);
  53MODULE_PARM_DESC(tcount, "Number of threads to spawn (default: 10)");
  54
  55static bool enomem_retry = false;
  56module_param(enomem_retry, bool, 0);
  57MODULE_PARM_DESC(enomem_retry, "Retry insert even if -ENOMEM was returned (default: off)");
  58
  59struct test_obj_val {
  60        int     id;
  61        int     tid;
  62};
  63
  64struct test_obj {
  65        struct test_obj_val     value;
  66        struct rhash_head       node;
  67};
  68
  69struct thread_data {
  70        int id;
  71        struct task_struct *task;
  72        struct test_obj *objs;
  73};
  74
  75static struct test_obj array[MAX_ENTRIES];
  76
  77static struct rhashtable_params test_rht_params = {
  78        .head_offset = offsetof(struct test_obj, node),
  79        .key_offset = offsetof(struct test_obj, value),
  80        .key_len = sizeof(struct test_obj_val),
  81        .hashfn = jhash,
  82        .nulls_base = (3U << RHT_BASE_SHIFT),
  83};
  84
  85static struct semaphore prestart_sem;
  86static struct semaphore startup_sem = __SEMAPHORE_INITIALIZER(startup_sem, 0);
  87
  88static int insert_retry(struct rhashtable *ht, struct rhash_head *obj,
  89                        const struct rhashtable_params params)
  90{
  91        int err, retries = -1, enomem_retries = 0;
  92
  93        do {
  94                retries++;
  95                cond_resched();
  96                err = rhashtable_insert_fast(ht, obj, params);
  97                if (err == -ENOMEM && enomem_retry) {
  98                        enomem_retries++;
  99                        err = -EBUSY;
 100                }
 101        } while (err == -EBUSY);
 102
 103        if (enomem_retries)
 104                pr_info(" %u insertions retried after -ENOMEM\n",
 105                        enomem_retries);
 106
 107        return err ? : retries;
 108}
 109
 110static int __init test_rht_lookup(struct rhashtable *ht)
 111{
 112        unsigned int i;
 113
 114        for (i = 0; i < entries * 2; i++) {
 115                struct test_obj *obj;
 116                bool expected = !(i % 2);
 117                struct test_obj_val key = {
 118                        .id = i,
 119                };
 120
 121                if (array[i / 2].value.id == TEST_INSERT_FAIL)
 122                        expected = false;
 123
 124                obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
 125
 126                if (expected && !obj) {
 127                        pr_warn("Test failed: Could not find key %u\n", key.id);
 128                        return -ENOENT;
 129                } else if (!expected && obj) {
 130                        pr_warn("Test failed: Unexpected entry found for key %u\n",
 131                                key.id);
 132                        return -EEXIST;
 133                } else if (expected && obj) {
 134                        if (obj->value.id != i) {
 135                                pr_warn("Test failed: Lookup value mismatch %u!=%u\n",
 136                                        obj->value.id, i);
 137                                return -EINVAL;
 138                        }
 139                }
 140
 141                cond_resched_rcu();
 142        }
 143
 144        return 0;
 145}
 146
 147static void test_bucket_stats(struct rhashtable *ht)
 148{
 149        unsigned int err, total = 0, chain_len = 0;
 150        struct rhashtable_iter hti;
 151        struct rhash_head *pos;
 152
 153        err = rhashtable_walk_init(ht, &hti, GFP_KERNEL);
 154        if (err) {
 155                pr_warn("Test failed: allocation error");
 156                return;
 157        }
 158
 159        err = rhashtable_walk_start(&hti);
 160        if (err && err != -EAGAIN) {
 161                pr_warn("Test failed: iterator failed: %d\n", err);
 162                return;
 163        }
 164
 165        while ((pos = rhashtable_walk_next(&hti))) {
 166                if (PTR_ERR(pos) == -EAGAIN) {
 167                        pr_info("Info: encountered resize\n");
 168                        chain_len++;
 169                        continue;
 170                } else if (IS_ERR(pos)) {
 171                        pr_warn("Test failed: rhashtable_walk_next() error: %ld\n",
 172                                PTR_ERR(pos));
 173                        break;
 174                }
 175
 176                total++;
 177        }
 178
 179        rhashtable_walk_stop(&hti);
 180        rhashtable_walk_exit(&hti);
 181
 182        pr_info("  Traversal complete: counted=%u, nelems=%u, entries=%d, table-jumps=%u\n",
 183                total, atomic_read(&ht->nelems), entries, chain_len);
 184
 185        if (total != atomic_read(&ht->nelems) || total != entries)
 186                pr_warn("Test failed: Total count mismatch ^^^");
 187}
 188
 189static s64 __init test_rhashtable(struct rhashtable *ht)
 190{
 191        struct test_obj *obj;
 192        int err;
 193        unsigned int i, insert_retries = 0;
 194        s64 start, end;
 195
 196        /*
 197         * Insertion Test:
 198         * Insert entries into table with all keys even numbers
 199         */
 200        pr_info("  Adding %d keys\n", entries);
 201        start = ktime_get_ns();
 202        for (i = 0; i < entries; i++) {
 203                struct test_obj *obj = &array[i];
 204
 205                obj->value.id = i * 2;
 206                err = insert_retry(ht, &obj->node, test_rht_params);
 207                if (err > 0)
 208                        insert_retries += err;
 209                else if (err)
 210                        return err;
 211        }
 212
 213        if (insert_retries)
 214                pr_info("  %u insertions retried due to memory pressure\n",
 215                        insert_retries);
 216
 217        test_bucket_stats(ht);
 218        rcu_read_lock();
 219        test_rht_lookup(ht);
 220        rcu_read_unlock();
 221
 222        test_bucket_stats(ht);
 223
 224        pr_info("  Deleting %d keys\n", entries);
 225        for (i = 0; i < entries; i++) {
 226                struct test_obj_val key = {
 227                        .id = i * 2,
 228                };
 229
 230                if (array[i].value.id != TEST_INSERT_FAIL) {
 231                        obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
 232                        BUG_ON(!obj);
 233
 234                        rhashtable_remove_fast(ht, &obj->node, test_rht_params);
 235                }
 236
 237                cond_resched();
 238        }
 239
 240        end = ktime_get_ns();
 241        pr_info("  Duration of test: %lld ns\n", end - start);
 242
 243        return end - start;
 244}
 245
 246static struct rhashtable ht;
 247
 248static int thread_lookup_test(struct thread_data *tdata)
 249{
 250        int i, err = 0;
 251
 252        for (i = 0; i < entries; i++) {
 253                struct test_obj *obj;
 254                struct test_obj_val key = {
 255                        .id = i,
 256                        .tid = tdata->id,
 257                };
 258
 259                obj = rhashtable_lookup_fast(&ht, &key, test_rht_params);
 260                if (obj && (tdata->objs[i].value.id == TEST_INSERT_FAIL)) {
 261                        pr_err("  found unexpected object %d-%d\n", key.tid, key.id);
 262                        err++;
 263                } else if (!obj && (tdata->objs[i].value.id != TEST_INSERT_FAIL)) {
 264                        pr_err("  object %d-%d not found!\n", key.tid, key.id);
 265                        err++;
 266                } else if (obj && memcmp(&obj->value, &key, sizeof(key))) {
 267                        pr_err("  wrong object returned (got %d-%d, expected %d-%d)\n",
 268                               obj->value.tid, obj->value.id, key.tid, key.id);
 269                        err++;
 270                }
 271
 272                cond_resched();
 273        }
 274        return err;
 275}
 276
 277static int threadfunc(void *data)
 278{
 279        int i, step, err = 0, insert_retries = 0;
 280        struct thread_data *tdata = data;
 281
 282        up(&prestart_sem);
 283        if (down_interruptible(&startup_sem))
 284                pr_err("  thread[%d]: down_interruptible failed\n", tdata->id);
 285
 286        for (i = 0; i < entries; i++) {
 287                tdata->objs[i].value.id = i;
 288                tdata->objs[i].value.tid = tdata->id;
 289                err = insert_retry(&ht, &tdata->objs[i].node, test_rht_params);
 290                if (err > 0) {
 291                        insert_retries += err;
 292                } else if (err) {
 293                        pr_err("  thread[%d]: rhashtable_insert_fast failed\n",
 294                               tdata->id);
 295                        goto out;
 296                }
 297        }
 298        if (insert_retries)
 299                pr_info("  thread[%d]: %u insertions retried due to memory pressure\n",
 300                        tdata->id, insert_retries);
 301
 302        err = thread_lookup_test(tdata);
 303        if (err) {
 304                pr_err("  thread[%d]: rhashtable_lookup_test failed\n",
 305                       tdata->id);
 306                goto out;
 307        }
 308
 309        for (step = 10; step > 0; step--) {
 310                for (i = 0; i < entries; i += step) {
 311                        if (tdata->objs[i].value.id == TEST_INSERT_FAIL)
 312                                continue;
 313                        err = rhashtable_remove_fast(&ht, &tdata->objs[i].node,
 314                                                     test_rht_params);
 315                        if (err) {
 316                                pr_err("  thread[%d]: rhashtable_remove_fast failed\n",
 317                                       tdata->id);
 318                                goto out;
 319                        }
 320                        tdata->objs[i].value.id = TEST_INSERT_FAIL;
 321
 322                        cond_resched();
 323                }
 324                err = thread_lookup_test(tdata);
 325                if (err) {
 326                        pr_err("  thread[%d]: rhashtable_lookup_test (2) failed\n",
 327                               tdata->id);
 328                        goto out;
 329                }
 330        }
 331out:
 332        while (!kthread_should_stop()) {
 333                set_current_state(TASK_INTERRUPTIBLE);
 334                schedule();
 335        }
 336        return err;
 337}
 338
 339static int __init test_rht_init(void)
 340{
 341        int i, err, started_threads = 0, failed_threads = 0;
 342        u64 total_time = 0;
 343        struct thread_data *tdata;
 344        struct test_obj *objs;
 345
 346        entries = min(entries, MAX_ENTRIES);
 347
 348        test_rht_params.automatic_shrinking = shrinking;
 349        test_rht_params.max_size = max_size ? : roundup_pow_of_two(entries);
 350        test_rht_params.nelem_hint = size;
 351
 352        pr_info("Running rhashtable test nelem=%d, max_size=%d, shrinking=%d\n",
 353                size, max_size, shrinking);
 354
 355        for (i = 0; i < runs; i++) {
 356                s64 time;
 357
 358                pr_info("Test %02d:\n", i);
 359                memset(&array, 0, sizeof(array));
 360                err = rhashtable_init(&ht, &test_rht_params);
 361                if (err < 0) {
 362                        pr_warn("Test failed: Unable to initialize hashtable: %d\n",
 363                                err);
 364                        continue;
 365                }
 366
 367                time = test_rhashtable(&ht);
 368                rhashtable_destroy(&ht);
 369                if (time < 0) {
 370                        pr_warn("Test failed: return code %lld\n", time);
 371                        return -EINVAL;
 372                }
 373
 374                total_time += time;
 375        }
 376
 377        do_div(total_time, runs);
 378        pr_info("Average test time: %llu\n", total_time);
 379
 380        if (!tcount)
 381                return 0;
 382
 383        pr_info("Testing concurrent rhashtable access from %d threads\n",
 384                tcount);
 385        sema_init(&prestart_sem, 1 - tcount);
 386        tdata = vzalloc(tcount * sizeof(struct thread_data));
 387        if (!tdata)
 388                return -ENOMEM;
 389        objs  = vzalloc(tcount * entries * sizeof(struct test_obj));
 390        if (!objs) {
 391                vfree(tdata);
 392                return -ENOMEM;
 393        }
 394
 395        test_rht_params.max_size = max_size ? :
 396                                   roundup_pow_of_two(tcount * entries);
 397        err = rhashtable_init(&ht, &test_rht_params);
 398        if (err < 0) {
 399                pr_warn("Test failed: Unable to initialize hashtable: %d\n",
 400                        err);
 401                vfree(tdata);
 402                vfree(objs);
 403                return -EINVAL;
 404        }
 405        for (i = 0; i < tcount; i++) {
 406                tdata[i].id = i;
 407                tdata[i].objs = objs + i * entries;
 408                tdata[i].task = kthread_run(threadfunc, &tdata[i],
 409                                            "rhashtable_thrad[%d]", i);
 410                if (IS_ERR(tdata[i].task))
 411                        pr_err(" kthread_run failed for thread %d\n", i);
 412                else
 413                        started_threads++;
 414        }
 415        if (down_interruptible(&prestart_sem))
 416                pr_err("  down interruptible failed\n");
 417        for (i = 0; i < tcount; i++)
 418                up(&startup_sem);
 419        for (i = 0; i < tcount; i++) {
 420                if (IS_ERR(tdata[i].task))
 421                        continue;
 422                if ((err = kthread_stop(tdata[i].task))) {
 423                        pr_warn("Test failed: thread %d returned: %d\n",
 424                                i, err);
 425                        failed_threads++;
 426                }
 427        }
 428        pr_info("Started %d threads, %d failed\n",
 429                started_threads, failed_threads);
 430        rhashtable_destroy(&ht);
 431        vfree(tdata);
 432        vfree(objs);
 433        return 0;
 434}
 435
 436static void __exit test_rht_exit(void)
 437{
 438}
 439
 440module_init(test_rht_init);
 441module_exit(test_rht_exit);
 442
 443MODULE_LICENSE("GPL v2");
 444