uboot/lib/rsa/rsa-mod-exp.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2013, Google Inc.
   3 *
   4 * SPDX-License-Identifier:     GPL-2.0+
   5 */
   6
   7#ifndef USE_HOSTCC
   8#include <common.h>
   9#include <fdtdec.h>
  10#include <asm/types.h>
  11#include <asm/byteorder.h>
  12#include <linux/errno.h>
  13#include <asm/types.h>
  14#include <asm/unaligned.h>
  15#else
  16#include "fdt_host.h"
  17#include "mkimage.h"
  18#include <fdt_support.h>
  19#endif
  20#include <u-boot/rsa.h>
  21#include <u-boot/rsa-mod-exp.h>
  22
  23#define UINT64_MULT32(v, multby)  (((uint64_t)(v)) * ((uint32_t)(multby)))
  24
  25#define get_unaligned_be32(a) fdt32_to_cpu(*(uint32_t *)a)
  26#define put_unaligned_be32(a, b) (*(uint32_t *)(b) = cpu_to_fdt32(a))
  27
  28/* Default public exponent for backward compatibility */
  29#define RSA_DEFAULT_PUBEXP      65537
  30
  31/**
  32 * subtract_modulus() - subtract modulus from the given value
  33 *
  34 * @key:        Key containing modulus to subtract
  35 * @num:        Number to subtract modulus from, as little endian word array
  36 */
  37static void subtract_modulus(const struct rsa_public_key *key, uint32_t num[])
  38{
  39        int64_t acc = 0;
  40        uint i;
  41
  42        for (i = 0; i < key->len; i++) {
  43                acc += (uint64_t)num[i] - key->modulus[i];
  44                num[i] = (uint32_t)acc;
  45                acc >>= 32;
  46        }
  47}
  48
  49/**
  50 * greater_equal_modulus() - check if a value is >= modulus
  51 *
  52 * @key:        Key containing modulus to check
  53 * @num:        Number to check against modulus, as little endian word array
  54 * @return 0 if num < modulus, 1 if num >= modulus
  55 */
  56static int greater_equal_modulus(const struct rsa_public_key *key,
  57                                 uint32_t num[])
  58{
  59        int i;
  60
  61        for (i = (int)key->len - 1; i >= 0; i--) {
  62                if (num[i] < key->modulus[i])
  63                        return 0;
  64                if (num[i] > key->modulus[i])
  65                        return 1;
  66        }
  67
  68        return 1;  /* equal */
  69}
  70
  71/**
  72 * montgomery_mul_add_step() - Perform montgomery multiply-add step
  73 *
  74 * Operation: montgomery result[] += a * b[] / n0inv % modulus
  75 *
  76 * @key:        RSA key
  77 * @result:     Place to put result, as little endian word array
  78 * @a:          Multiplier
  79 * @b:          Multiplicand, as little endian word array
  80 */
  81static void montgomery_mul_add_step(const struct rsa_public_key *key,
  82                uint32_t result[], const uint32_t a, const uint32_t b[])
  83{
  84        uint64_t acc_a, acc_b;
  85        uint32_t d0;
  86        uint i;
  87
  88        acc_a = (uint64_t)a * b[0] + result[0];
  89        d0 = (uint32_t)acc_a * key->n0inv;
  90        acc_b = (uint64_t)d0 * key->modulus[0] + (uint32_t)acc_a;
  91        for (i = 1; i < key->len; i++) {
  92                acc_a = (acc_a >> 32) + (uint64_t)a * b[i] + result[i];
  93                acc_b = (acc_b >> 32) + (uint64_t)d0 * key->modulus[i] +
  94                                (uint32_t)acc_a;
  95                result[i - 1] = (uint32_t)acc_b;
  96        }
  97
  98        acc_a = (acc_a >> 32) + (acc_b >> 32);
  99
 100        result[i - 1] = (uint32_t)acc_a;
 101
 102        if (acc_a >> 32)
 103                subtract_modulus(key, result);
 104}
 105
 106/**
 107 * montgomery_mul() - Perform montgomery mutitply
 108 *
 109 * Operation: montgomery result[] = a[] * b[] / n0inv % modulus
 110 *
 111 * @key:        RSA key
 112 * @result:     Place to put result, as little endian word array
 113 * @a:          Multiplier, as little endian word array
 114 * @b:          Multiplicand, as little endian word array
 115 */
 116static void montgomery_mul(const struct rsa_public_key *key,
 117                uint32_t result[], uint32_t a[], const uint32_t b[])
 118{
 119        uint i;
 120
 121        for (i = 0; i < key->len; ++i)
 122                result[i] = 0;
 123        for (i = 0; i < key->len; ++i)
 124                montgomery_mul_add_step(key, result, a[i], b);
 125}
 126
 127/**
 128 * num_pub_exponent_bits() - Number of bits in the public exponent
 129 *
 130 * @key:        RSA key
 131 * @num_bits:   Storage for the number of public exponent bits
 132 */
 133static int num_public_exponent_bits(const struct rsa_public_key *key,
 134                int *num_bits)
 135{
 136        uint64_t exponent;
 137        int exponent_bits;
 138        const uint max_bits = (sizeof(exponent) * 8);
 139
 140        exponent = key->exponent;
 141        exponent_bits = 0;
 142
 143        if (!exponent) {
 144                *num_bits = exponent_bits;
 145                return 0;
 146        }
 147
 148        for (exponent_bits = 1; exponent_bits < max_bits + 1; ++exponent_bits)
 149                if (!(exponent >>= 1)) {
 150                        *num_bits = exponent_bits;
 151                        return 0;
 152                }
 153
 154        return -EINVAL;
 155}
 156
 157/**
 158 * is_public_exponent_bit_set() - Check if a bit in the public exponent is set
 159 *
 160 * @key:        RSA key
 161 * @pos:        The bit position to check
 162 */
 163static int is_public_exponent_bit_set(const struct rsa_public_key *key,
 164                int pos)
 165{
 166        return key->exponent & (1ULL << pos);
 167}
 168
 169/**
 170 * pow_mod() - in-place public exponentiation
 171 *
 172 * @key:        RSA key
 173 * @inout:      Big-endian word array containing value and result
 174 */
 175static int pow_mod(const struct rsa_public_key *key, uint32_t *inout)
 176{
 177        uint32_t *result, *ptr;
 178        uint i;
 179        int j, k;
 180
 181        /* Sanity check for stack size - key->len is in 32-bit words */
 182        if (key->len > RSA_MAX_KEY_BITS / 32) {
 183                debug("RSA key words %u exceeds maximum %d\n", key->len,
 184                      RSA_MAX_KEY_BITS / 32);
 185                return -EINVAL;
 186        }
 187
 188        uint32_t val[key->len], acc[key->len], tmp[key->len];
 189        uint32_t a_scaled[key->len];
 190        result = tmp;  /* Re-use location. */
 191
 192        /* Convert from big endian byte array to little endian word array. */
 193        for (i = 0, ptr = inout + key->len - 1; i < key->len; i++, ptr--)
 194                val[i] = get_unaligned_be32(ptr);
 195
 196        if (0 != num_public_exponent_bits(key, &k))
 197                return -EINVAL;
 198
 199        if (k < 2) {
 200                debug("Public exponent is too short (%d bits, minimum 2)\n",
 201                      k);
 202                return -EINVAL;
 203        }
 204
 205        if (!is_public_exponent_bit_set(key, 0)) {
 206                debug("LSB of RSA public exponent must be set.\n");
 207                return -EINVAL;
 208        }
 209
 210        /* the bit at e[k-1] is 1 by definition, so start with: C := M */
 211        montgomery_mul(key, acc, val, key->rr); /* acc = a * RR / R mod n */
 212        /* retain scaled version for intermediate use */
 213        memcpy(a_scaled, acc, key->len * sizeof(a_scaled[0]));
 214
 215        for (j = k - 2; j > 0; --j) {
 216                montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */
 217
 218                if (is_public_exponent_bit_set(key, j)) {
 219                        /* acc = tmp * val / R mod n */
 220                        montgomery_mul(key, acc, tmp, a_scaled);
 221                } else {
 222                        /* e[j] == 0, copy tmp back to acc for next operation */
 223                        memcpy(acc, tmp, key->len * sizeof(acc[0]));
 224                }
 225        }
 226
 227        /* the bit at e[0] is always 1 */
 228        montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */
 229        montgomery_mul(key, acc, tmp, val); /* acc = tmp * a / R mod M */
 230        memcpy(result, acc, key->len * sizeof(result[0]));
 231
 232        /* Make sure result < mod; result is at most 1x mod too large. */
 233        if (greater_equal_modulus(key, result))
 234                subtract_modulus(key, result);
 235
 236        /* Convert to bigendian byte array */
 237        for (i = key->len - 1, ptr = inout; (int)i >= 0; i--, ptr++)
 238                put_unaligned_be32(result[i], ptr);
 239        return 0;
 240}
 241
 242static void rsa_convert_big_endian(uint32_t *dst, const uint32_t *src, int len)
 243{
 244        int i;
 245
 246        for (i = 0; i < len; i++)
 247                dst[i] = fdt32_to_cpu(src[len - 1 - i]);
 248}
 249
 250int rsa_mod_exp_sw(const uint8_t *sig, uint32_t sig_len,
 251                struct key_prop *prop, uint8_t *out)
 252{
 253        struct rsa_public_key key;
 254        int ret;
 255
 256        if (!prop) {
 257                debug("%s: Skipping invalid prop", __func__);
 258                return -EBADF;
 259        }
 260        key.n0inv = prop->n0inv;
 261        key.len = prop->num_bits;
 262
 263        if (!prop->public_exponent)
 264                key.exponent = RSA_DEFAULT_PUBEXP;
 265        else
 266                key.exponent =
 267                        fdt64_to_cpu(*((uint64_t *)(prop->public_exponent)));
 268
 269        if (!key.len || !prop->modulus || !prop->rr) {
 270                debug("%s: Missing RSA key info", __func__);
 271                return -EFAULT;
 272        }
 273
 274        /* Sanity check for stack size */
 275        if (key.len > RSA_MAX_KEY_BITS || key.len < RSA_MIN_KEY_BITS) {
 276                debug("RSA key bits %u outside allowed range %d..%d\n",
 277                      key.len, RSA_MIN_KEY_BITS, RSA_MAX_KEY_BITS);
 278                return -EFAULT;
 279        }
 280        key.len /= sizeof(uint32_t) * 8;
 281        uint32_t key1[key.len], key2[key.len];
 282
 283        key.modulus = key1;
 284        key.rr = key2;
 285        rsa_convert_big_endian(key.modulus, (uint32_t *)prop->modulus, key.len);
 286        rsa_convert_big_endian(key.rr, (uint32_t *)prop->rr, key.len);
 287        if (!key.modulus || !key.rr) {
 288                debug("%s: Out of memory", __func__);
 289                return -ENOMEM;
 290        }
 291
 292        uint32_t buf[sig_len / sizeof(uint32_t)];
 293
 294        memcpy(buf, sig, sig_len);
 295
 296        ret = pow_mod(&key, buf);
 297        if (ret)
 298                return ret;
 299
 300        memcpy(out, buf, sig_len);
 301
 302        return 0;
 303}
 304