linux/tools/testing/selftests/bpf/test_hashmap.c
<<
>>
Prefs
   1// SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause)
   2
   3/*
   4 * Tests for libbpf's hashmap.
   5 *
   6 * Copyright (c) 2019 Facebook
   7 */
   8#include <stdio.h>
   9#include <errno.h>
  10#include <linux/err.h>
  11#include "hashmap.h"
  12
  13#define CHECK(condition, format...) ({                                  \
  14        int __ret = !!(condition);                                      \
  15        if (__ret) {                                                    \
  16                fprintf(stderr, "%s:%d:FAIL ", __func__, __LINE__);     \
  17                fprintf(stderr, format);                                \
  18        }                                                               \
  19        __ret;                                                          \
  20})
  21
  22size_t hash_fn(const void *k, void *ctx)
  23{
  24        return (long)k;
  25}
  26
  27bool equal_fn(const void *a, const void *b, void *ctx)
  28{
  29        return (long)a == (long)b;
  30}
  31
  32static inline size_t next_pow_2(size_t n)
  33{
  34        size_t r = 1;
  35
  36        while (r < n)
  37                r <<= 1;
  38        return r;
  39}
  40
  41static inline size_t exp_cap(size_t sz)
  42{
  43        size_t r = next_pow_2(sz);
  44
  45        if (sz * 4 / 3 > r)
  46                r <<= 1;
  47        return r;
  48}
  49
  50#define ELEM_CNT 62
  51
  52int test_hashmap_generic(void)
  53{
  54        struct hashmap_entry *entry, *tmp;
  55        int err, bkt, found_cnt, i;
  56        long long found_msk;
  57        struct hashmap *map;
  58
  59        fprintf(stderr, "%s: ", __func__);
  60
  61        map = hashmap__new(hash_fn, equal_fn, NULL);
  62        if (CHECK(IS_ERR(map), "failed to create map: %ld\n", PTR_ERR(map)))
  63                return 1;
  64
  65        for (i = 0; i < ELEM_CNT; i++) {
  66                const void *oldk, *k = (const void *)(long)i;
  67                void *oldv, *v = (void *)(long)(1024 + i);
  68
  69                err = hashmap__update(map, k, v, &oldk, &oldv);
  70                if (CHECK(err != -ENOENT, "unexpected result: %d\n", err))
  71                        return 1;
  72
  73                if (i % 2) {
  74                        err = hashmap__add(map, k, v);
  75                } else {
  76                        err = hashmap__set(map, k, v, &oldk, &oldv);
  77                        if (CHECK(oldk != NULL || oldv != NULL,
  78                                  "unexpected k/v: %p=%p\n", oldk, oldv))
  79                                return 1;
  80                }
  81
  82                if (CHECK(err, "failed to add k/v %ld = %ld: %d\n",
  83                               (long)k, (long)v, err))
  84                        return 1;
  85
  86                if (CHECK(!hashmap__find(map, k, &oldv),
  87                          "failed to find key %ld\n", (long)k))
  88                        return 1;
  89                if (CHECK(oldv != v, "found value is wrong: %ld\n", (long)oldv))
  90                        return 1;
  91        }
  92
  93        if (CHECK(hashmap__size(map) != ELEM_CNT,
  94                  "invalid map size: %zu\n", hashmap__size(map)))
  95                return 1;
  96        if (CHECK(hashmap__capacity(map) != exp_cap(hashmap__size(map)),
  97                  "unexpected map capacity: %zu\n", hashmap__capacity(map)))
  98                return 1;
  99
 100        found_msk = 0;
 101        hashmap__for_each_entry(map, entry, bkt) {
 102                long k = (long)entry->key;
 103                long v = (long)entry->value;
 104
 105                found_msk |= 1ULL << k;
 106                if (CHECK(v - k != 1024, "invalid k/v pair: %ld = %ld\n", k, v))
 107                        return 1;
 108        }
 109        if (CHECK(found_msk != (1ULL << ELEM_CNT) - 1,
 110                  "not all keys iterated: %llx\n", found_msk))
 111                return 1;
 112
 113        for (i = 0; i < ELEM_CNT; i++) {
 114                const void *oldk, *k = (const void *)(long)i;
 115                void *oldv, *v = (void *)(long)(256 + i);
 116
 117                err = hashmap__add(map, k, v);
 118                if (CHECK(err != -EEXIST, "unexpected add result: %d\n", err))
 119                        return 1;
 120
 121                if (i % 2)
 122                        err = hashmap__update(map, k, v, &oldk, &oldv);
 123                else
 124                        err = hashmap__set(map, k, v, &oldk, &oldv);
 125
 126                if (CHECK(err, "failed to update k/v %ld = %ld: %d\n",
 127                               (long)k, (long)v, err))
 128                        return 1;
 129                if (CHECK(!hashmap__find(map, k, &oldv),
 130                          "failed to find key %ld\n", (long)k))
 131                        return 1;
 132                if (CHECK(oldv != v, "found value is wrong: %ld\n", (long)oldv))
 133                        return 1;
 134        }
 135
 136        if (CHECK(hashmap__size(map) != ELEM_CNT,
 137                  "invalid updated map size: %zu\n", hashmap__size(map)))
 138                return 1;
 139        if (CHECK(hashmap__capacity(map) != exp_cap(hashmap__size(map)),
 140                  "unexpected map capacity: %zu\n", hashmap__capacity(map)))
 141                return 1;
 142
 143        found_msk = 0;
 144        hashmap__for_each_entry_safe(map, entry, tmp, bkt) {
 145                long k = (long)entry->key;
 146                long v = (long)entry->value;
 147
 148                found_msk |= 1ULL << k;
 149                if (CHECK(v - k != 256,
 150                          "invalid updated k/v pair: %ld = %ld\n", k, v))
 151                        return 1;
 152        }
 153        if (CHECK(found_msk != (1ULL << ELEM_CNT) - 1,
 154                  "not all keys iterated after update: %llx\n", found_msk))
 155                return 1;
 156
 157        found_cnt = 0;
 158        hashmap__for_each_key_entry(map, entry, (void *)0) {
 159                found_cnt++;
 160        }
 161        if (CHECK(!found_cnt, "didn't find any entries for key 0\n"))
 162                return 1;
 163
 164        found_msk = 0;
 165        found_cnt = 0;
 166        hashmap__for_each_key_entry_safe(map, entry, tmp, (void *)0) {
 167                const void *oldk, *k;
 168                void *oldv, *v;
 169
 170                k = entry->key;
 171                v = entry->value;
 172
 173                found_cnt++;
 174                found_msk |= 1ULL << (long)k;
 175
 176                if (CHECK(!hashmap__delete(map, k, &oldk, &oldv),
 177                          "failed to delete k/v %ld = %ld\n",
 178                          (long)k, (long)v))
 179                        return 1;
 180                if (CHECK(oldk != k || oldv != v,
 181                          "invalid deleted k/v: expected %ld = %ld, got %ld = %ld\n",
 182                          (long)k, (long)v, (long)oldk, (long)oldv))
 183                        return 1;
 184                if (CHECK(hashmap__delete(map, k, &oldk, &oldv),
 185                          "unexpectedly deleted k/v %ld = %ld\n",
 186                          (long)oldk, (long)oldv))
 187                        return 1;
 188        }
 189
 190        if (CHECK(!found_cnt || !found_msk,
 191                  "didn't delete any key entries\n"))
 192                return 1;
 193        if (CHECK(hashmap__size(map) != ELEM_CNT - found_cnt,
 194                  "invalid updated map size (already deleted: %d): %zu\n",
 195                  found_cnt, hashmap__size(map)))
 196                return 1;
 197        if (CHECK(hashmap__capacity(map) != exp_cap(hashmap__size(map)),
 198                  "unexpected map capacity: %zu\n", hashmap__capacity(map)))
 199                return 1;
 200
 201        hashmap__for_each_entry_safe(map, entry, tmp, bkt) {
 202                const void *oldk, *k;
 203                void *oldv, *v;
 204
 205                k = entry->key;
 206                v = entry->value;
 207
 208                found_cnt++;
 209                found_msk |= 1ULL << (long)k;
 210
 211                if (CHECK(!hashmap__delete(map, k, &oldk, &oldv),
 212                          "failed to delete k/v %ld = %ld\n",
 213                          (long)k, (long)v))
 214                        return 1;
 215                if (CHECK(oldk != k || oldv != v,
 216                          "invalid old k/v: expect %ld = %ld, got %ld = %ld\n",
 217                          (long)k, (long)v, (long)oldk, (long)oldv))
 218                        return 1;
 219                if (CHECK(hashmap__delete(map, k, &oldk, &oldv),
 220                          "unexpectedly deleted k/v %ld = %ld\n",
 221                          (long)k, (long)v))
 222                        return 1;
 223        }
 224
 225        if (CHECK(found_cnt != ELEM_CNT || found_msk != (1ULL << ELEM_CNT) - 1,
 226                  "not all keys were deleted: found_cnt:%d, found_msk:%llx\n",
 227                  found_cnt, found_msk))
 228                return 1;
 229        if (CHECK(hashmap__size(map) != 0,
 230                  "invalid updated map size (already deleted: %d): %zu\n",
 231                  found_cnt, hashmap__size(map)))
 232                return 1;
 233
 234        found_cnt = 0;
 235        hashmap__for_each_entry(map, entry, bkt) {
 236                CHECK(false, "unexpected map entries left: %ld = %ld\n",
 237                             (long)entry->key, (long)entry->value);
 238                return 1;
 239        }
 240
 241        hashmap__free(map);
 242        hashmap__for_each_entry(map, entry, bkt) {
 243                CHECK(false, "unexpected map entries left: %ld = %ld\n",
 244                             (long)entry->key, (long)entry->value);
 245                return 1;
 246        }
 247
 248        fprintf(stderr, "OK\n");
 249        return 0;
 250}
 251
 252size_t collision_hash_fn(const void *k, void *ctx)
 253{
 254        return 0;
 255}
 256
 257int test_hashmap_multimap(void)
 258{
 259        void *k1 = (void *)0, *k2 = (void *)1;
 260        struct hashmap_entry *entry;
 261        struct hashmap *map;
 262        long found_msk;
 263        int err, bkt;
 264
 265        fprintf(stderr, "%s: ", __func__);
 266
 267        /* force collisions */
 268        map = hashmap__new(collision_hash_fn, equal_fn, NULL);
 269        if (CHECK(IS_ERR(map), "failed to create map: %ld\n", PTR_ERR(map)))
 270                return 1;
 271
 272
 273        /* set up multimap:
 274         * [0] -> 1, 2, 4;
 275         * [1] -> 8, 16, 32;
 276         */
 277        err = hashmap__append(map, k1, (void *)1);
 278        if (CHECK(err, "failed to add k/v: %d\n", err))
 279                return 1;
 280        err = hashmap__append(map, k1, (void *)2);
 281        if (CHECK(err, "failed to add k/v: %d\n", err))
 282                return 1;
 283        err = hashmap__append(map, k1, (void *)4);
 284        if (CHECK(err, "failed to add k/v: %d\n", err))
 285                return 1;
 286
 287        err = hashmap__append(map, k2, (void *)8);
 288        if (CHECK(err, "failed to add k/v: %d\n", err))
 289                return 1;
 290        err = hashmap__append(map, k2, (void *)16);
 291        if (CHECK(err, "failed to add k/v: %d\n", err))
 292                return 1;
 293        err = hashmap__append(map, k2, (void *)32);
 294        if (CHECK(err, "failed to add k/v: %d\n", err))
 295                return 1;
 296
 297        if (CHECK(hashmap__size(map) != 6,
 298                  "invalid map size: %zu\n", hashmap__size(map)))
 299                return 1;
 300
 301        /* verify global iteration still works and sees all values */
 302        found_msk = 0;
 303        hashmap__for_each_entry(map, entry, bkt) {
 304                found_msk |= (long)entry->value;
 305        }
 306        if (CHECK(found_msk != (1 << 6) - 1,
 307                  "not all keys iterated: %lx\n", found_msk))
 308                return 1;
 309
 310        /* iterate values for key 1 */
 311        found_msk = 0;
 312        hashmap__for_each_key_entry(map, entry, k1) {
 313                found_msk |= (long)entry->value;
 314        }
 315        if (CHECK(found_msk != (1 | 2 | 4),
 316                  "invalid k1 values: %lx\n", found_msk))
 317                return 1;
 318
 319        /* iterate values for key 2 */
 320        found_msk = 0;
 321        hashmap__for_each_key_entry(map, entry, k2) {
 322                found_msk |= (long)entry->value;
 323        }
 324        if (CHECK(found_msk != (8 | 16 | 32),
 325                  "invalid k2 values: %lx\n", found_msk))
 326                return 1;
 327
 328        fprintf(stderr, "OK\n");
 329        return 0;
 330}
 331
 332int test_hashmap_empty()
 333{
 334        struct hashmap_entry *entry;
 335        int bkt;
 336        struct hashmap *map;
 337        void *k = (void *)0;
 338
 339        fprintf(stderr, "%s: ", __func__);
 340
 341        /* force collisions */
 342        map = hashmap__new(hash_fn, equal_fn, NULL);
 343        if (CHECK(IS_ERR(map), "failed to create map: %ld\n", PTR_ERR(map)))
 344                return 1;
 345
 346        if (CHECK(hashmap__size(map) != 0,
 347                  "invalid map size: %zu\n", hashmap__size(map)))
 348                return 1;
 349        if (CHECK(hashmap__capacity(map) != 0,
 350                  "invalid map capacity: %zu\n", hashmap__capacity(map)))
 351                return 1;
 352        if (CHECK(hashmap__find(map, k, NULL), "unexpected find\n"))
 353                return 1;
 354        if (CHECK(hashmap__delete(map, k, NULL, NULL), "unexpected delete\n"))
 355                return 1;
 356
 357        hashmap__for_each_entry(map, entry, bkt) {
 358                CHECK(false, "unexpected iterated entry\n");
 359                return 1;
 360        }
 361        hashmap__for_each_key_entry(map, entry, k) {
 362                CHECK(false, "unexpected key entry\n");
 363                return 1;
 364        }
 365
 366        fprintf(stderr, "OK\n");
 367        return 0;
 368}
 369
 370int main(int argc, char **argv)
 371{
 372        bool failed = false;
 373
 374        if (test_hashmap_generic())
 375                failed = true;
 376        if (test_hashmap_multimap())
 377                failed = true;
 378        if (test_hashmap_empty())
 379                failed = true;
 380
 381        return failed;
 382}
 383