linux/lib/test_objagg.c
<<
>>
Prefs
   1// SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0
   2/* Copyright (c) 2018 Mellanox Technologies. All rights reserved */
   3
   4#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
   5
   6#include <linux/kernel.h>
   7#include <linux/module.h>
   8#include <linux/slab.h>
   9#include <linux/random.h>
  10#include <linux/objagg.h>
  11
  12struct tokey {
  13        unsigned int id;
  14};
  15
  16#define NUM_KEYS 32
  17
  18static int key_id_index(unsigned int key_id)
  19{
  20        if (key_id >= NUM_KEYS) {
  21                WARN_ON(1);
  22                return 0;
  23        }
  24        return key_id;
  25}
  26
  27#define BUF_LEN 128
  28
  29struct world {
  30        unsigned int root_count;
  31        unsigned int delta_count;
  32        char next_root_buf[BUF_LEN];
  33        struct objagg_obj *objagg_objs[NUM_KEYS];
  34        unsigned int key_refs[NUM_KEYS];
  35};
  36
  37struct root {
  38        struct tokey key;
  39        char buf[BUF_LEN];
  40};
  41
  42struct delta {
  43        unsigned int key_id_diff;
  44};
  45
  46static struct objagg_obj *world_obj_get(struct world *world,
  47                                        struct objagg *objagg,
  48                                        unsigned int key_id)
  49{
  50        struct objagg_obj *objagg_obj;
  51        struct tokey key;
  52        int err;
  53
  54        key.id = key_id;
  55        objagg_obj = objagg_obj_get(objagg, &key);
  56        if (IS_ERR(objagg_obj)) {
  57                pr_err("Key %u: Failed to get object.\n", key_id);
  58                return objagg_obj;
  59        }
  60        if (!world->key_refs[key_id_index(key_id)]) {
  61                world->objagg_objs[key_id_index(key_id)] = objagg_obj;
  62        } else if (world->objagg_objs[key_id_index(key_id)] != objagg_obj) {
  63                pr_err("Key %u: God another object for the same key.\n",
  64                       key_id);
  65                err = -EINVAL;
  66                goto err_key_id_check;
  67        }
  68        world->key_refs[key_id_index(key_id)]++;
  69        return objagg_obj;
  70
  71err_key_id_check:
  72        objagg_obj_put(objagg, objagg_obj);
  73        return ERR_PTR(err);
  74}
  75
  76static void world_obj_put(struct world *world, struct objagg *objagg,
  77                          unsigned int key_id)
  78{
  79        struct objagg_obj *objagg_obj;
  80
  81        if (!world->key_refs[key_id_index(key_id)])
  82                return;
  83        objagg_obj = world->objagg_objs[key_id_index(key_id)];
  84        objagg_obj_put(objagg, objagg_obj);
  85        world->key_refs[key_id_index(key_id)]--;
  86}
  87
  88#define MAX_KEY_ID_DIFF 5
  89
  90static bool delta_check(void *priv, const void *parent_obj, const void *obj)
  91{
  92        const struct tokey *parent_key = parent_obj;
  93        const struct tokey *key = obj;
  94        int diff = key->id - parent_key->id;
  95
  96        return diff >= 0 && diff <= MAX_KEY_ID_DIFF;
  97}
  98
  99static void *delta_create(void *priv, void *parent_obj, void *obj)
 100{
 101        struct tokey *parent_key = parent_obj;
 102        struct world *world = priv;
 103        struct tokey *key = obj;
 104        int diff = key->id - parent_key->id;
 105        struct delta *delta;
 106
 107        if (!delta_check(priv, parent_obj, obj))
 108                return ERR_PTR(-EINVAL);
 109
 110        delta = kzalloc(sizeof(*delta), GFP_KERNEL);
 111        if (!delta)
 112                return ERR_PTR(-ENOMEM);
 113        delta->key_id_diff = diff;
 114        world->delta_count++;
 115        return delta;
 116}
 117
 118static void delta_destroy(void *priv, void *delta_priv)
 119{
 120        struct delta *delta = delta_priv;
 121        struct world *world = priv;
 122
 123        world->delta_count--;
 124        kfree(delta);
 125}
 126
 127static void *root_create(void *priv, void *obj, unsigned int id)
 128{
 129        struct world *world = priv;
 130        struct tokey *key = obj;
 131        struct root *root;
 132
 133        root = kzalloc(sizeof(*root), GFP_KERNEL);
 134        if (!root)
 135                return ERR_PTR(-ENOMEM);
 136        memcpy(&root->key, key, sizeof(root->key));
 137        memcpy(root->buf, world->next_root_buf, sizeof(root->buf));
 138        world->root_count++;
 139        return root;
 140}
 141
 142static void root_destroy(void *priv, void *root_priv)
 143{
 144        struct root *root = root_priv;
 145        struct world *world = priv;
 146
 147        world->root_count--;
 148        kfree(root);
 149}
 150
 151static int test_nodelta_obj_get(struct world *world, struct objagg *objagg,
 152                                unsigned int key_id, bool should_create_root)
 153{
 154        unsigned int orig_root_count = world->root_count;
 155        struct objagg_obj *objagg_obj;
 156        const struct root *root;
 157        int err;
 158
 159        if (should_create_root)
 160                prandom_bytes(world->next_root_buf,
 161                              sizeof(world->next_root_buf));
 162
 163        objagg_obj = world_obj_get(world, objagg, key_id);
 164        if (IS_ERR(objagg_obj)) {
 165                pr_err("Key %u: Failed to get object.\n", key_id);
 166                return PTR_ERR(objagg_obj);
 167        }
 168        if (should_create_root) {
 169                if (world->root_count != orig_root_count + 1) {
 170                        pr_err("Key %u: Root was not created\n", key_id);
 171                        err = -EINVAL;
 172                        goto err_check_root_count;
 173                }
 174        } else {
 175                if (world->root_count != orig_root_count) {
 176                        pr_err("Key %u: Root was incorrectly created\n",
 177                               key_id);
 178                        err = -EINVAL;
 179                        goto err_check_root_count;
 180                }
 181        }
 182        root = objagg_obj_root_priv(objagg_obj);
 183        if (root->key.id != key_id) {
 184                pr_err("Key %u: Root has unexpected key id\n", key_id);
 185                err = -EINVAL;
 186                goto err_check_key_id;
 187        }
 188        if (should_create_root &&
 189            memcmp(world->next_root_buf, root->buf, sizeof(root->buf))) {
 190                pr_err("Key %u: Buffer does not match the expected content\n",
 191                       key_id);
 192                err = -EINVAL;
 193                goto err_check_buf;
 194        }
 195        return 0;
 196
 197err_check_buf:
 198err_check_key_id:
 199err_check_root_count:
 200        objagg_obj_put(objagg, objagg_obj);
 201        return err;
 202}
 203
 204static int test_nodelta_obj_put(struct world *world, struct objagg *objagg,
 205                                unsigned int key_id, bool should_destroy_root)
 206{
 207        unsigned int orig_root_count = world->root_count;
 208
 209        world_obj_put(world, objagg, key_id);
 210
 211        if (should_destroy_root) {
 212                if (world->root_count != orig_root_count - 1) {
 213                        pr_err("Key %u: Root was not destroyed\n", key_id);
 214                        return -EINVAL;
 215                }
 216        } else {
 217                if (world->root_count != orig_root_count) {
 218                        pr_err("Key %u: Root was incorrectly destroyed\n",
 219                               key_id);
 220                        return -EINVAL;
 221                }
 222        }
 223        return 0;
 224}
 225
 226static int check_stats_zero(struct objagg *objagg)
 227{
 228        const struct objagg_stats *stats;
 229        int err = 0;
 230
 231        stats = objagg_stats_get(objagg);
 232        if (IS_ERR(stats))
 233                return PTR_ERR(stats);
 234
 235        if (stats->stats_info_count != 0) {
 236                pr_err("Stats: Object count is not zero while it should be\n");
 237                err = -EINVAL;
 238        }
 239
 240        objagg_stats_put(stats);
 241        return err;
 242}
 243
 244static int check_stats_nodelta(struct objagg *objagg)
 245{
 246        const struct objagg_stats *stats;
 247        int i;
 248        int err;
 249
 250        stats = objagg_stats_get(objagg);
 251        if (IS_ERR(stats))
 252                return PTR_ERR(stats);
 253
 254        if (stats->stats_info_count != NUM_KEYS) {
 255                pr_err("Stats: Unexpected object count (%u expected, %u returned)\n",
 256                       NUM_KEYS, stats->stats_info_count);
 257                err = -EINVAL;
 258                goto stats_put;
 259        }
 260
 261        for (i = 0; i < stats->stats_info_count; i++) {
 262                if (stats->stats_info[i].stats.user_count != 2) {
 263                        pr_err("Stats: incorrect user count\n");
 264                        err = -EINVAL;
 265                        goto stats_put;
 266                }
 267                if (stats->stats_info[i].stats.delta_user_count != 2) {
 268                        pr_err("Stats: incorrect delta user count\n");
 269                        err = -EINVAL;
 270                        goto stats_put;
 271                }
 272        }
 273        err = 0;
 274
 275stats_put:
 276        objagg_stats_put(stats);
 277        return err;
 278}
 279
 280static bool delta_check_dummy(void *priv, const void *parent_obj,
 281                              const void *obj)
 282{
 283        return false;
 284}
 285
 286static void *delta_create_dummy(void *priv, void *parent_obj, void *obj)
 287{
 288        return ERR_PTR(-EOPNOTSUPP);
 289}
 290
 291static void delta_destroy_dummy(void *priv, void *delta_priv)
 292{
 293}
 294
 295static const struct objagg_ops nodelta_ops = {
 296        .obj_size = sizeof(struct tokey),
 297        .delta_check = delta_check_dummy,
 298        .delta_create = delta_create_dummy,
 299        .delta_destroy = delta_destroy_dummy,
 300        .root_create = root_create,
 301        .root_destroy = root_destroy,
 302};
 303
 304static int test_nodelta(void)
 305{
 306        struct world world = {};
 307        struct objagg *objagg;
 308        int i;
 309        int err;
 310
 311        objagg = objagg_create(&nodelta_ops, NULL, &world);
 312        if (IS_ERR(objagg))
 313                return PTR_ERR(objagg);
 314
 315        err = check_stats_zero(objagg);
 316        if (err)
 317                goto err_stats_first_zero;
 318
 319        /* First round of gets, the root objects should be created */
 320        for (i = 0; i < NUM_KEYS; i++) {
 321                err = test_nodelta_obj_get(&world, objagg, i, true);
 322                if (err)
 323                        goto err_obj_first_get;
 324        }
 325
 326        /* Do the second round of gets, all roots are already created,
 327         * make sure that no new root is created
 328         */
 329        for (i = 0; i < NUM_KEYS; i++) {
 330                err = test_nodelta_obj_get(&world, objagg, i, false);
 331                if (err)
 332                        goto err_obj_second_get;
 333        }
 334
 335        err = check_stats_nodelta(objagg);
 336        if (err)
 337                goto err_stats_nodelta;
 338
 339        for (i = NUM_KEYS - 1; i >= 0; i--) {
 340                err = test_nodelta_obj_put(&world, objagg, i, false);
 341                if (err)
 342                        goto err_obj_first_put;
 343        }
 344        for (i = NUM_KEYS - 1; i >= 0; i--) {
 345                err = test_nodelta_obj_put(&world, objagg, i, true);
 346                if (err)
 347                        goto err_obj_second_put;
 348        }
 349
 350        err = check_stats_zero(objagg);
 351        if (err)
 352                goto err_stats_second_zero;
 353
 354        objagg_destroy(objagg);
 355        return 0;
 356
 357err_stats_nodelta:
 358err_obj_first_put:
 359err_obj_second_get:
 360        for (i--; i >= 0; i--)
 361                world_obj_put(&world, objagg, i);
 362
 363        i = NUM_KEYS;
 364err_obj_first_get:
 365err_obj_second_put:
 366        for (i--; i >= 0; i--)
 367                world_obj_put(&world, objagg, i);
 368err_stats_first_zero:
 369err_stats_second_zero:
 370        objagg_destroy(objagg);
 371        return err;
 372}
 373
 374static const struct objagg_ops delta_ops = {
 375        .obj_size = sizeof(struct tokey),
 376        .delta_check = delta_check,
 377        .delta_create = delta_create,
 378        .delta_destroy = delta_destroy,
 379        .root_create = root_create,
 380        .root_destroy = root_destroy,
 381};
 382
 383enum action {
 384        ACTION_GET,
 385        ACTION_PUT,
 386};
 387
 388enum expect_delta {
 389        EXPECT_DELTA_SAME,
 390        EXPECT_DELTA_INC,
 391        EXPECT_DELTA_DEC,
 392};
 393
 394enum expect_root {
 395        EXPECT_ROOT_SAME,
 396        EXPECT_ROOT_INC,
 397        EXPECT_ROOT_DEC,
 398};
 399
 400struct expect_stats_info {
 401        struct objagg_obj_stats stats;
 402        bool is_root;
 403        unsigned int key_id;
 404};
 405
 406struct expect_stats {
 407        unsigned int info_count;
 408        struct expect_stats_info info[NUM_KEYS];
 409};
 410
 411struct action_item {
 412        unsigned int key_id;
 413        enum action action;
 414        enum expect_delta expect_delta;
 415        enum expect_root expect_root;
 416        struct expect_stats expect_stats;
 417};
 418
 419#define EXPECT_STATS(count, ...)                \
 420{                                               \
 421        .info_count = count,                    \
 422        .info = { __VA_ARGS__ }                 \
 423}
 424
 425#define ROOT(key_id, user_count, delta_user_count)      \
 426        {{user_count, delta_user_count}, true, key_id}
 427
 428#define DELTA(key_id, user_count)                       \
 429        {{user_count, user_count}, false, key_id}
 430
 431static const struct action_item action_items[] = {
 432        {
 433                1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
 434                EXPECT_STATS(1, ROOT(1, 1, 1)),
 435        },      /* r: 1                 d: */
 436        {
 437                7, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
 438                EXPECT_STATS(2, ROOT(1, 1, 1), ROOT(7, 1, 1)),
 439        },      /* r: 1, 7              d: */
 440        {
 441                3, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
 442                EXPECT_STATS(3, ROOT(1, 1, 2), ROOT(7, 1, 1),
 443                                DELTA(3, 1)),
 444        },      /* r: 1, 7              d: 3^1 */
 445        {
 446                5, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
 447                EXPECT_STATS(4, ROOT(1, 1, 3), ROOT(7, 1, 1),
 448                                DELTA(3, 1), DELTA(5, 1)),
 449        },      /* r: 1, 7              d: 3^1, 5^1 */
 450        {
 451                3, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 452                EXPECT_STATS(4, ROOT(1, 1, 4), ROOT(7, 1, 1),
 453                                DELTA(3, 2), DELTA(5, 1)),
 454        },      /* r: 1, 7              d: 3^1, 3^1, 5^1 */
 455        {
 456                1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 457                EXPECT_STATS(4, ROOT(1, 2, 5), ROOT(7, 1, 1),
 458                                DELTA(3, 2), DELTA(5, 1)),
 459        },      /* r: 1, 1, 7           d: 3^1, 3^1, 5^1 */
 460        {
 461                30, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
 462                EXPECT_STATS(5, ROOT(1, 2, 5), ROOT(7, 1, 1), ROOT(30, 1, 1),
 463                                DELTA(3, 2), DELTA(5, 1)),
 464        },      /* r: 1, 1, 7, 30       d: 3^1, 3^1, 5^1 */
 465        {
 466                8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
 467                EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 2), ROOT(30, 1, 1),
 468                                DELTA(3, 2), DELTA(5, 1), DELTA(8, 1)),
 469        },      /* r: 1, 1, 7, 30       d: 3^1, 3^1, 5^1, 8^7 */
 470        {
 471                8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 472                EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 3), ROOT(30, 1, 1),
 473                                DELTA(3, 2), DELTA(8, 2), DELTA(5, 1)),
 474        },      /* r: 1, 1, 7, 30       d: 3^1, 3^1, 5^1, 8^7, 8^7 */
 475        {
 476                3, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 477                EXPECT_STATS(6, ROOT(1, 2, 4), ROOT(7, 1, 3), ROOT(30, 1, 1),
 478                                DELTA(8, 2), DELTA(3, 1), DELTA(5, 1)),
 479        },      /* r: 1, 1, 7, 30       d: 3^1, 5^1, 8^7, 8^7 */
 480        {
 481                3, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
 482                EXPECT_STATS(5, ROOT(1, 2, 3), ROOT(7, 1, 3), ROOT(30, 1, 1),
 483                                DELTA(8, 2), DELTA(5, 1)),
 484        },      /* r: 1, 1, 7, 30       d: 5^1, 8^7, 8^7 */
 485        {
 486                1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 487                EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(1, 1, 2), ROOT(30, 1, 1),
 488                                DELTA(8, 2), DELTA(5, 1)),
 489        },      /* r: 1, 7, 30          d: 5^1, 8^7, 8^7 */
 490        {
 491                1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 492                EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(1, 0, 1),
 493                                DELTA(8, 2), DELTA(5, 1)),
 494        },      /* r: 7, 30             d: 5^1, 8^7, 8^7 */
 495        {
 496                5, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
 497                EXPECT_STATS(3, ROOT(7, 1, 3), ROOT(30, 1, 1),
 498                                DELTA(8, 2)),
 499        },      /* r: 7, 30             d: 8^7, 8^7 */
 500        {
 501                5, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
 502                EXPECT_STATS(4, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(5, 1, 1),
 503                                DELTA(8, 2)),
 504        },      /* r: 7, 30, 5          d: 8^7, 8^7 */
 505        {
 506                6, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
 507                EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
 508                                DELTA(8, 2), DELTA(6, 1)),
 509        },      /* r: 7, 30, 5          d: 8^7, 8^7, 6^5 */
 510        {
 511                8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 512                EXPECT_STATS(5, ROOT(7, 1, 4), ROOT(5, 1, 2), ROOT(30, 1, 1),
 513                                DELTA(8, 3), DELTA(6, 1)),
 514        },      /* r: 7, 30, 5          d: 8^7, 8^7, 8^7, 6^5 */
 515        {
 516                8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 517                EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
 518                                DELTA(8, 2), DELTA(6, 1)),
 519        },      /* r: 7, 30, 5          d: 8^7, 8^7, 6^5 */
 520        {
 521                8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 522                EXPECT_STATS(5, ROOT(7, 1, 2), ROOT(5, 1, 2), ROOT(30, 1, 1),
 523                                DELTA(8, 1), DELTA(6, 1)),
 524        },      /* r: 7, 30, 5          d: 8^7, 6^5 */
 525        {
 526                8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
 527                EXPECT_STATS(4, ROOT(5, 1, 2), ROOT(7, 1, 1), ROOT(30, 1, 1),
 528                                DELTA(6, 1)),
 529        },      /* r: 7, 30, 5          d: 6^5 */
 530        {
 531                8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
 532                EXPECT_STATS(5, ROOT(5, 1, 3), ROOT(7, 1, 1), ROOT(30, 1, 1),
 533                                DELTA(6, 1), DELTA(8, 1)),
 534        },      /* r: 7, 30, 5          d: 6^5, 8^5 */
 535        {
 536                7, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
 537                EXPECT_STATS(4, ROOT(5, 1, 3), ROOT(30, 1, 1),
 538                                DELTA(6, 1), DELTA(8, 1)),
 539        },      /* r: 30, 5             d: 6^5, 8^5 */
 540        {
 541                30, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
 542                EXPECT_STATS(3, ROOT(5, 1, 3),
 543                                DELTA(6, 1), DELTA(8, 1)),
 544        },      /* r: 5                 d: 6^5, 8^5 */
 545        {
 546                5, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
 547                EXPECT_STATS(3, ROOT(5, 0, 2),
 548                                DELTA(6, 1), DELTA(8, 1)),
 549        },      /* r:                   d: 6^5, 8^5 */
 550        {
 551                6, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
 552                EXPECT_STATS(2, ROOT(5, 0, 1),
 553                                DELTA(8, 1)),
 554        },      /* r:                   d: 6^5 */
 555        {
 556                8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
 557                EXPECT_STATS(0, ),
 558        },      /* r:                   d: */
 559};
 560
 561static int check_expect(struct world *world,
 562                        const struct action_item *action_item,
 563                        unsigned int orig_delta_count,
 564                        unsigned int orig_root_count)
 565{
 566        unsigned int key_id = action_item->key_id;
 567
 568        switch (action_item->expect_delta) {
 569        case EXPECT_DELTA_SAME:
 570                if (orig_delta_count != world->delta_count) {
 571                        pr_err("Key %u: Delta count changed while expected to remain the same.\n",
 572                               key_id);
 573                        return -EINVAL;
 574                }
 575                break;
 576        case EXPECT_DELTA_INC:
 577                if (WARN_ON(action_item->action == ACTION_PUT))
 578                        return -EINVAL;
 579                if (orig_delta_count + 1 != world->delta_count) {
 580                        pr_err("Key %u: Delta count was not incremented.\n",
 581                               key_id);
 582                        return -EINVAL;
 583                }
 584                break;
 585        case EXPECT_DELTA_DEC:
 586                if (WARN_ON(action_item->action == ACTION_GET))
 587                        return -EINVAL;
 588                if (orig_delta_count - 1 != world->delta_count) {
 589                        pr_err("Key %u: Delta count was not decremented.\n",
 590                               key_id);
 591                        return -EINVAL;
 592                }
 593                break;
 594        }
 595
 596        switch (action_item->expect_root) {
 597        case EXPECT_ROOT_SAME:
 598                if (orig_root_count != world->root_count) {
 599                        pr_err("Key %u: Root count changed while expected to remain the same.\n",
 600                               key_id);
 601                        return -EINVAL;
 602                }
 603                break;
 604        case EXPECT_ROOT_INC:
 605                if (WARN_ON(action_item->action == ACTION_PUT))
 606                        return -EINVAL;
 607                if (orig_root_count + 1 != world->root_count) {
 608                        pr_err("Key %u: Root count was not incremented.\n",
 609                               key_id);
 610                        return -EINVAL;
 611                }
 612                break;
 613        case EXPECT_ROOT_DEC:
 614                if (WARN_ON(action_item->action == ACTION_GET))
 615                        return -EINVAL;
 616                if (orig_root_count - 1 != world->root_count) {
 617                        pr_err("Key %u: Root count was not decremented.\n",
 618                               key_id);
 619                        return -EINVAL;
 620                }
 621        }
 622
 623        return 0;
 624}
 625
 626static unsigned int obj_to_key_id(struct objagg_obj *objagg_obj)
 627{
 628        const struct tokey *root_key;
 629        const struct delta *delta;
 630        unsigned int key_id;
 631
 632        root_key = objagg_obj_root_priv(objagg_obj);
 633        key_id = root_key->id;
 634        delta = objagg_obj_delta_priv(objagg_obj);
 635        if (delta)
 636                key_id += delta->key_id_diff;
 637        return key_id;
 638}
 639
 640static int
 641check_expect_stats_nums(const struct objagg_obj_stats_info *stats_info,
 642                        const struct expect_stats_info *expect_stats_info,
 643                        const char **errmsg)
 644{
 645        if (stats_info->is_root != expect_stats_info->is_root) {
 646                if (errmsg)
 647                        *errmsg = "Incorrect root/delta indication";
 648                return -EINVAL;
 649        }
 650        if (stats_info->stats.user_count !=
 651            expect_stats_info->stats.user_count) {
 652                if (errmsg)
 653                        *errmsg = "Incorrect user count";
 654                return -EINVAL;
 655        }
 656        if (stats_info->stats.delta_user_count !=
 657            expect_stats_info->stats.delta_user_count) {
 658                if (errmsg)
 659                        *errmsg = "Incorrect delta user count";
 660                return -EINVAL;
 661        }
 662        return 0;
 663}
 664
 665static int
 666check_expect_stats_key_id(const struct objagg_obj_stats_info *stats_info,
 667                          const struct expect_stats_info *expect_stats_info,
 668                          const char **errmsg)
 669{
 670        if (obj_to_key_id(stats_info->objagg_obj) !=
 671            expect_stats_info->key_id) {
 672                if (errmsg)
 673                        *errmsg = "incorrect key id";
 674                return -EINVAL;
 675        }
 676        return 0;
 677}
 678
 679static int check_expect_stats_neigh(const struct objagg_stats *stats,
 680                                    const struct expect_stats *expect_stats,
 681                                    int pos)
 682{
 683        int i;
 684        int err;
 685
 686        for (i = pos - 1; i >= 0; i--) {
 687                err = check_expect_stats_nums(&stats->stats_info[i],
 688                                              &expect_stats->info[pos], NULL);
 689                if (err)
 690                        break;
 691                err = check_expect_stats_key_id(&stats->stats_info[i],
 692                                                &expect_stats->info[pos], NULL);
 693                if (!err)
 694                        return 0;
 695        }
 696        for (i = pos + 1; i < stats->stats_info_count; i++) {
 697                err = check_expect_stats_nums(&stats->stats_info[i],
 698                                              &expect_stats->info[pos], NULL);
 699                if (err)
 700                        break;
 701                err = check_expect_stats_key_id(&stats->stats_info[i],
 702                                                &expect_stats->info[pos], NULL);
 703                if (!err)
 704                        return 0;
 705        }
 706        return -EINVAL;
 707}
 708
 709static int __check_expect_stats(const struct objagg_stats *stats,
 710                                const struct expect_stats *expect_stats,
 711                                const char **errmsg)
 712{
 713        int i;
 714        int err;
 715
 716        if (stats->stats_info_count != expect_stats->info_count) {
 717                *errmsg = "Unexpected object count";
 718                return -EINVAL;
 719        }
 720
 721        for (i = 0; i < stats->stats_info_count; i++) {
 722                err = check_expect_stats_nums(&stats->stats_info[i],
 723                                              &expect_stats->info[i], errmsg);
 724                if (err)
 725                        return err;
 726                err = check_expect_stats_key_id(&stats->stats_info[i],
 727                                                &expect_stats->info[i], errmsg);
 728                if (err) {
 729                        /* It is possible that one of the neighbor stats with
 730                         * same numbers have the correct key id, so check it
 731                         */
 732                        err = check_expect_stats_neigh(stats, expect_stats, i);
 733                        if (err)
 734                                return err;
 735                }
 736        }
 737        return 0;
 738}
 739
 740static int check_expect_stats(struct objagg *objagg,
 741                              const struct expect_stats *expect_stats,
 742                              const char **errmsg)
 743{
 744        const struct objagg_stats *stats;
 745        int err;
 746
 747        stats = objagg_stats_get(objagg);
 748        if (IS_ERR(stats)) {
 749                *errmsg = "objagg_stats_get() failed.";
 750                return PTR_ERR(stats);
 751        }
 752        err = __check_expect_stats(stats, expect_stats, errmsg);
 753        objagg_stats_put(stats);
 754        return err;
 755}
 756
 757static int test_delta_action_item(struct world *world,
 758                                  struct objagg *objagg,
 759                                  const struct action_item *action_item,
 760                                  bool inverse)
 761{
 762        unsigned int orig_delta_count = world->delta_count;
 763        unsigned int orig_root_count = world->root_count;
 764        unsigned int key_id = action_item->key_id;
 765        enum action action = action_item->action;
 766        struct objagg_obj *objagg_obj;
 767        const char *errmsg;
 768        int err;
 769
 770        if (inverse)
 771                action = action == ACTION_GET ? ACTION_PUT : ACTION_GET;
 772
 773        switch (action) {
 774        case ACTION_GET:
 775                objagg_obj = world_obj_get(world, objagg, key_id);
 776                if (IS_ERR(objagg_obj))
 777                        return PTR_ERR(objagg_obj);
 778                break;
 779        case ACTION_PUT:
 780                world_obj_put(world, objagg, key_id);
 781                break;
 782        }
 783
 784        if (inverse)
 785                return 0;
 786        err = check_expect(world, action_item,
 787                           orig_delta_count, orig_root_count);
 788        if (err)
 789                goto errout;
 790
 791        err = check_expect_stats(objagg, &action_item->expect_stats, &errmsg);
 792        if (err) {
 793                pr_err("Key %u: Stats: %s\n", action_item->key_id, errmsg);
 794                goto errout;
 795        }
 796
 797        return 0;
 798
 799errout:
 800        /* This can only happen when action is not inversed.
 801         * So in case of an error, cleanup by doing inverse action.
 802         */
 803        test_delta_action_item(world, objagg, action_item, true);
 804        return err;
 805}
 806
 807static int test_delta(void)
 808{
 809        struct world world = {};
 810        struct objagg *objagg;
 811        int i;
 812        int err;
 813
 814        objagg = objagg_create(&delta_ops, NULL, &world);
 815        if (IS_ERR(objagg))
 816                return PTR_ERR(objagg);
 817
 818        for (i = 0; i < ARRAY_SIZE(action_items); i++) {
 819                err = test_delta_action_item(&world, objagg,
 820                                             &action_items[i], false);
 821                if (err)
 822                        goto err_do_action_item;
 823        }
 824
 825        objagg_destroy(objagg);
 826        return 0;
 827
 828err_do_action_item:
 829        for (i--; i >= 0; i--)
 830                test_delta_action_item(&world, objagg, &action_items[i], true);
 831
 832        objagg_destroy(objagg);
 833        return err;
 834}
 835
 836struct hints_case {
 837        const unsigned int *key_ids;
 838        size_t key_ids_count;
 839        struct expect_stats expect_stats;
 840        struct expect_stats expect_stats_hints;
 841};
 842
 843static const unsigned int hints_case_key_ids[] = {
 844        1, 7, 3, 5, 3, 1, 30, 8, 8, 5, 6, 8,
 845};
 846
 847static const struct hints_case hints_case = {
 848        .key_ids = hints_case_key_ids,
 849        .key_ids_count = ARRAY_SIZE(hints_case_key_ids),
 850        .expect_stats =
 851                EXPECT_STATS(7, ROOT(1, 2, 7), ROOT(7, 1, 4), ROOT(30, 1, 1),
 852                                DELTA(8, 3), DELTA(3, 2),
 853                                DELTA(5, 2), DELTA(6, 1)),
 854        .expect_stats_hints =
 855                EXPECT_STATS(7, ROOT(3, 2, 9), ROOT(1, 2, 2), ROOT(30, 1, 1),
 856                                DELTA(8, 3), DELTA(5, 2),
 857                                DELTA(6, 1), DELTA(7, 1)),
 858};
 859
 860static void __pr_debug_stats(const struct objagg_stats *stats)
 861{
 862        int i;
 863
 864        for (i = 0; i < stats->stats_info_count; i++)
 865                pr_debug("Stat index %d key %u: u %d, d %d, %s\n", i,
 866                         obj_to_key_id(stats->stats_info[i].objagg_obj),
 867                         stats->stats_info[i].stats.user_count,
 868                         stats->stats_info[i].stats.delta_user_count,
 869                         stats->stats_info[i].is_root ? "root" : "noroot");
 870}
 871
 872static void pr_debug_stats(struct objagg *objagg)
 873{
 874        const struct objagg_stats *stats;
 875
 876        stats = objagg_stats_get(objagg);
 877        if (IS_ERR(stats))
 878                return;
 879        __pr_debug_stats(stats);
 880        objagg_stats_put(stats);
 881}
 882
 883static void pr_debug_hints_stats(struct objagg_hints *objagg_hints)
 884{
 885        const struct objagg_stats *stats;
 886
 887        stats = objagg_hints_stats_get(objagg_hints);
 888        if (IS_ERR(stats))
 889                return;
 890        __pr_debug_stats(stats);
 891        objagg_stats_put(stats);
 892}
 893
 894static int check_expect_hints_stats(struct objagg_hints *objagg_hints,
 895                                    const struct expect_stats *expect_stats,
 896                                    const char **errmsg)
 897{
 898        const struct objagg_stats *stats;
 899        int err;
 900
 901        stats = objagg_hints_stats_get(objagg_hints);
 902        if (IS_ERR(stats))
 903                return PTR_ERR(stats);
 904        err = __check_expect_stats(stats, expect_stats, errmsg);
 905        objagg_stats_put(stats);
 906        return err;
 907}
 908
 909static int test_hints_case(const struct hints_case *hints_case)
 910{
 911        struct objagg_obj *objagg_obj;
 912        struct objagg_hints *hints;
 913        struct world world2 = {};
 914        struct world world = {};
 915        struct objagg *objagg2;
 916        struct objagg *objagg;
 917        const char *errmsg;
 918        int i;
 919        int err;
 920
 921        objagg = objagg_create(&delta_ops, NULL, &world);
 922        if (IS_ERR(objagg))
 923                return PTR_ERR(objagg);
 924
 925        for (i = 0; i < hints_case->key_ids_count; i++) {
 926                objagg_obj = world_obj_get(&world, objagg,
 927                                           hints_case->key_ids[i]);
 928                if (IS_ERR(objagg_obj)) {
 929                        err = PTR_ERR(objagg_obj);
 930                        goto err_world_obj_get;
 931                }
 932        }
 933
 934        pr_debug_stats(objagg);
 935        err = check_expect_stats(objagg, &hints_case->expect_stats, &errmsg);
 936        if (err) {
 937                pr_err("Stats: %s\n", errmsg);
 938                goto err_check_expect_stats;
 939        }
 940
 941        hints = objagg_hints_get(objagg, OBJAGG_OPT_ALGO_SIMPLE_GREEDY);
 942        if (IS_ERR(hints)) {
 943                err = PTR_ERR(hints);
 944                goto err_hints_get;
 945        }
 946
 947        pr_debug_hints_stats(hints);
 948        err = check_expect_hints_stats(hints, &hints_case->expect_stats_hints,
 949                                       &errmsg);
 950        if (err) {
 951                pr_err("Hints stats: %s\n", errmsg);
 952                goto err_check_expect_hints_stats;
 953        }
 954
 955        objagg2 = objagg_create(&delta_ops, hints, &world2);
 956        if (IS_ERR(objagg2))
 957                return PTR_ERR(objagg2);
 958
 959        for (i = 0; i < hints_case->key_ids_count; i++) {
 960                objagg_obj = world_obj_get(&world2, objagg2,
 961                                           hints_case->key_ids[i]);
 962                if (IS_ERR(objagg_obj)) {
 963                        err = PTR_ERR(objagg_obj);
 964                        goto err_world2_obj_get;
 965                }
 966        }
 967
 968        pr_debug_stats(objagg2);
 969        err = check_expect_stats(objagg2, &hints_case->expect_stats_hints,
 970                                 &errmsg);
 971        if (err) {
 972                pr_err("Stats2: %s\n", errmsg);
 973                goto err_check_expect_stats2;
 974        }
 975
 976        err = 0;
 977
 978err_check_expect_stats2:
 979err_world2_obj_get:
 980        for (i--; i >= 0; i--)
 981                world_obj_put(&world2, objagg, hints_case->key_ids[i]);
 982        objagg_hints_put(hints);
 983        objagg_destroy(objagg2);
 984        i = hints_case->key_ids_count;
 985err_check_expect_hints_stats:
 986err_hints_get:
 987err_check_expect_stats:
 988err_world_obj_get:
 989        for (i--; i >= 0; i--)
 990                world_obj_put(&world, objagg, hints_case->key_ids[i]);
 991
 992        objagg_destroy(objagg);
 993        return err;
 994}
 995static int test_hints(void)
 996{
 997        return test_hints_case(&hints_case);
 998}
 999
1000static int __init test_objagg_init(void)
1001{
1002        int err;
1003
1004        err = test_nodelta();
1005        if (err)
1006                return err;
1007        err = test_delta();
1008        if (err)
1009                return err;
1010        return test_hints();
1011}
1012
1013static void __exit test_objagg_exit(void)
1014{
1015}
1016
1017module_init(test_objagg_init);
1018module_exit(test_objagg_exit);
1019MODULE_LICENSE("Dual BSD/GPL");
1020MODULE_AUTHOR("Jiri Pirko <jiri@mellanox.com>");
1021MODULE_DESCRIPTION("Test module for objagg");
1022