uboot/lib/rsa/rsa-verify.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0+
   2/*
   3 * Copyright (c) 2013, Google Inc.
   4 */
   5
   6#ifndef USE_HOSTCC
   7#include <common.h>
   8#include <fdtdec.h>
   9#include <asm/types.h>
  10#include <asm/byteorder.h>
  11#include <linux/errno.h>
  12#include <asm/types.h>
  13#include <asm/unaligned.h>
  14#include <dm.h>
  15#else
  16#include "fdt_host.h"
  17#include "mkimage.h"
  18#include <fdt_support.h>
  19#endif
  20#include <u-boot/rsa-mod-exp.h>
  21#include <u-boot/rsa.h>
  22
  23/* Default public exponent for backward compatibility */
  24#define RSA_DEFAULT_PUBEXP      65537
  25
  26/**
  27 * rsa_verify_padding() - Verify RSA message padding is valid
  28 *
  29 * Verify a RSA message's padding is consistent with PKCS1.5
  30 * padding as described in the RSA PKCS#1 v2.1 standard.
  31 *
  32 * @msg:        Padded message
  33 * @pad_len:    Number of expected padding bytes
  34 * @algo:       Checksum algo structure having information on DER encoding etc.
  35 * @return 0 on success, != 0 on failure
  36 */
  37static int rsa_verify_padding(const uint8_t *msg, const int pad_len,
  38                              struct checksum_algo *algo)
  39{
  40        int ff_len;
  41        int ret;
  42
  43        /* first byte must be 0x00 */
  44        ret = *msg++;
  45        /* second byte must be 0x01 */
  46        ret |= *msg++ ^ 0x01;
  47        /* next ff_len bytes must be 0xff */
  48        ff_len = pad_len - algo->der_len - 3;
  49        ret |= *msg ^ 0xff;
  50        ret |= memcmp(msg, msg+1, ff_len-1);
  51        msg += ff_len;
  52        /* next byte must be 0x00 */
  53        ret |= *msg++;
  54        /* next der_len bytes must match der_prefix */
  55        ret |= memcmp(msg, algo->der_prefix, algo->der_len);
  56
  57        return ret;
  58}
  59
  60int padding_pkcs_15_verify(struct image_sign_info *info,
  61                           uint8_t *msg, int msg_len,
  62                           const uint8_t *hash, int hash_len)
  63{
  64        struct checksum_algo *checksum = info->checksum;
  65        int ret, pad_len = msg_len - checksum->checksum_len;
  66
  67        /* Check pkcs1.5 padding bytes. */
  68        ret = rsa_verify_padding(msg, pad_len, checksum);
  69        if (ret) {
  70                debug("In RSAVerify(): Padding check failed!\n");
  71                return -EINVAL;
  72        }
  73
  74        /* Check hash. */
  75        if (memcmp((uint8_t *)msg + pad_len, hash, msg_len - pad_len)) {
  76                debug("In RSAVerify(): Hash check failed!\n");
  77                return -EACCES;
  78        }
  79
  80        return 0;
  81}
  82
  83#ifdef CONFIG_FIT_ENABLE_RSASSA_PSS_SUPPORT
  84static void u32_i2osp(uint32_t val, uint8_t *buf)
  85{
  86        buf[0] = (uint8_t)((val >> 24) & 0xff);
  87        buf[1] = (uint8_t)((val >> 16) & 0xff);
  88        buf[2] = (uint8_t)((val >>  8) & 0xff);
  89        buf[3] = (uint8_t)((val >>  0) & 0xff);
  90}
  91
  92/**
  93 * mask_generation_function1() - generate an octet string
  94 *
  95 * Generate an octet string used to check rsa signature.
  96 * It use an input octet string and a hash function.
  97 *
  98 * @checksum:   A Hash function
  99 * @seed:       Specifies an input variable octet string
 100 * @seed_len:   Size of the input octet string
 101 * @output:     Specifies the output octet string
 102 * @output_len: Size of the output octet string
 103 * @return 0 if the octet string was correctly generated, others on error
 104 */
 105static int mask_generation_function1(struct checksum_algo *checksum,
 106                                     uint8_t *seed, int seed_len,
 107                                     uint8_t *output, int output_len)
 108{
 109        struct image_region region[2];
 110        int ret = 0, i, i_output = 0, region_count = 2;
 111        uint32_t counter = 0;
 112        uint8_t buf_counter[4], *tmp;
 113        int hash_len = checksum->checksum_len;
 114
 115        memset(output, 0, output_len);
 116
 117        region[0].data = seed;
 118        region[0].size = seed_len;
 119        region[1].data = &buf_counter[0];
 120        region[1].size = 4;
 121
 122        tmp = malloc(hash_len);
 123        if (!tmp) {
 124                debug("%s: can't allocate array tmp\n", __func__);
 125                ret = -ENOMEM;
 126                goto out;
 127        }
 128
 129        while (i_output < output_len) {
 130                u32_i2osp(counter, &buf_counter[0]);
 131
 132                ret = checksum->calculate(checksum->name,
 133                                          region, region_count,
 134                                          tmp);
 135                if (ret < 0) {
 136                        debug("%s: Error in checksum calculation\n", __func__);
 137                        goto out;
 138                }
 139
 140                i = 0;
 141                while ((i_output < output_len) && (i < hash_len)) {
 142                        output[i_output] = tmp[i];
 143                        i_output++;
 144                        i++;
 145                }
 146
 147                counter++;
 148        }
 149
 150out:
 151        free(tmp);
 152
 153        return ret;
 154}
 155
 156static int compute_hash_prime(struct checksum_algo *checksum,
 157                              uint8_t *pad, int pad_len,
 158                              uint8_t *hash, int hash_len,
 159                              uint8_t *salt, int salt_len,
 160                              uint8_t *hprime)
 161{
 162        struct image_region region[3];
 163        int ret, region_count = 3;
 164
 165        region[0].data = pad;
 166        region[0].size = pad_len;
 167        region[1].data = hash;
 168        region[1].size = hash_len;
 169        region[2].data = salt;
 170        region[2].size = salt_len;
 171
 172        ret = checksum->calculate(checksum->name, region, region_count, hprime);
 173        if (ret < 0) {
 174                debug("%s: Error in checksum calculation\n", __func__);
 175                goto out;
 176        }
 177
 178out:
 179        return ret;
 180}
 181
 182int padding_pss_verify(struct image_sign_info *info,
 183                       uint8_t *msg, int msg_len,
 184                       const uint8_t *hash, int hash_len)
 185{
 186        uint8_t *masked_db = NULL;
 187        int masked_db_len = msg_len - hash_len - 1;
 188        uint8_t *h = NULL, *hprime = NULL;
 189        int h_len = hash_len;
 190        uint8_t *db_mask = NULL;
 191        int db_mask_len = masked_db_len;
 192        uint8_t *db = NULL, *salt = NULL;
 193        int db_len = masked_db_len, salt_len = msg_len - hash_len - 2;
 194        uint8_t pad_zero[8] = { 0 };
 195        int ret, i, leftmost_bits = 1;
 196        uint8_t leftmost_mask;
 197        struct checksum_algo *checksum = info->checksum;
 198
 199        /* first, allocate everything */
 200        masked_db = malloc(masked_db_len);
 201        h = malloc(h_len);
 202        db_mask = malloc(db_mask_len);
 203        db = malloc(db_len);
 204        salt = malloc(salt_len);
 205        hprime = malloc(hash_len);
 206        if (!masked_db || !h || !db_mask || !db || !salt || !hprime) {
 207                printf("%s: can't allocate some buffer\n", __func__);
 208                ret = -ENOMEM;
 209                goto out;
 210        }
 211
 212        /* step 4: check if the last byte is 0xbc */
 213        if (msg[msg_len - 1] != 0xbc) {
 214                printf("%s: invalid pss padding (0xbc is missing)\n", __func__);
 215                ret = -EINVAL;
 216                goto out;
 217        }
 218
 219        /* step 5 */
 220        memcpy(masked_db, msg, masked_db_len);
 221        memcpy(h, msg + masked_db_len, h_len);
 222
 223        /* step 6 */
 224        leftmost_mask = (0xff >> (8 - leftmost_bits)) << (8 - leftmost_bits);
 225        if (masked_db[0] & leftmost_mask) {
 226                printf("%s: invalid pss padding ", __func__);
 227                printf("(leftmost bit of maskedDB not zero)\n");
 228                ret = -EINVAL;
 229                goto out;
 230        }
 231
 232        /* step 7 */
 233        mask_generation_function1(checksum, h, h_len, db_mask, db_mask_len);
 234
 235        /* step 8 */
 236        for (i = 0; i < db_len; i++)
 237                db[i] = masked_db[i] ^ db_mask[i];
 238
 239        /* step 9 */
 240        db[0] &= 0xff >> leftmost_bits;
 241
 242        /* step 10 */
 243        if (db[0] != 0x01) {
 244                printf("%s: invalid pss padding ", __func__);
 245                printf("(leftmost byte of db isn't 0x01)\n");
 246                ret = EINVAL;
 247                goto out;
 248        }
 249
 250        /* step 11 */
 251        memcpy(salt, &db[1], salt_len);
 252
 253        /* step 12 & 13 */
 254        compute_hash_prime(checksum, pad_zero, 8,
 255                           (uint8_t *)hash, hash_len,
 256                           salt, salt_len, hprime);
 257
 258        /* step 14 */
 259        ret = memcmp(h, hprime, hash_len);
 260
 261out:
 262        free(hprime);
 263        free(salt);
 264        free(db);
 265        free(db_mask);
 266        free(h);
 267        free(masked_db);
 268
 269        return ret;
 270}
 271#endif
 272
 273/**
 274 * rsa_verify_key() - Verify a signature against some data using RSA Key
 275 *
 276 * Verify a RSA PKCS1.5 signature against an expected hash using
 277 * the RSA Key properties in prop structure.
 278 *
 279 * @info:       Specifies key and FIT information
 280 * @prop:       Specifies key
 281 * @sig:        Signature
 282 * @sig_len:    Number of bytes in signature
 283 * @hash:       Pointer to the expected hash
 284 * @key_len:    Number of bytes in rsa key
 285 * @return 0 if verified, -ve on error
 286 */
 287static int rsa_verify_key(struct image_sign_info *info,
 288                          struct key_prop *prop, const uint8_t *sig,
 289                          const uint32_t sig_len, const uint8_t *hash,
 290                          const uint32_t key_len)
 291{
 292        int ret;
 293#if !defined(USE_HOSTCC)
 294        struct udevice *mod_exp_dev;
 295#endif
 296        struct checksum_algo *checksum = info->checksum;
 297        struct padding_algo *padding = info->padding;
 298        int hash_len;
 299
 300        if (!prop || !sig || !hash || !checksum)
 301                return -EIO;
 302
 303        if (sig_len != (prop->num_bits / 8)) {
 304                debug("Signature is of incorrect length %d\n", sig_len);
 305                return -EINVAL;
 306        }
 307
 308        debug("Checksum algorithm: %s", checksum->name);
 309
 310        /* Sanity check for stack size */
 311        if (sig_len > RSA_MAX_SIG_BITS / 8) {
 312                debug("Signature length %u exceeds maximum %d\n", sig_len,
 313                      RSA_MAX_SIG_BITS / 8);
 314                return -EINVAL;
 315        }
 316
 317        uint8_t buf[sig_len];
 318        hash_len = checksum->checksum_len;
 319
 320#if !defined(USE_HOSTCC)
 321        ret = uclass_get_device(UCLASS_MOD_EXP, 0, &mod_exp_dev);
 322        if (ret) {
 323                printf("RSA: Can't find Modular Exp implementation\n");
 324                return -EINVAL;
 325        }
 326
 327        ret = rsa_mod_exp(mod_exp_dev, sig, sig_len, prop, buf);
 328#else
 329        ret = rsa_mod_exp_sw(sig, sig_len, prop, buf);
 330#endif
 331        if (ret) {
 332                debug("Error in Modular exponentation\n");
 333                return ret;
 334        }
 335
 336        ret = padding->verify(info, buf, key_len, hash, hash_len);
 337        if (ret) {
 338                debug("In RSAVerify(): padding check failed!\n");
 339                return ret;
 340        }
 341
 342        return 0;
 343}
 344
 345/**
 346 * rsa_verify_with_keynode() - Verify a signature against some data using
 347 * information in node with prperties of RSA Key like modulus, exponent etc.
 348 *
 349 * Parse sign-node and fill a key_prop structure with properties of the
 350 * key.  Verify a RSA PKCS1.5 signature against an expected hash using
 351 * the properties parsed
 352 *
 353 * @info:       Specifies key and FIT information
 354 * @hash:       Pointer to the expected hash
 355 * @sig:        Signature
 356 * @sig_len:    Number of bytes in signature
 357 * @node:       Node having the RSA Key properties
 358 * @return 0 if verified, -ve on error
 359 */
 360static int rsa_verify_with_keynode(struct image_sign_info *info,
 361                                   const void *hash, uint8_t *sig,
 362                                   uint sig_len, int node)
 363{
 364        const void *blob = info->fdt_blob;
 365        struct key_prop prop;
 366        int length;
 367        int ret = 0;
 368
 369        if (node < 0) {
 370                debug("%s: Skipping invalid node", __func__);
 371                return -EBADF;
 372        }
 373
 374        prop.num_bits = fdtdec_get_int(blob, node, "rsa,num-bits", 0);
 375
 376        prop.n0inv = fdtdec_get_int(blob, node, "rsa,n0-inverse", 0);
 377
 378        prop.public_exponent = fdt_getprop(blob, node, "rsa,exponent", &length);
 379        if (!prop.public_exponent || length < sizeof(uint64_t))
 380                prop.public_exponent = NULL;
 381
 382        prop.exp_len = sizeof(uint64_t);
 383
 384        prop.modulus = fdt_getprop(blob, node, "rsa,modulus", NULL);
 385
 386        prop.rr = fdt_getprop(blob, node, "rsa,r-squared", NULL);
 387
 388        if (!prop.num_bits || !prop.modulus) {
 389                debug("%s: Missing RSA key info", __func__);
 390                return -EFAULT;
 391        }
 392
 393        ret = rsa_verify_key(info, &prop, sig, sig_len, hash,
 394                             info->crypto->key_len);
 395
 396        return ret;
 397}
 398
 399int rsa_verify(struct image_sign_info *info,
 400               const struct image_region region[], int region_count,
 401               uint8_t *sig, uint sig_len)
 402{
 403        const void *blob = info->fdt_blob;
 404        /* Reserve memory for maximum checksum-length */
 405        uint8_t hash[info->crypto->key_len];
 406        int ndepth, noffset;
 407        int sig_node, node;
 408        char name[100];
 409        int ret;
 410
 411        /*
 412         * Verify that the checksum-length does not exceed the
 413         * rsa-signature-length
 414         */
 415        if (info->checksum->checksum_len >
 416            info->crypto->key_len) {
 417                debug("%s: invlaid checksum-algorithm %s for %s\n",
 418                      __func__, info->checksum->name, info->crypto->name);
 419                return -EINVAL;
 420        }
 421
 422        sig_node = fdt_subnode_offset(blob, 0, FIT_SIG_NODENAME);
 423        if (sig_node < 0) {
 424                debug("%s: No signature node found\n", __func__);
 425                return -ENOENT;
 426        }
 427
 428        /* Calculate checksum with checksum-algorithm */
 429        ret = info->checksum->calculate(info->checksum->name,
 430                                        region, region_count, hash);
 431        if (ret < 0) {
 432                debug("%s: Error in checksum calculation\n", __func__);
 433                return -EINVAL;
 434        }
 435
 436        /* See if we must use a particular key */
 437        if (info->required_keynode != -1) {
 438                ret = rsa_verify_with_keynode(info, hash, sig, sig_len,
 439                        info->required_keynode);
 440                if (!ret)
 441                        return ret;
 442        }
 443
 444        /* Look for a key that matches our hint */
 445        snprintf(name, sizeof(name), "key-%s", info->keyname);
 446        node = fdt_subnode_offset(blob, sig_node, name);
 447        ret = rsa_verify_with_keynode(info, hash, sig, sig_len, node);
 448        if (!ret)
 449                return ret;
 450
 451        /* No luck, so try each of the keys in turn */
 452        for (ndepth = 0, noffset = fdt_next_node(info->fit, sig_node, &ndepth);
 453                        (noffset >= 0) && (ndepth > 0);
 454                        noffset = fdt_next_node(info->fit, noffset, &ndepth)) {
 455                if (ndepth == 1 && noffset != node) {
 456                        ret = rsa_verify_with_keynode(info, hash, sig, sig_len,
 457                                                      noffset);
 458                        if (!ret)
 459                                break;
 460                }
 461        }
 462
 463        return ret;
 464}
 465