linux/crypto/ecc.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2013, Kenneth MacKay
   3 * All rights reserved.
   4 *
   5 * Redistribution and use in source and binary forms, with or without
   6 * modification, are permitted provided that the following conditions are
   7 * met:
   8 *  * Redistributions of source code must retain the above copyright
   9 *   notice, this list of conditions and the following disclaimer.
  10 *  * Redistributions in binary form must reproduce the above copyright
  11 *    notice, this list of conditions and the following disclaimer in the
  12 *    documentation and/or other materials provided with the distribution.
  13 *
  14 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  15 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  16 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  17 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  18 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  19 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  20 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  21 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  22 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25 */
  26
  27#include <linux/random.h>
  28#include <linux/slab.h>
  29#include <linux/swab.h>
  30#include <linux/fips.h>
  31#include <crypto/ecdh.h>
  32#include <crypto/rng.h>
  33
  34#include "ecc.h"
  35#include "ecc_curve_defs.h"
  36
  37typedef struct {
  38        u64 m_low;
  39        u64 m_high;
  40} uint128_t;
  41
  42static inline const struct ecc_curve *ecc_get_curve(unsigned int curve_id)
  43{
  44        switch (curve_id) {
  45        /* In FIPS mode only allow P256 and higher */
  46        case ECC_CURVE_NIST_P192:
  47                return fips_enabled ? NULL : &nist_p192;
  48        case ECC_CURVE_NIST_P256:
  49                return &nist_p256;
  50        default:
  51                return NULL;
  52        }
  53}
  54
  55static u64 *ecc_alloc_digits_space(unsigned int ndigits)
  56{
  57        size_t len = ndigits * sizeof(u64);
  58
  59        if (!len)
  60                return NULL;
  61
  62        return kmalloc(len, GFP_KERNEL);
  63}
  64
  65static void ecc_free_digits_space(u64 *space)
  66{
  67        kzfree(space);
  68}
  69
  70static struct ecc_point *ecc_alloc_point(unsigned int ndigits)
  71{
  72        struct ecc_point *p = kmalloc(sizeof(*p), GFP_KERNEL);
  73
  74        if (!p)
  75                return NULL;
  76
  77        p->x = ecc_alloc_digits_space(ndigits);
  78        if (!p->x)
  79                goto err_alloc_x;
  80
  81        p->y = ecc_alloc_digits_space(ndigits);
  82        if (!p->y)
  83                goto err_alloc_y;
  84
  85        p->ndigits = ndigits;
  86
  87        return p;
  88
  89err_alloc_y:
  90        ecc_free_digits_space(p->x);
  91err_alloc_x:
  92        kfree(p);
  93        return NULL;
  94}
  95
  96static void ecc_free_point(struct ecc_point *p)
  97{
  98        if (!p)
  99                return;
 100
 101        kzfree(p->x);
 102        kzfree(p->y);
 103        kzfree(p);
 104}
 105
 106static void vli_clear(u64 *vli, unsigned int ndigits)
 107{
 108        int i;
 109
 110        for (i = 0; i < ndigits; i++)
 111                vli[i] = 0;
 112}
 113
 114/* Returns true if vli == 0, false otherwise. */
 115static bool vli_is_zero(const u64 *vli, unsigned int ndigits)
 116{
 117        int i;
 118
 119        for (i = 0; i < ndigits; i++) {
 120                if (vli[i])
 121                        return false;
 122        }
 123
 124        return true;
 125}
 126
 127/* Returns nonzero if bit bit of vli is set. */
 128static u64 vli_test_bit(const u64 *vli, unsigned int bit)
 129{
 130        return (vli[bit / 64] & ((u64)1 << (bit % 64)));
 131}
 132
 133/* Counts the number of 64-bit "digits" in vli. */
 134static unsigned int vli_num_digits(const u64 *vli, unsigned int ndigits)
 135{
 136        int i;
 137
 138        /* Search from the end until we find a non-zero digit.
 139         * We do it in reverse because we expect that most digits will
 140         * be nonzero.
 141         */
 142        for (i = ndigits - 1; i >= 0 && vli[i] == 0; i--);
 143
 144        return (i + 1);
 145}
 146
 147/* Counts the number of bits required for vli. */
 148static unsigned int vli_num_bits(const u64 *vli, unsigned int ndigits)
 149{
 150        unsigned int i, num_digits;
 151        u64 digit;
 152
 153        num_digits = vli_num_digits(vli, ndigits);
 154        if (num_digits == 0)
 155                return 0;
 156
 157        digit = vli[num_digits - 1];
 158        for (i = 0; digit; i++)
 159                digit >>= 1;
 160
 161        return ((num_digits - 1) * 64 + i);
 162}
 163
 164/* Sets dest = src. */
 165static void vli_set(u64 *dest, const u64 *src, unsigned int ndigits)
 166{
 167        int i;
 168
 169        for (i = 0; i < ndigits; i++)
 170                dest[i] = src[i];
 171}
 172
 173/* Returns sign of left - right. */
 174static int vli_cmp(const u64 *left, const u64 *right, unsigned int ndigits)
 175{
 176        int i;
 177
 178        for (i = ndigits - 1; i >= 0; i--) {
 179                if (left[i] > right[i])
 180                        return 1;
 181                else if (left[i] < right[i])
 182                        return -1;
 183        }
 184
 185        return 0;
 186}
 187
 188/* Computes result = in << c, returning carry. Can modify in place
 189 * (if result == in). 0 < shift < 64.
 190 */
 191static u64 vli_lshift(u64 *result, const u64 *in, unsigned int shift,
 192                      unsigned int ndigits)
 193{
 194        u64 carry = 0;
 195        int i;
 196
 197        for (i = 0; i < ndigits; i++) {
 198                u64 temp = in[i];
 199
 200                result[i] = (temp << shift) | carry;
 201                carry = temp >> (64 - shift);
 202        }
 203
 204        return carry;
 205}
 206
 207/* Computes vli = vli >> 1. */
 208static void vli_rshift1(u64 *vli, unsigned int ndigits)
 209{
 210        u64 *end = vli;
 211        u64 carry = 0;
 212
 213        vli += ndigits;
 214
 215        while (vli-- > end) {
 216                u64 temp = *vli;
 217                *vli = (temp >> 1) | carry;
 218                carry = temp << 63;
 219        }
 220}
 221
 222/* Computes result = left + right, returning carry. Can modify in place. */
 223static u64 vli_add(u64 *result, const u64 *left, const u64 *right,
 224                   unsigned int ndigits)
 225{
 226        u64 carry = 0;
 227        int i;
 228
 229        for (i = 0; i < ndigits; i++) {
 230                u64 sum;
 231
 232                sum = left[i] + right[i] + carry;
 233                if (sum != left[i])
 234                        carry = (sum < left[i]);
 235
 236                result[i] = sum;
 237        }
 238
 239        return carry;
 240}
 241
 242/* Computes result = left - right, returning borrow. Can modify in place. */
 243static u64 vli_sub(u64 *result, const u64 *left, const u64 *right,
 244                   unsigned int ndigits)
 245{
 246        u64 borrow = 0;
 247        int i;
 248
 249        for (i = 0; i < ndigits; i++) {
 250                u64 diff;
 251
 252                diff = left[i] - right[i] - borrow;
 253                if (diff != left[i])
 254                        borrow = (diff > left[i]);
 255
 256                result[i] = diff;
 257        }
 258
 259        return borrow;
 260}
 261
 262static uint128_t mul_64_64(u64 left, u64 right)
 263{
 264        u64 a0 = left & 0xffffffffull;
 265        u64 a1 = left >> 32;
 266        u64 b0 = right & 0xffffffffull;
 267        u64 b1 = right >> 32;
 268        u64 m0 = a0 * b0;
 269        u64 m1 = a0 * b1;
 270        u64 m2 = a1 * b0;
 271        u64 m3 = a1 * b1;
 272        uint128_t result;
 273
 274        m2 += (m0 >> 32);
 275        m2 += m1;
 276
 277        /* Overflow */
 278        if (m2 < m1)
 279                m3 += 0x100000000ull;
 280
 281        result.m_low = (m0 & 0xffffffffull) | (m2 << 32);
 282        result.m_high = m3 + (m2 >> 32);
 283
 284        return result;
 285}
 286
 287static uint128_t add_128_128(uint128_t a, uint128_t b)
 288{
 289        uint128_t result;
 290
 291        result.m_low = a.m_low + b.m_low;
 292        result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low);
 293
 294        return result;
 295}
 296
 297static void vli_mult(u64 *result, const u64 *left, const u64 *right,
 298                     unsigned int ndigits)
 299{
 300        uint128_t r01 = { 0, 0 };
 301        u64 r2 = 0;
 302        unsigned int i, k;
 303
 304        /* Compute each digit of result in sequence, maintaining the
 305         * carries.
 306         */
 307        for (k = 0; k < ndigits * 2 - 1; k++) {
 308                unsigned int min;
 309
 310                if (k < ndigits)
 311                        min = 0;
 312                else
 313                        min = (k + 1) - ndigits;
 314
 315                for (i = min; i <= k && i < ndigits; i++) {
 316                        uint128_t product;
 317
 318                        product = mul_64_64(left[i], right[k - i]);
 319
 320                        r01 = add_128_128(r01, product);
 321                        r2 += (r01.m_high < product.m_high);
 322                }
 323
 324                result[k] = r01.m_low;
 325                r01.m_low = r01.m_high;
 326                r01.m_high = r2;
 327                r2 = 0;
 328        }
 329
 330        result[ndigits * 2 - 1] = r01.m_low;
 331}
 332
 333static void vli_square(u64 *result, const u64 *left, unsigned int ndigits)
 334{
 335        uint128_t r01 = { 0, 0 };
 336        u64 r2 = 0;
 337        int i, k;
 338
 339        for (k = 0; k < ndigits * 2 - 1; k++) {
 340                unsigned int min;
 341
 342                if (k < ndigits)
 343                        min = 0;
 344                else
 345                        min = (k + 1) - ndigits;
 346
 347                for (i = min; i <= k && i <= k - i; i++) {
 348                        uint128_t product;
 349
 350                        product = mul_64_64(left[i], left[k - i]);
 351
 352                        if (i < k - i) {
 353                                r2 += product.m_high >> 63;
 354                                product.m_high = (product.m_high << 1) |
 355                                                 (product.m_low >> 63);
 356                                product.m_low <<= 1;
 357                        }
 358
 359                        r01 = add_128_128(r01, product);
 360                        r2 += (r01.m_high < product.m_high);
 361                }
 362
 363                result[k] = r01.m_low;
 364                r01.m_low = r01.m_high;
 365                r01.m_high = r2;
 366                r2 = 0;
 367        }
 368
 369        result[ndigits * 2 - 1] = r01.m_low;
 370}
 371
 372/* Computes result = (left + right) % mod.
 373 * Assumes that left < mod and right < mod, result != mod.
 374 */
 375static void vli_mod_add(u64 *result, const u64 *left, const u64 *right,
 376                        const u64 *mod, unsigned int ndigits)
 377{
 378        u64 carry;
 379
 380        carry = vli_add(result, left, right, ndigits);
 381
 382        /* result > mod (result = mod + remainder), so subtract mod to
 383         * get remainder.
 384         */
 385        if (carry || vli_cmp(result, mod, ndigits) >= 0)
 386                vli_sub(result, result, mod, ndigits);
 387}
 388
 389/* Computes result = (left - right) % mod.
 390 * Assumes that left < mod and right < mod, result != mod.
 391 */
 392static void vli_mod_sub(u64 *result, const u64 *left, const u64 *right,
 393                        const u64 *mod, unsigned int ndigits)
 394{
 395        u64 borrow = vli_sub(result, left, right, ndigits);
 396
 397        /* In this case, p_result == -diff == (max int) - diff.
 398         * Since -x % d == d - x, we can get the correct result from
 399         * result + mod (with overflow).
 400         */
 401        if (borrow)
 402                vli_add(result, result, mod, ndigits);
 403}
 404
 405/* Computes p_result = p_product % curve_p.
 406 * See algorithm 5 and 6 from
 407 * http://www.isys.uni-klu.ac.at/PDF/2001-0126-MT.pdf
 408 */
 409static void vli_mmod_fast_192(u64 *result, const u64 *product,
 410                              const u64 *curve_prime, u64 *tmp)
 411{
 412        const unsigned int ndigits = 3;
 413        int carry;
 414
 415        vli_set(result, product, ndigits);
 416
 417        vli_set(tmp, &product[3], ndigits);
 418        carry = vli_add(result, result, tmp, ndigits);
 419
 420        tmp[0] = 0;
 421        tmp[1] = product[3];
 422        tmp[2] = product[4];
 423        carry += vli_add(result, result, tmp, ndigits);
 424
 425        tmp[0] = tmp[1] = product[5];
 426        tmp[2] = 0;
 427        carry += vli_add(result, result, tmp, ndigits);
 428
 429        while (carry || vli_cmp(curve_prime, result, ndigits) != 1)
 430                carry -= vli_sub(result, result, curve_prime, ndigits);
 431}
 432
 433/* Computes result = product % curve_prime
 434 * from http://www.nsa.gov/ia/_files/nist-routines.pdf
 435 */
 436static void vli_mmod_fast_256(u64 *result, const u64 *product,
 437                              const u64 *curve_prime, u64 *tmp)
 438{
 439        int carry;
 440        const unsigned int ndigits = 4;
 441
 442        /* t */
 443        vli_set(result, product, ndigits);
 444
 445        /* s1 */
 446        tmp[0] = 0;
 447        tmp[1] = product[5] & 0xffffffff00000000ull;
 448        tmp[2] = product[6];
 449        tmp[3] = product[7];
 450        carry = vli_lshift(tmp, tmp, 1, ndigits);
 451        carry += vli_add(result, result, tmp, ndigits);
 452
 453        /* s2 */
 454        tmp[1] = product[6] << 32;
 455        tmp[2] = (product[6] >> 32) | (product[7] << 32);
 456        tmp[3] = product[7] >> 32;
 457        carry += vli_lshift(tmp, tmp, 1, ndigits);
 458        carry += vli_add(result, result, tmp, ndigits);
 459
 460        /* s3 */
 461        tmp[0] = product[4];
 462        tmp[1] = product[5] & 0xffffffff;
 463        tmp[2] = 0;
 464        tmp[3] = product[7];
 465        carry += vli_add(result, result, tmp, ndigits);
 466
 467        /* s4 */
 468        tmp[0] = (product[4] >> 32) | (product[5] << 32);
 469        tmp[1] = (product[5] >> 32) | (product[6] & 0xffffffff00000000ull);
 470        tmp[2] = product[7];
 471        tmp[3] = (product[6] >> 32) | (product[4] << 32);
 472        carry += vli_add(result, result, tmp, ndigits);
 473
 474        /* d1 */
 475        tmp[0] = (product[5] >> 32) | (product[6] << 32);
 476        tmp[1] = (product[6] >> 32);
 477        tmp[2] = 0;
 478        tmp[3] = (product[4] & 0xffffffff) | (product[5] << 32);
 479        carry -= vli_sub(result, result, tmp, ndigits);
 480
 481        /* d2 */
 482        tmp[0] = product[6];
 483        tmp[1] = product[7];
 484        tmp[2] = 0;
 485        tmp[3] = (product[4] >> 32) | (product[5] & 0xffffffff00000000ull);
 486        carry -= vli_sub(result, result, tmp, ndigits);
 487
 488        /* d3 */
 489        tmp[0] = (product[6] >> 32) | (product[7] << 32);
 490        tmp[1] = (product[7] >> 32) | (product[4] << 32);
 491        tmp[2] = (product[4] >> 32) | (product[5] << 32);
 492        tmp[3] = (product[6] << 32);
 493        carry -= vli_sub(result, result, tmp, ndigits);
 494
 495        /* d4 */
 496        tmp[0] = product[7];
 497        tmp[1] = product[4] & 0xffffffff00000000ull;
 498        tmp[2] = product[5];
 499        tmp[3] = product[6] & 0xffffffff00000000ull;
 500        carry -= vli_sub(result, result, tmp, ndigits);
 501
 502        if (carry < 0) {
 503                do {
 504                        carry += vli_add(result, result, curve_prime, ndigits);
 505                } while (carry < 0);
 506        } else {
 507                while (carry || vli_cmp(curve_prime, result, ndigits) != 1)
 508                        carry -= vli_sub(result, result, curve_prime, ndigits);
 509        }
 510}
 511
 512/* Computes result = product % curve_prime
 513 *  from http://www.nsa.gov/ia/_files/nist-routines.pdf
 514*/
 515static bool vli_mmod_fast(u64 *result, u64 *product,
 516                          const u64 *curve_prime, unsigned int ndigits)
 517{
 518        u64 tmp[2 * ECC_MAX_DIGITS];
 519
 520        switch (ndigits) {
 521        case 3:
 522                vli_mmod_fast_192(result, product, curve_prime, tmp);
 523                break;
 524        case 4:
 525                vli_mmod_fast_256(result, product, curve_prime, tmp);
 526                break;
 527        default:
 528                pr_err("unsupports digits size!\n");
 529                return false;
 530        }
 531
 532        return true;
 533}
 534
 535/* Computes result = (left * right) % curve_prime. */
 536static void vli_mod_mult_fast(u64 *result, const u64 *left, const u64 *right,
 537                              const u64 *curve_prime, unsigned int ndigits)
 538{
 539        u64 product[2 * ECC_MAX_DIGITS];
 540
 541        vli_mult(product, left, right, ndigits);
 542        vli_mmod_fast(result, product, curve_prime, ndigits);
 543}
 544
 545/* Computes result = left^2 % curve_prime. */
 546static void vli_mod_square_fast(u64 *result, const u64 *left,
 547                                const u64 *curve_prime, unsigned int ndigits)
 548{
 549        u64 product[2 * ECC_MAX_DIGITS];
 550
 551        vli_square(product, left, ndigits);
 552        vli_mmod_fast(result, product, curve_prime, ndigits);
 553}
 554
 555#define EVEN(vli) (!(vli[0] & 1))
 556/* Computes result = (1 / p_input) % mod. All VLIs are the same size.
 557 * See "From Euclid's GCD to Montgomery Multiplication to the Great Divide"
 558 * https://labs.oracle.com/techrep/2001/smli_tr-2001-95.pdf
 559 */
 560static void vli_mod_inv(u64 *result, const u64 *input, const u64 *mod,
 561                        unsigned int ndigits)
 562{
 563        u64 a[ECC_MAX_DIGITS], b[ECC_MAX_DIGITS];
 564        u64 u[ECC_MAX_DIGITS], v[ECC_MAX_DIGITS];
 565        u64 carry;
 566        int cmp_result;
 567
 568        if (vli_is_zero(input, ndigits)) {
 569                vli_clear(result, ndigits);
 570                return;
 571        }
 572
 573        vli_set(a, input, ndigits);
 574        vli_set(b, mod, ndigits);
 575        vli_clear(u, ndigits);
 576        u[0] = 1;
 577        vli_clear(v, ndigits);
 578
 579        while ((cmp_result = vli_cmp(a, b, ndigits)) != 0) {
 580                carry = 0;
 581
 582                if (EVEN(a)) {
 583                        vli_rshift1(a, ndigits);
 584
 585                        if (!EVEN(u))
 586                                carry = vli_add(u, u, mod, ndigits);
 587
 588                        vli_rshift1(u, ndigits);
 589                        if (carry)
 590                                u[ndigits - 1] |= 0x8000000000000000ull;
 591                } else if (EVEN(b)) {
 592                        vli_rshift1(b, ndigits);
 593
 594                        if (!EVEN(v))
 595                                carry = vli_add(v, v, mod, ndigits);
 596
 597                        vli_rshift1(v, ndigits);
 598                        if (carry)
 599                                v[ndigits - 1] |= 0x8000000000000000ull;
 600                } else if (cmp_result > 0) {
 601                        vli_sub(a, a, b, ndigits);
 602                        vli_rshift1(a, ndigits);
 603
 604                        if (vli_cmp(u, v, ndigits) < 0)
 605                                vli_add(u, u, mod, ndigits);
 606
 607                        vli_sub(u, u, v, ndigits);
 608                        if (!EVEN(u))
 609                                carry = vli_add(u, u, mod, ndigits);
 610
 611                        vli_rshift1(u, ndigits);
 612                        if (carry)
 613                                u[ndigits - 1] |= 0x8000000000000000ull;
 614                } else {
 615                        vli_sub(b, b, a, ndigits);
 616                        vli_rshift1(b, ndigits);
 617
 618                        if (vli_cmp(v, u, ndigits) < 0)
 619                                vli_add(v, v, mod, ndigits);
 620
 621                        vli_sub(v, v, u, ndigits);
 622                        if (!EVEN(v))
 623                                carry = vli_add(v, v, mod, ndigits);
 624
 625                        vli_rshift1(v, ndigits);
 626                        if (carry)
 627                                v[ndigits - 1] |= 0x8000000000000000ull;
 628                }
 629        }
 630
 631        vli_set(result, u, ndigits);
 632}
 633
 634/* ------ Point operations ------ */
 635
 636/* Returns true if p_point is the point at infinity, false otherwise. */
 637static bool ecc_point_is_zero(const struct ecc_point *point)
 638{
 639        return (vli_is_zero(point->x, point->ndigits) &&
 640                vli_is_zero(point->y, point->ndigits));
 641}
 642
 643/* Point multiplication algorithm using Montgomery's ladder with co-Z
 644 * coordinates. From http://eprint.iacr.org/2011/338.pdf
 645 */
 646
 647/* Double in place */
 648static void ecc_point_double_jacobian(u64 *x1, u64 *y1, u64 *z1,
 649                                      u64 *curve_prime, unsigned int ndigits)
 650{
 651        /* t1 = x, t2 = y, t3 = z */
 652        u64 t4[ECC_MAX_DIGITS];
 653        u64 t5[ECC_MAX_DIGITS];
 654
 655        if (vli_is_zero(z1, ndigits))
 656                return;
 657
 658        /* t4 = y1^2 */
 659        vli_mod_square_fast(t4, y1, curve_prime, ndigits);
 660        /* t5 = x1*y1^2 = A */
 661        vli_mod_mult_fast(t5, x1, t4, curve_prime, ndigits);
 662        /* t4 = y1^4 */
 663        vli_mod_square_fast(t4, t4, curve_prime, ndigits);
 664        /* t2 = y1*z1 = z3 */
 665        vli_mod_mult_fast(y1, y1, z1, curve_prime, ndigits);
 666        /* t3 = z1^2 */
 667        vli_mod_square_fast(z1, z1, curve_prime, ndigits);
 668
 669        /* t1 = x1 + z1^2 */
 670        vli_mod_add(x1, x1, z1, curve_prime, ndigits);
 671        /* t3 = 2*z1^2 */
 672        vli_mod_add(z1, z1, z1, curve_prime, ndigits);
 673        /* t3 = x1 - z1^2 */
 674        vli_mod_sub(z1, x1, z1, curve_prime, ndigits);
 675        /* t1 = x1^2 - z1^4 */
 676        vli_mod_mult_fast(x1, x1, z1, curve_prime, ndigits);
 677
 678        /* t3 = 2*(x1^2 - z1^4) */
 679        vli_mod_add(z1, x1, x1, curve_prime, ndigits);
 680        /* t1 = 3*(x1^2 - z1^4) */
 681        vli_mod_add(x1, x1, z1, curve_prime, ndigits);
 682        if (vli_test_bit(x1, 0)) {
 683                u64 carry = vli_add(x1, x1, curve_prime, ndigits);
 684
 685                vli_rshift1(x1, ndigits);
 686                x1[ndigits - 1] |= carry << 63;
 687        } else {
 688                vli_rshift1(x1, ndigits);
 689        }
 690        /* t1 = 3/2*(x1^2 - z1^4) = B */
 691
 692        /* t3 = B^2 */
 693        vli_mod_square_fast(z1, x1, curve_prime, ndigits);
 694        /* t3 = B^2 - A */
 695        vli_mod_sub(z1, z1, t5, curve_prime, ndigits);
 696        /* t3 = B^2 - 2A = x3 */
 697        vli_mod_sub(z1, z1, t5, curve_prime, ndigits);
 698        /* t5 = A - x3 */
 699        vli_mod_sub(t5, t5, z1, curve_prime, ndigits);
 700        /* t1 = B * (A - x3) */
 701        vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
 702        /* t4 = B * (A - x3) - y1^4 = y3 */
 703        vli_mod_sub(t4, x1, t4, curve_prime, ndigits);
 704
 705        vli_set(x1, z1, ndigits);
 706        vli_set(z1, y1, ndigits);
 707        vli_set(y1, t4, ndigits);
 708}
 709
 710/* Modify (x1, y1) => (x1 * z^2, y1 * z^3) */
 711static void apply_z(u64 *x1, u64 *y1, u64 *z, u64 *curve_prime,
 712                    unsigned int ndigits)
 713{
 714        u64 t1[ECC_MAX_DIGITS];
 715
 716        vli_mod_square_fast(t1, z, curve_prime, ndigits);    /* z^2 */
 717        vli_mod_mult_fast(x1, x1, t1, curve_prime, ndigits); /* x1 * z^2 */
 718        vli_mod_mult_fast(t1, t1, z, curve_prime, ndigits);  /* z^3 */
 719        vli_mod_mult_fast(y1, y1, t1, curve_prime, ndigits); /* y1 * z^3 */
 720}
 721
 722/* P = (x1, y1) => 2P, (x2, y2) => P' */
 723static void xycz_initial_double(u64 *x1, u64 *y1, u64 *x2, u64 *y2,
 724                                u64 *p_initial_z, u64 *curve_prime,
 725                                unsigned int ndigits)
 726{
 727        u64 z[ECC_MAX_DIGITS];
 728
 729        vli_set(x2, x1, ndigits);
 730        vli_set(y2, y1, ndigits);
 731
 732        vli_clear(z, ndigits);
 733        z[0] = 1;
 734
 735        if (p_initial_z)
 736                vli_set(z, p_initial_z, ndigits);
 737
 738        apply_z(x1, y1, z, curve_prime, ndigits);
 739
 740        ecc_point_double_jacobian(x1, y1, z, curve_prime, ndigits);
 741
 742        apply_z(x2, y2, z, curve_prime, ndigits);
 743}
 744
 745/* Input P = (x1, y1, Z), Q = (x2, y2, Z)
 746 * Output P' = (x1', y1', Z3), P + Q = (x3, y3, Z3)
 747 * or P => P', Q => P + Q
 748 */
 749static void xycz_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime,
 750                     unsigned int ndigits)
 751{
 752        /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */
 753        u64 t5[ECC_MAX_DIGITS];
 754
 755        /* t5 = x2 - x1 */
 756        vli_mod_sub(t5, x2, x1, curve_prime, ndigits);
 757        /* t5 = (x2 - x1)^2 = A */
 758        vli_mod_square_fast(t5, t5, curve_prime, ndigits);
 759        /* t1 = x1*A = B */
 760        vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
 761        /* t3 = x2*A = C */
 762        vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits);
 763        /* t4 = y2 - y1 */
 764        vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
 765        /* t5 = (y2 - y1)^2 = D */
 766        vli_mod_square_fast(t5, y2, curve_prime, ndigits);
 767
 768        /* t5 = D - B */
 769        vli_mod_sub(t5, t5, x1, curve_prime, ndigits);
 770        /* t5 = D - B - C = x3 */
 771        vli_mod_sub(t5, t5, x2, curve_prime, ndigits);
 772        /* t3 = C - B */
 773        vli_mod_sub(x2, x2, x1, curve_prime, ndigits);
 774        /* t2 = y1*(C - B) */
 775        vli_mod_mult_fast(y1, y1, x2, curve_prime, ndigits);
 776        /* t3 = B - x3 */
 777        vli_mod_sub(x2, x1, t5, curve_prime, ndigits);
 778        /* t4 = (y2 - y1)*(B - x3) */
 779        vli_mod_mult_fast(y2, y2, x2, curve_prime, ndigits);
 780        /* t4 = y3 */
 781        vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
 782
 783        vli_set(x2, t5, ndigits);
 784}
 785
 786/* Input P = (x1, y1, Z), Q = (x2, y2, Z)
 787 * Output P + Q = (x3, y3, Z3), P - Q = (x3', y3', Z3)
 788 * or P => P - Q, Q => P + Q
 789 */
 790static void xycz_add_c(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime,
 791                       unsigned int ndigits)
 792{
 793        /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */
 794        u64 t5[ECC_MAX_DIGITS];
 795        u64 t6[ECC_MAX_DIGITS];
 796        u64 t7[ECC_MAX_DIGITS];
 797
 798        /* t5 = x2 - x1 */
 799        vli_mod_sub(t5, x2, x1, curve_prime, ndigits);
 800        /* t5 = (x2 - x1)^2 = A */
 801        vli_mod_square_fast(t5, t5, curve_prime, ndigits);
 802        /* t1 = x1*A = B */
 803        vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
 804        /* t3 = x2*A = C */
 805        vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits);
 806        /* t4 = y2 + y1 */
 807        vli_mod_add(t5, y2, y1, curve_prime, ndigits);
 808        /* t4 = y2 - y1 */
 809        vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
 810
 811        /* t6 = C - B */
 812        vli_mod_sub(t6, x2, x1, curve_prime, ndigits);
 813        /* t2 = y1 * (C - B) */
 814        vli_mod_mult_fast(y1, y1, t6, curve_prime, ndigits);
 815        /* t6 = B + C */
 816        vli_mod_add(t6, x1, x2, curve_prime, ndigits);
 817        /* t3 = (y2 - y1)^2 */
 818        vli_mod_square_fast(x2, y2, curve_prime, ndigits);
 819        /* t3 = x3 */
 820        vli_mod_sub(x2, x2, t6, curve_prime, ndigits);
 821
 822        /* t7 = B - x3 */
 823        vli_mod_sub(t7, x1, x2, curve_prime, ndigits);
 824        /* t4 = (y2 - y1)*(B - x3) */
 825        vli_mod_mult_fast(y2, y2, t7, curve_prime, ndigits);
 826        /* t4 = y3 */
 827        vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
 828
 829        /* t7 = (y2 + y1)^2 = F */
 830        vli_mod_square_fast(t7, t5, curve_prime, ndigits);
 831        /* t7 = x3' */
 832        vli_mod_sub(t7, t7, t6, curve_prime, ndigits);
 833        /* t6 = x3' - B */
 834        vli_mod_sub(t6, t7, x1, curve_prime, ndigits);
 835        /* t6 = (y2 + y1)*(x3' - B) */
 836        vli_mod_mult_fast(t6, t6, t5, curve_prime, ndigits);
 837        /* t2 = y3' */
 838        vli_mod_sub(y1, t6, y1, curve_prime, ndigits);
 839
 840        vli_set(x1, t7, ndigits);
 841}
 842
 843static void ecc_point_mult(struct ecc_point *result,
 844                           const struct ecc_point *point, const u64 *scalar,
 845                           u64 *initial_z, u64 *curve_prime,
 846                           unsigned int ndigits)
 847{
 848        /* R0 and R1 */
 849        u64 rx[2][ECC_MAX_DIGITS];
 850        u64 ry[2][ECC_MAX_DIGITS];
 851        u64 z[ECC_MAX_DIGITS];
 852        int i, nb;
 853        int num_bits = vli_num_bits(scalar, ndigits);
 854
 855        vli_set(rx[1], point->x, ndigits);
 856        vli_set(ry[1], point->y, ndigits);
 857
 858        xycz_initial_double(rx[1], ry[1], rx[0], ry[0], initial_z, curve_prime,
 859                            ndigits);
 860
 861        for (i = num_bits - 2; i > 0; i--) {
 862                nb = !vli_test_bit(scalar, i);
 863                xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime,
 864                           ndigits);
 865                xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime,
 866                         ndigits);
 867        }
 868
 869        nb = !vli_test_bit(scalar, 0);
 870        xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime,
 871                   ndigits);
 872
 873        /* Find final 1/Z value. */
 874        /* X1 - X0 */
 875        vli_mod_sub(z, rx[1], rx[0], curve_prime, ndigits);
 876        /* Yb * (X1 - X0) */
 877        vli_mod_mult_fast(z, z, ry[1 - nb], curve_prime, ndigits);
 878        /* xP * Yb * (X1 - X0) */
 879        vli_mod_mult_fast(z, z, point->x, curve_prime, ndigits);
 880
 881        /* 1 / (xP * Yb * (X1 - X0)) */
 882        vli_mod_inv(z, z, curve_prime, point->ndigits);
 883
 884        /* yP / (xP * Yb * (X1 - X0)) */
 885        vli_mod_mult_fast(z, z, point->y, curve_prime, ndigits);
 886        /* Xb * yP / (xP * Yb * (X1 - X0)) */
 887        vli_mod_mult_fast(z, z, rx[1 - nb], curve_prime, ndigits);
 888        /* End 1/Z calculation */
 889
 890        xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime, ndigits);
 891
 892        apply_z(rx[0], ry[0], z, curve_prime, ndigits);
 893
 894        vli_set(result->x, rx[0], ndigits);
 895        vli_set(result->y, ry[0], ndigits);
 896}
 897
 898static inline void ecc_swap_digits(const u64 *in, u64 *out,
 899                                   unsigned int ndigits)
 900{
 901        int i;
 902
 903        for (i = 0; i < ndigits; i++)
 904                out[i] = __swab64(in[ndigits - 1 - i]);
 905}
 906
 907int ecc_is_key_valid(unsigned int curve_id, unsigned int ndigits,
 908                     const u64 *private_key, unsigned int private_key_len)
 909{
 910        int nbytes;
 911        const struct ecc_curve *curve = ecc_get_curve(curve_id);
 912
 913        if (!private_key)
 914                return -EINVAL;
 915
 916        nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
 917
 918        if (private_key_len != nbytes)
 919                return -EINVAL;
 920
 921        if (vli_is_zero(private_key, ndigits))
 922                return -EINVAL;
 923
 924        /* Make sure the private key is in the range [1, n-1]. */
 925        if (vli_cmp(curve->n, private_key, ndigits) != 1)
 926                return -EINVAL;
 927
 928        return 0;
 929}
 930
 931/*
 932 * ECC private keys are generated using the method of extra random bits,
 933 * equivalent to that described in FIPS 186-4, Appendix B.4.1.
 934 *
 935 * d = (c mod(n–1)) + 1    where c is a string of random bits, 64 bits longer
 936 *                         than requested
 937 * 0 <= c mod(n-1) <= n-2  and implies that
 938 * 1 <= d <= n-1
 939 *
 940 * This method generates a private key uniformly distributed in the range
 941 * [1, n-1].
 942 */
 943int ecc_gen_privkey(unsigned int curve_id, unsigned int ndigits, u64 *privkey)
 944{
 945        const struct ecc_curve *curve = ecc_get_curve(curve_id);
 946        u64 priv[ECC_MAX_DIGITS];
 947        unsigned int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
 948        unsigned int nbits = vli_num_bits(curve->n, ndigits);
 949        int err;
 950
 951        /* Check that N is included in Table 1 of FIPS 186-4, section 6.1.1 */
 952        if (nbits < 160 || ndigits > ARRAY_SIZE(priv))
 953                return -EINVAL;
 954
 955        /*
 956         * FIPS 186-4 recommends that the private key should be obtained from a
 957         * RBG with a security strength equal to or greater than the security
 958         * strength associated with N.
 959         *
 960         * The maximum security strength identified by NIST SP800-57pt1r4 for
 961         * ECC is 256 (N >= 512).
 962         *
 963         * This condition is met by the default RNG because it selects a favored
 964         * DRBG with a security strength of 256.
 965         */
 966        if (crypto_get_default_rng())
 967                return -EFAULT;
 968
 969        err = crypto_rng_get_bytes(crypto_default_rng, (u8 *)priv, nbytes);
 970        crypto_put_default_rng();
 971        if (err)
 972                return err;
 973
 974        if (vli_is_zero(priv, ndigits))
 975                return -EINVAL;
 976
 977        /* Make sure the private key is in the range [1, n-1]. */
 978        if (vli_cmp(curve->n, priv, ndigits) != 1)
 979                return -EINVAL;
 980
 981        ecc_swap_digits(priv, privkey, ndigits);
 982
 983        return 0;
 984}
 985
 986int ecc_make_pub_key(unsigned int curve_id, unsigned int ndigits,
 987                     const u64 *private_key, u64 *public_key)
 988{
 989        int ret = 0;
 990        struct ecc_point *pk;
 991        u64 priv[ECC_MAX_DIGITS];
 992        const struct ecc_curve *curve = ecc_get_curve(curve_id);
 993
 994        if (!private_key || !curve || ndigits > ARRAY_SIZE(priv)) {
 995                ret = -EINVAL;
 996                goto out;
 997        }
 998
 999        ecc_swap_digits(private_key, priv, ndigits);
1000
1001        pk = ecc_alloc_point(ndigits);
1002        if (!pk) {
1003                ret = -ENOMEM;
1004                goto out;
1005        }
1006
1007        ecc_point_mult(pk, &curve->g, priv, NULL, curve->p, ndigits);
1008        if (ecc_point_is_zero(pk)) {
1009                ret = -EAGAIN;
1010                goto err_free_point;
1011        }
1012
1013        ecc_swap_digits(pk->x, public_key, ndigits);
1014        ecc_swap_digits(pk->y, &public_key[ndigits], ndigits);
1015
1016err_free_point:
1017        ecc_free_point(pk);
1018out:
1019        return ret;
1020}
1021
1022/* SP800-56A section 5.6.2.3.4 partial verification: ephemeral keys only */
1023static int ecc_is_pubkey_valid_partial(const struct ecc_curve *curve,
1024                                       struct ecc_point *pk)
1025{
1026        u64 yy[ECC_MAX_DIGITS], xxx[ECC_MAX_DIGITS], w[ECC_MAX_DIGITS];
1027
1028        /* Check 1: Verify key is not the zero point. */
1029        if (ecc_point_is_zero(pk))
1030                return -EINVAL;
1031
1032        /* Check 2: Verify key is in the range [1, p-1]. */
1033        if (vli_cmp(curve->p, pk->x, pk->ndigits) != 1)
1034                return -EINVAL;
1035        if (vli_cmp(curve->p, pk->y, pk->ndigits) != 1)
1036                return -EINVAL;
1037
1038        /* Check 3: Verify that y^2 == (x^3 + a·x + b) mod p */
1039        vli_mod_square_fast(yy, pk->y, curve->p, pk->ndigits); /* y^2 */
1040        vli_mod_square_fast(xxx, pk->x, curve->p, pk->ndigits); /* x^2 */
1041        vli_mod_mult_fast(xxx, xxx, pk->x, curve->p, pk->ndigits); /* x^3 */
1042        vli_mod_mult_fast(w, curve->a, pk->x, curve->p, pk->ndigits); /* a·x */
1043        vli_mod_add(w, w, curve->b, curve->p, pk->ndigits); /* a·x + b */
1044        vli_mod_add(w, w, xxx, curve->p, pk->ndigits); /* x^3 + a·x + b */
1045        if (vli_cmp(yy, w, pk->ndigits) != 0) /* Equation */
1046                return -EINVAL;
1047
1048        return 0;
1049
1050}
1051
1052int crypto_ecdh_shared_secret(unsigned int curve_id, unsigned int ndigits,
1053                              const u64 *private_key, const u64 *public_key,
1054                              u64 *secret)
1055{
1056        int ret = 0;
1057        struct ecc_point *product, *pk;
1058        u64 priv[ECC_MAX_DIGITS];
1059        u64 rand_z[ECC_MAX_DIGITS];
1060        unsigned int nbytes;
1061        const struct ecc_curve *curve = ecc_get_curve(curve_id);
1062
1063        if (!private_key || !public_key || !curve ||
1064            ndigits > ARRAY_SIZE(priv) || ndigits > ARRAY_SIZE(rand_z)) {
1065                ret = -EINVAL;
1066                goto out;
1067        }
1068
1069        nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
1070
1071        get_random_bytes(rand_z, nbytes);
1072
1073        pk = ecc_alloc_point(ndigits);
1074        if (!pk) {
1075                ret = -ENOMEM;
1076                goto out;
1077        }
1078
1079        ecc_swap_digits(public_key, pk->x, ndigits);
1080        ecc_swap_digits(&public_key[ndigits], pk->y, ndigits);
1081        ret = ecc_is_pubkey_valid_partial(curve, pk);
1082        if (ret)
1083                goto err_alloc_product;
1084
1085        ecc_swap_digits(private_key, priv, ndigits);
1086
1087        product = ecc_alloc_point(ndigits);
1088        if (!product) {
1089                ret = -ENOMEM;
1090                goto err_alloc_product;
1091        }
1092
1093        ecc_point_mult(product, pk, priv, rand_z, curve->p, ndigits);
1094
1095        ecc_swap_digits(product->x, secret, ndigits);
1096
1097        if (ecc_point_is_zero(product))
1098                ret = -EFAULT;
1099
1100        ecc_free_point(product);
1101err_alloc_product:
1102        ecc_free_point(pk);
1103out:
1104        return ret;
1105}
1106