uboot/lib/rsa/rsa-verify.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0+
   2/*
   3 * Copyright (c) 2013, Google Inc.
   4 */
   5
   6#ifndef USE_HOSTCC
   7#include <common.h>
   8#include <fdtdec.h>
   9#include <log.h>
  10#include <malloc.h>
  11#include <asm/types.h>
  12#include <asm/byteorder.h>
  13#include <linux/errno.h>
  14#include <asm/types.h>
  15#include <asm/unaligned.h>
  16#include <dm.h>
  17#else
  18#include "fdt_host.h"
  19#include "mkimage.h"
  20#include <fdt_support.h>
  21#endif
  22#include <linux/kconfig.h>
  23#include <u-boot/rsa-mod-exp.h>
  24#include <u-boot/rsa.h>
  25
  26#ifndef __UBOOT__
  27/*
  28 * NOTE:
  29 * Since host tools, like mkimage, make use of openssl library for
  30 * RSA encryption, rsa_verify_with_pkey()/rsa_gen_key_prop() are
  31 * of no use and should not be compiled in.
  32 * So just turn off CONFIG_RSA_VERIFY_WITH_PKEY.
  33 */
  34
  35#undef CONFIG_RSA_VERIFY_WITH_PKEY
  36#endif
  37
  38/* Default public exponent for backward compatibility */
  39#define RSA_DEFAULT_PUBEXP      65537
  40
  41/**
  42 * rsa_verify_padding() - Verify RSA message padding is valid
  43 *
  44 * Verify a RSA message's padding is consistent with PKCS1.5
  45 * padding as described in the RSA PKCS#1 v2.1 standard.
  46 *
  47 * @msg:        Padded message
  48 * @pad_len:    Number of expected padding bytes
  49 * @algo:       Checksum algo structure having information on DER encoding etc.
  50 * @return 0 on success, != 0 on failure
  51 */
  52static int rsa_verify_padding(const uint8_t *msg, const int pad_len,
  53                              struct checksum_algo *algo)
  54{
  55        int ff_len;
  56        int ret;
  57
  58        /* first byte must be 0x00 */
  59        ret = *msg++;
  60        /* second byte must be 0x01 */
  61        ret |= *msg++ ^ 0x01;
  62        /* next ff_len bytes must be 0xff */
  63        ff_len = pad_len - algo->der_len - 3;
  64        ret |= *msg ^ 0xff;
  65        ret |= memcmp(msg, msg+1, ff_len-1);
  66        msg += ff_len;
  67        /* next byte must be 0x00 */
  68        ret |= *msg++;
  69        /* next der_len bytes must match der_prefix */
  70        ret |= memcmp(msg, algo->der_prefix, algo->der_len);
  71
  72        return ret;
  73}
  74
  75int padding_pkcs_15_verify(struct image_sign_info *info,
  76                           uint8_t *msg, int msg_len,
  77                           const uint8_t *hash, int hash_len)
  78{
  79        struct checksum_algo *checksum = info->checksum;
  80        int ret, pad_len = msg_len - checksum->checksum_len;
  81
  82        /* Check pkcs1.5 padding bytes. */
  83        ret = rsa_verify_padding(msg, pad_len, checksum);
  84        if (ret) {
  85                debug("In RSAVerify(): Padding check failed!\n");
  86                return -EINVAL;
  87        }
  88
  89        /* Check hash. */
  90        if (memcmp((uint8_t *)msg + pad_len, hash, msg_len - pad_len)) {
  91                debug("In RSAVerify(): Hash check failed!\n");
  92                return -EACCES;
  93        }
  94
  95        return 0;
  96}
  97
  98#ifdef CONFIG_FIT_ENABLE_RSASSA_PSS_SUPPORT
  99static void u32_i2osp(uint32_t val, uint8_t *buf)
 100{
 101        buf[0] = (uint8_t)((val >> 24) & 0xff);
 102        buf[1] = (uint8_t)((val >> 16) & 0xff);
 103        buf[2] = (uint8_t)((val >>  8) & 0xff);
 104        buf[3] = (uint8_t)((val >>  0) & 0xff);
 105}
 106
 107/**
 108 * mask_generation_function1() - generate an octet string
 109 *
 110 * Generate an octet string used to check rsa signature.
 111 * It use an input octet string and a hash function.
 112 *
 113 * @checksum:   A Hash function
 114 * @seed:       Specifies an input variable octet string
 115 * @seed_len:   Size of the input octet string
 116 * @output:     Specifies the output octet string
 117 * @output_len: Size of the output octet string
 118 * @return 0 if the octet string was correctly generated, others on error
 119 */
 120static int mask_generation_function1(struct checksum_algo *checksum,
 121                                     uint8_t *seed, int seed_len,
 122                                     uint8_t *output, int output_len)
 123{
 124        struct image_region region[2];
 125        int ret = 0, i, i_output = 0, region_count = 2;
 126        uint32_t counter = 0;
 127        uint8_t buf_counter[4], *tmp;
 128        int hash_len = checksum->checksum_len;
 129
 130        memset(output, 0, output_len);
 131
 132        region[0].data = seed;
 133        region[0].size = seed_len;
 134        region[1].data = &buf_counter[0];
 135        region[1].size = 4;
 136
 137        tmp = malloc(hash_len);
 138        if (!tmp) {
 139                debug("%s: can't allocate array tmp\n", __func__);
 140                ret = -ENOMEM;
 141                goto out;
 142        }
 143
 144        while (i_output < output_len) {
 145                u32_i2osp(counter, &buf_counter[0]);
 146
 147                ret = checksum->calculate(checksum->name,
 148                                          region, region_count,
 149                                          tmp);
 150                if (ret < 0) {
 151                        debug("%s: Error in checksum calculation\n", __func__);
 152                        goto out;
 153                }
 154
 155                i = 0;
 156                while ((i_output < output_len) && (i < hash_len)) {
 157                        output[i_output] = tmp[i];
 158                        i_output++;
 159                        i++;
 160                }
 161
 162                counter++;
 163        }
 164
 165out:
 166        free(tmp);
 167
 168        return ret;
 169}
 170
 171static int compute_hash_prime(struct checksum_algo *checksum,
 172                              uint8_t *pad, int pad_len,
 173                              uint8_t *hash, int hash_len,
 174                              uint8_t *salt, int salt_len,
 175                              uint8_t *hprime)
 176{
 177        struct image_region region[3];
 178        int ret, region_count = 3;
 179
 180        region[0].data = pad;
 181        region[0].size = pad_len;
 182        region[1].data = hash;
 183        region[1].size = hash_len;
 184        region[2].data = salt;
 185        region[2].size = salt_len;
 186
 187        ret = checksum->calculate(checksum->name, region, region_count, hprime);
 188        if (ret < 0) {
 189                debug("%s: Error in checksum calculation\n", __func__);
 190                goto out;
 191        }
 192
 193out:
 194        return ret;
 195}
 196
 197/*
 198 * padding_pss_verify() - verify the pss padding of a signature
 199 *
 200 * Only works with a rsa_pss_saltlen:-2 (default value) right now
 201 * saltlen:-1 "set the salt length to the digest length" is currently
 202 * not supported.
 203 *
 204 * @info:       Specifies key and FIT information
 205 * @msg:        byte array of message, len equal to msg_len
 206 * @msg_len:    Message length
 207 * @hash:       Pointer to the expected hash
 208 * @hash_len:   Length of the hash
 209 */
 210int padding_pss_verify(struct image_sign_info *info,
 211                       uint8_t *msg, int msg_len,
 212                       const uint8_t *hash, int hash_len)
 213{
 214        uint8_t *masked_db = NULL;
 215        int masked_db_len = msg_len - hash_len - 1;
 216        uint8_t *h = NULL, *hprime = NULL;
 217        int h_len = hash_len;
 218        uint8_t *db_mask = NULL;
 219        int db_mask_len = masked_db_len;
 220        uint8_t *db = NULL, *salt = NULL;
 221        int db_len = masked_db_len, salt_len = msg_len - hash_len - 2;
 222        uint8_t pad_zero[8] = { 0 };
 223        int ret, i, leftmost_bits = 1;
 224        uint8_t leftmost_mask;
 225        struct checksum_algo *checksum = info->checksum;
 226
 227        /* first, allocate everything */
 228        masked_db = malloc(masked_db_len);
 229        h = malloc(h_len);
 230        db_mask = malloc(db_mask_len);
 231        db = malloc(db_len);
 232        salt = malloc(salt_len);
 233        hprime = malloc(hash_len);
 234        if (!masked_db || !h || !db_mask || !db || !salt || !hprime) {
 235                printf("%s: can't allocate some buffer\n", __func__);
 236                ret = -ENOMEM;
 237                goto out;
 238        }
 239
 240        /* step 4: check if the last byte is 0xbc */
 241        if (msg[msg_len - 1] != 0xbc) {
 242                printf("%s: invalid pss padding (0xbc is missing)\n", __func__);
 243                ret = -EINVAL;
 244                goto out;
 245        }
 246
 247        /* step 5 */
 248        memcpy(masked_db, msg, masked_db_len);
 249        memcpy(h, msg + masked_db_len, h_len);
 250
 251        /* step 6 */
 252        leftmost_mask = (0xff >> (8 - leftmost_bits)) << (8 - leftmost_bits);
 253        if (masked_db[0] & leftmost_mask) {
 254                printf("%s: invalid pss padding ", __func__);
 255                printf("(leftmost bit of maskedDB not zero)\n");
 256                ret = -EINVAL;
 257                goto out;
 258        }
 259
 260        /* step 7 */
 261        mask_generation_function1(checksum, h, h_len, db_mask, db_mask_len);
 262
 263        /* step 8 */
 264        for (i = 0; i < db_len; i++)
 265                db[i] = masked_db[i] ^ db_mask[i];
 266
 267        /* step 9 */
 268        db[0] &= 0xff >> leftmost_bits;
 269
 270        /* step 10 */
 271        if (db[0] != 0x01) {
 272                printf("%s: invalid pss padding ", __func__);
 273                printf("(leftmost byte of db isn't 0x01)\n");
 274                ret = EINVAL;
 275                goto out;
 276        }
 277
 278        /* step 11 */
 279        memcpy(salt, &db[1], salt_len);
 280
 281        /* step 12 & 13 */
 282        compute_hash_prime(checksum, pad_zero, 8,
 283                           (uint8_t *)hash, hash_len,
 284                           salt, salt_len, hprime);
 285
 286        /* step 14 */
 287        ret = memcmp(h, hprime, hash_len);
 288
 289out:
 290        free(hprime);
 291        free(salt);
 292        free(db);
 293        free(db_mask);
 294        free(h);
 295        free(masked_db);
 296
 297        return ret;
 298}
 299#endif
 300
 301#if CONFIG_IS_ENABLED(FIT_SIGNATURE) || CONFIG_IS_ENABLED(RSA_VERIFY_WITH_PKEY)
 302/**
 303 * rsa_verify_key() - Verify a signature against some data using RSA Key
 304 *
 305 * Verify a RSA PKCS1.5 signature against an expected hash using
 306 * the RSA Key properties in prop structure.
 307 *
 308 * @info:       Specifies key and FIT information
 309 * @prop:       Specifies key
 310 * @sig:        Signature
 311 * @sig_len:    Number of bytes in signature
 312 * @hash:       Pointer to the expected hash
 313 * @key_len:    Number of bytes in rsa key
 314 * @return 0 if verified, -ve on error
 315 */
 316static int rsa_verify_key(struct image_sign_info *info,
 317                          struct key_prop *prop, const uint8_t *sig,
 318                          const uint32_t sig_len, const uint8_t *hash,
 319                          const uint32_t key_len)
 320{
 321        int ret;
 322#if !defined(USE_HOSTCC)
 323        struct udevice *mod_exp_dev;
 324#endif
 325        struct checksum_algo *checksum = info->checksum;
 326        struct padding_algo *padding = info->padding;
 327        int hash_len;
 328
 329        if (!prop || !sig || !hash || !checksum)
 330                return -EIO;
 331
 332        if (sig_len != (prop->num_bits / 8)) {
 333                debug("Signature is of incorrect length %d\n", sig_len);
 334                return -EINVAL;
 335        }
 336
 337        debug("Checksum algorithm: %s", checksum->name);
 338
 339        /* Sanity check for stack size */
 340        if (sig_len > RSA_MAX_SIG_BITS / 8) {
 341                debug("Signature length %u exceeds maximum %d\n", sig_len,
 342                      RSA_MAX_SIG_BITS / 8);
 343                return -EINVAL;
 344        }
 345
 346        uint8_t buf[sig_len];
 347        hash_len = checksum->checksum_len;
 348
 349#if !defined(USE_HOSTCC)
 350        ret = uclass_get_device(UCLASS_MOD_EXP, 0, &mod_exp_dev);
 351        if (ret) {
 352                printf("RSA: Can't find Modular Exp implementation\n");
 353                return -EINVAL;
 354        }
 355
 356        ret = rsa_mod_exp(mod_exp_dev, sig, sig_len, prop, buf);
 357#else
 358        ret = rsa_mod_exp_sw(sig, sig_len, prop, buf);
 359#endif
 360        if (ret) {
 361                debug("Error in Modular exponentation\n");
 362                return ret;
 363        }
 364
 365        ret = padding->verify(info, buf, key_len, hash, hash_len);
 366        if (ret) {
 367                debug("In RSAVerify(): padding check failed!\n");
 368                return ret;
 369        }
 370
 371        return 0;
 372}
 373#endif
 374
 375#if CONFIG_IS_ENABLED(RSA_VERIFY_WITH_PKEY)
 376/**
 377 * rsa_verify_with_pkey() - Verify a signature against some data using
 378 * only modulus and exponent as RSA key properties.
 379 * @info:       Specifies key information
 380 * @hash:       Pointer to the expected hash
 381 * @sig:        Signature
 382 * @sig_len:    Number of bytes in signature
 383 *
 384 * Parse a RSA public key blob in DER format pointed to in @info and fill
 385 * a key_prop structure with properties of the key. Then verify a RSA PKCS1.5
 386 * signature against an expected hash using the calculated properties.
 387 *
 388 * Return       0 if verified, -ve on error
 389 */
 390int rsa_verify_with_pkey(struct image_sign_info *info,
 391                         const void *hash, uint8_t *sig, uint sig_len)
 392{
 393        struct key_prop *prop;
 394        int ret;
 395
 396        /* Public key is self-described to fill key_prop */
 397        ret = rsa_gen_key_prop(info->key, info->keylen, &prop);
 398        if (ret) {
 399                debug("Generating necessary parameter for decoding failed\n");
 400                return ret;
 401        }
 402
 403        ret = rsa_verify_key(info, prop, sig, sig_len, hash,
 404                             info->crypto->key_len);
 405
 406        rsa_free_key_prop(prop);
 407
 408        return ret;
 409}
 410#else
 411int rsa_verify_with_pkey(struct image_sign_info *info,
 412                         const void *hash, uint8_t *sig, uint sig_len)
 413{
 414        return -EACCES;
 415}
 416#endif
 417
 418#if CONFIG_IS_ENABLED(FIT_SIGNATURE)
 419/**
 420 * rsa_verify_with_keynode() - Verify a signature against some data using
 421 * information in node with prperties of RSA Key like modulus, exponent etc.
 422 *
 423 * Parse sign-node and fill a key_prop structure with properties of the
 424 * key.  Verify a RSA PKCS1.5 signature against an expected hash using
 425 * the properties parsed
 426 *
 427 * @info:       Specifies key and FIT information
 428 * @hash:       Pointer to the expected hash
 429 * @sig:        Signature
 430 * @sig_len:    Number of bytes in signature
 431 * @node:       Node having the RSA Key properties
 432 * @return 0 if verified, -ve on error
 433 */
 434static int rsa_verify_with_keynode(struct image_sign_info *info,
 435                                   const void *hash, uint8_t *sig,
 436                                   uint sig_len, int node)
 437{
 438        const void *blob = info->fdt_blob;
 439        struct key_prop prop;
 440        int length;
 441        int ret = 0;
 442        const char *algo;
 443
 444        if (node < 0) {
 445                debug("%s: Skipping invalid node", __func__);
 446                return -EBADF;
 447        }
 448
 449        algo = fdt_getprop(blob, node, "algo", NULL);
 450        if (strcmp(info->name, algo)) {
 451                debug("%s: Wrong algo: have %s, expected %s", __func__,
 452                      info->name, algo);
 453                return -EFAULT;
 454        }
 455
 456        prop.num_bits = fdtdec_get_int(blob, node, "rsa,num-bits", 0);
 457
 458        prop.n0inv = fdtdec_get_int(blob, node, "rsa,n0-inverse", 0);
 459
 460        prop.public_exponent = fdt_getprop(blob, node, "rsa,exponent", &length);
 461        if (!prop.public_exponent || length < sizeof(uint64_t))
 462                prop.public_exponent = NULL;
 463
 464        prop.exp_len = sizeof(uint64_t);
 465
 466        prop.modulus = fdt_getprop(blob, node, "rsa,modulus", NULL);
 467
 468        prop.rr = fdt_getprop(blob, node, "rsa,r-squared", NULL);
 469
 470        if (!prop.num_bits || !prop.modulus || !prop.rr) {
 471                debug("%s: Missing RSA key info", __func__);
 472                return -EFAULT;
 473        }
 474
 475        ret = rsa_verify_key(info, &prop, sig, sig_len, hash,
 476                             info->crypto->key_len);
 477
 478        return ret;
 479}
 480#else
 481static int rsa_verify_with_keynode(struct image_sign_info *info,
 482                                   const void *hash, uint8_t *sig,
 483                                   uint sig_len, int node)
 484{
 485        return -EACCES;
 486}
 487#endif
 488
 489int rsa_verify_hash(struct image_sign_info *info,
 490                    const uint8_t *hash, uint8_t *sig, uint sig_len)
 491{
 492        int ret = -EACCES;
 493
 494        if (CONFIG_IS_ENABLED(RSA_VERIFY_WITH_PKEY) && !info->fdt_blob) {
 495                /* don't rely on fdt properties */
 496                ret = rsa_verify_with_pkey(info, hash, sig, sig_len);
 497
 498                return ret;
 499        }
 500
 501        if (CONFIG_IS_ENABLED(FIT_SIGNATURE)) {
 502                const void *blob = info->fdt_blob;
 503                int ndepth, noffset;
 504                int sig_node, node;
 505                char name[100];
 506
 507                sig_node = fdt_subnode_offset(blob, 0, FIT_SIG_NODENAME);
 508                if (sig_node < 0) {
 509                        debug("%s: No signature node found\n", __func__);
 510                        return -ENOENT;
 511                }
 512
 513                /* See if we must use a particular key */
 514                if (info->required_keynode != -1) {
 515                        ret = rsa_verify_with_keynode(info, hash, sig, sig_len,
 516                                                      info->required_keynode);
 517                        return ret;
 518                }
 519
 520                /* Look for a key that matches our hint */
 521                snprintf(name, sizeof(name), "key-%s", info->keyname);
 522                node = fdt_subnode_offset(blob, sig_node, name);
 523                ret = rsa_verify_with_keynode(info, hash, sig, sig_len, node);
 524                if (!ret)
 525                        return ret;
 526
 527                /* No luck, so try each of the keys in turn */
 528                for (ndepth = 0, noffset = fdt_next_node(blob, sig_node,
 529                                                         &ndepth);
 530                     (noffset >= 0) && (ndepth > 0);
 531                     noffset = fdt_next_node(blob, noffset, &ndepth)) {
 532                        if (ndepth == 1 && noffset != node) {
 533                                ret = rsa_verify_with_keynode(info, hash,
 534                                                              sig, sig_len,
 535                                                              noffset);
 536                                if (!ret)
 537                                        break;
 538                        }
 539                }
 540        }
 541
 542        return ret;
 543}
 544
 545int rsa_verify(struct image_sign_info *info,
 546               const struct image_region region[], int region_count,
 547               uint8_t *sig, uint sig_len)
 548{
 549        /* Reserve memory for maximum checksum-length */
 550        uint8_t hash[info->crypto->key_len];
 551        int ret;
 552
 553        /*
 554         * Verify that the checksum-length does not exceed the
 555         * rsa-signature-length
 556         */
 557        if (info->checksum->checksum_len >
 558            info->crypto->key_len) {
 559                debug("%s: invlaid checksum-algorithm %s for %s\n",
 560                      __func__, info->checksum->name, info->crypto->name);
 561                return -EINVAL;
 562        }
 563
 564        /* Calculate checksum with checksum-algorithm */
 565        ret = info->checksum->calculate(info->checksum->name,
 566                                        region, region_count, hash);
 567        if (ret < 0) {
 568                debug("%s: Error in checksum calculation\n", __func__);
 569                return -EINVAL;
 570        }
 571
 572        return rsa_verify_hash(info, hash, sig, sig_len);
 573}
 574