linux/lib/mpi/mpih-mul.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/* mpihelp-mul.c  -  MPI helper functions
   3 * Copyright (C) 1994, 1996, 1998, 1999,
   4 *               2000 Free Software Foundation, Inc.
   5 *
   6 * This file is part of GnuPG.
   7 *
   8 * Note: This code is heavily based on the GNU MP Library.
   9 *       Actually it's the same code with only minor changes in the
  10 *       way the data is stored; this is to support the abstraction
  11 *       of an optional secure memory allocation which may be used
  12 *       to avoid revealing of sensitive data due to paging etc.
  13 *       The GNU MP Library itself is published under the LGPL;
  14 *       however I decided to publish this code under the plain GPL.
  15 */
  16
  17#include <linux/string.h>
  18#include "mpi-internal.h"
  19#include "longlong.h"
  20
  21#define MPN_MUL_N_RECURSE(prodp, up, vp, size, tspace)          \
  22        do {                                                    \
  23                if ((size) < KARATSUBA_THRESHOLD)               \
  24                        mul_n_basecase(prodp, up, vp, size);    \
  25                else                                            \
  26                        mul_n(prodp, up, vp, size, tspace);     \
  27        } while (0);
  28
  29#define MPN_SQR_N_RECURSE(prodp, up, size, tspace)              \
  30        do {                                                    \
  31                if ((size) < KARATSUBA_THRESHOLD)               \
  32                        mpih_sqr_n_basecase(prodp, up, size);   \
  33                else                                            \
  34                        mpih_sqr_n(prodp, up, size, tspace);    \
  35        } while (0);
  36
  37/* Multiply the natural numbers u (pointed to by UP) and v (pointed to by VP),
  38 * both with SIZE limbs, and store the result at PRODP.  2 * SIZE limbs are
  39 * always stored.  Return the most significant limb.
  40 *
  41 * Argument constraints:
  42 * 1. PRODP != UP and PRODP != VP, i.e. the destination
  43 *    must be distinct from the multiplier and the multiplicand.
  44 *
  45 *
  46 * Handle simple cases with traditional multiplication.
  47 *
  48 * This is the most critical code of multiplication.  All multiplies rely
  49 * on this, both small and huge.  Small ones arrive here immediately.  Huge
  50 * ones arrive here as this is the base case for Karatsuba's recursive
  51 * algorithm below.
  52 */
  53
  54static mpi_limb_t
  55mul_n_basecase(mpi_ptr_t prodp, mpi_ptr_t up, mpi_ptr_t vp, mpi_size_t size)
  56{
  57        mpi_size_t i;
  58        mpi_limb_t cy;
  59        mpi_limb_t v_limb;
  60
  61        /* Multiply by the first limb in V separately, as the result can be
  62         * stored (not added) to PROD.  We also avoid a loop for zeroing.  */
  63        v_limb = vp[0];
  64        if (v_limb <= 1) {
  65                if (v_limb == 1)
  66                        MPN_COPY(prodp, up, size);
  67                else
  68                        MPN_ZERO(prodp, size);
  69                cy = 0;
  70        } else
  71                cy = mpihelp_mul_1(prodp, up, size, v_limb);
  72
  73        prodp[size] = cy;
  74        prodp++;
  75
  76        /* For each iteration in the outer loop, multiply one limb from
  77         * U with one limb from V, and add it to PROD.  */
  78        for (i = 1; i < size; i++) {
  79                v_limb = vp[i];
  80                if (v_limb <= 1) {
  81                        cy = 0;
  82                        if (v_limb == 1)
  83                                cy = mpihelp_add_n(prodp, prodp, up, size);
  84                } else
  85                        cy = mpihelp_addmul_1(prodp, up, size, v_limb);
  86
  87                prodp[size] = cy;
  88                prodp++;
  89        }
  90
  91        return cy;
  92}
  93
  94static void
  95mul_n(mpi_ptr_t prodp, mpi_ptr_t up, mpi_ptr_t vp,
  96                mpi_size_t size, mpi_ptr_t tspace)
  97{
  98        if (size & 1) {
  99                /* The size is odd, and the code below doesn't handle that.
 100                 * Multiply the least significant (size - 1) limbs with a recursive
 101                 * call, and handle the most significant limb of S1 and S2
 102                 * separately.
 103                 * A slightly faster way to do this would be to make the Karatsuba
 104                 * code below behave as if the size were even, and let it check for
 105                 * odd size in the end.  I.e., in essence move this code to the end.
 106                 * Doing so would save us a recursive call, and potentially make the
 107                 * stack grow a lot less.
 108                 */
 109                mpi_size_t esize = size - 1;    /* even size */
 110                mpi_limb_t cy_limb;
 111
 112                MPN_MUL_N_RECURSE(prodp, up, vp, esize, tspace);
 113                cy_limb = mpihelp_addmul_1(prodp + esize, up, esize, vp[esize]);
 114                prodp[esize + esize] = cy_limb;
 115                cy_limb = mpihelp_addmul_1(prodp + esize, vp, size, up[esize]);
 116                prodp[esize + size] = cy_limb;
 117        } else {
 118                /* Anatolij Alekseevich Karatsuba's divide-and-conquer algorithm.
 119                 *
 120                 * Split U in two pieces, U1 and U0, such that
 121                 * U = U0 + U1*(B**n),
 122                 * and V in V1 and V0, such that
 123                 * V = V0 + V1*(B**n).
 124                 *
 125                 * UV is then computed recursively using the identity
 126                 *
 127                 *        2n   n          n                     n
 128                 * UV = (B  + B )U V  +  B (U -U )(V -V )  +  (B + 1)U V
 129                 *                1 1        1  0   0  1              0 0
 130                 *
 131                 * Where B = 2**BITS_PER_MP_LIMB.
 132                 */
 133                mpi_size_t hsize = size >> 1;
 134                mpi_limb_t cy;
 135                int negflg;
 136
 137                /* Product H.      ________________  ________________
 138                 *                |_____U1 x V1____||____U0 x V0_____|
 139                 * Put result in upper part of PROD and pass low part of TSPACE
 140                 * as new TSPACE.
 141                 */
 142                MPN_MUL_N_RECURSE(prodp + size, up + hsize, vp + hsize, hsize,
 143                                  tspace);
 144
 145                /* Product M.      ________________
 146                 *                |_(U1-U0)(V0-V1)_|
 147                 */
 148                if (mpihelp_cmp(up + hsize, up, hsize) >= 0) {
 149                        mpihelp_sub_n(prodp, up + hsize, up, hsize);
 150                        negflg = 0;
 151                } else {
 152                        mpihelp_sub_n(prodp, up, up + hsize, hsize);
 153                        negflg = 1;
 154                }
 155                if (mpihelp_cmp(vp + hsize, vp, hsize) >= 0) {
 156                        mpihelp_sub_n(prodp + hsize, vp + hsize, vp, hsize);
 157                        negflg ^= 1;
 158                } else {
 159                        mpihelp_sub_n(prodp + hsize, vp, vp + hsize, hsize);
 160                        /* No change of NEGFLG.  */
 161                }
 162                /* Read temporary operands from low part of PROD.
 163                 * Put result in low part of TSPACE using upper part of TSPACE
 164                 * as new TSPACE.
 165                 */
 166                MPN_MUL_N_RECURSE(tspace, prodp, prodp + hsize, hsize,
 167                                  tspace + size);
 168
 169                /* Add/copy product H. */
 170                MPN_COPY(prodp + hsize, prodp + size, hsize);
 171                cy = mpihelp_add_n(prodp + size, prodp + size,
 172                                   prodp + size + hsize, hsize);
 173
 174                /* Add product M (if NEGFLG M is a negative number) */
 175                if (negflg)
 176                        cy -=
 177                            mpihelp_sub_n(prodp + hsize, prodp + hsize, tspace,
 178                                          size);
 179                else
 180                        cy +=
 181                            mpihelp_add_n(prodp + hsize, prodp + hsize, tspace,
 182                                          size);
 183
 184                /* Product L.      ________________  ________________
 185                 *                |________________||____U0 x V0_____|
 186                 * Read temporary operands from low part of PROD.
 187                 * Put result in low part of TSPACE using upper part of TSPACE
 188                 * as new TSPACE.
 189                 */
 190                MPN_MUL_N_RECURSE(tspace, up, vp, hsize, tspace + size);
 191
 192                /* Add/copy Product L (twice) */
 193
 194                cy += mpihelp_add_n(prodp + hsize, prodp + hsize, tspace, size);
 195                if (cy)
 196                        mpihelp_add_1(prodp + hsize + size,
 197                                      prodp + hsize + size, hsize, cy);
 198
 199                MPN_COPY(prodp, tspace, hsize);
 200                cy = mpihelp_add_n(prodp + hsize, prodp + hsize, tspace + hsize,
 201                                   hsize);
 202                if (cy)
 203                        mpihelp_add_1(prodp + size, prodp + size, size, 1);
 204        }
 205}
 206
 207void mpih_sqr_n_basecase(mpi_ptr_t prodp, mpi_ptr_t up, mpi_size_t size)
 208{
 209        mpi_size_t i;
 210        mpi_limb_t cy_limb;
 211        mpi_limb_t v_limb;
 212
 213        /* Multiply by the first limb in V separately, as the result can be
 214         * stored (not added) to PROD.  We also avoid a loop for zeroing.  */
 215        v_limb = up[0];
 216        if (v_limb <= 1) {
 217                if (v_limb == 1)
 218                        MPN_COPY(prodp, up, size);
 219                else
 220                        MPN_ZERO(prodp, size);
 221                cy_limb = 0;
 222        } else
 223                cy_limb = mpihelp_mul_1(prodp, up, size, v_limb);
 224
 225        prodp[size] = cy_limb;
 226        prodp++;
 227
 228        /* For each iteration in the outer loop, multiply one limb from
 229         * U with one limb from V, and add it to PROD.  */
 230        for (i = 1; i < size; i++) {
 231                v_limb = up[i];
 232                if (v_limb <= 1) {
 233                        cy_limb = 0;
 234                        if (v_limb == 1)
 235                                cy_limb = mpihelp_add_n(prodp, prodp, up, size);
 236                } else
 237                        cy_limb = mpihelp_addmul_1(prodp, up, size, v_limb);
 238
 239                prodp[size] = cy_limb;
 240                prodp++;
 241        }
 242}
 243
 244void
 245mpih_sqr_n(mpi_ptr_t prodp, mpi_ptr_t up, mpi_size_t size, mpi_ptr_t tspace)
 246{
 247        if (size & 1) {
 248                /* The size is odd, and the code below doesn't handle that.
 249                 * Multiply the least significant (size - 1) limbs with a recursive
 250                 * call, and handle the most significant limb of S1 and S2
 251                 * separately.
 252                 * A slightly faster way to do this would be to make the Karatsuba
 253                 * code below behave as if the size were even, and let it check for
 254                 * odd size in the end.  I.e., in essence move this code to the end.
 255                 * Doing so would save us a recursive call, and potentially make the
 256                 * stack grow a lot less.
 257                 */
 258                mpi_size_t esize = size - 1;    /* even size */
 259                mpi_limb_t cy_limb;
 260
 261                MPN_SQR_N_RECURSE(prodp, up, esize, tspace);
 262                cy_limb = mpihelp_addmul_1(prodp + esize, up, esize, up[esize]);
 263                prodp[esize + esize] = cy_limb;
 264                cy_limb = mpihelp_addmul_1(prodp + esize, up, size, up[esize]);
 265
 266                prodp[esize + size] = cy_limb;
 267        } else {
 268                mpi_size_t hsize = size >> 1;
 269                mpi_limb_t cy;
 270
 271                /* Product H.      ________________  ________________
 272                 *                |_____U1 x U1____||____U0 x U0_____|
 273                 * Put result in upper part of PROD and pass low part of TSPACE
 274                 * as new TSPACE.
 275                 */
 276                MPN_SQR_N_RECURSE(prodp + size, up + hsize, hsize, tspace);
 277
 278                /* Product M.      ________________
 279                 *                |_(U1-U0)(U0-U1)_|
 280                 */
 281                if (mpihelp_cmp(up + hsize, up, hsize) >= 0)
 282                        mpihelp_sub_n(prodp, up + hsize, up, hsize);
 283                else
 284                        mpihelp_sub_n(prodp, up, up + hsize, hsize);
 285
 286                /* Read temporary operands from low part of PROD.
 287                 * Put result in low part of TSPACE using upper part of TSPACE
 288                 * as new TSPACE.  */
 289                MPN_SQR_N_RECURSE(tspace, prodp, hsize, tspace + size);
 290
 291                /* Add/copy product H  */
 292                MPN_COPY(prodp + hsize, prodp + size, hsize);
 293                cy = mpihelp_add_n(prodp + size, prodp + size,
 294                                   prodp + size + hsize, hsize);
 295
 296                /* Add product M (if NEGFLG M is a negative number).  */
 297                cy -= mpihelp_sub_n(prodp + hsize, prodp + hsize, tspace, size);
 298
 299                /* Product L.      ________________  ________________
 300                 *                |________________||____U0 x U0_____|
 301                 * Read temporary operands from low part of PROD.
 302                 * Put result in low part of TSPACE using upper part of TSPACE
 303                 * as new TSPACE.  */
 304                MPN_SQR_N_RECURSE(tspace, up, hsize, tspace + size);
 305
 306                /* Add/copy Product L (twice).  */
 307                cy += mpihelp_add_n(prodp + hsize, prodp + hsize, tspace, size);
 308                if (cy)
 309                        mpihelp_add_1(prodp + hsize + size,
 310                                      prodp + hsize + size, hsize, cy);
 311
 312                MPN_COPY(prodp, tspace, hsize);
 313                cy = mpihelp_add_n(prodp + hsize, prodp + hsize, tspace + hsize,
 314                                   hsize);
 315                if (cy)
 316                        mpihelp_add_1(prodp + size, prodp + size, size, 1);
 317        }
 318}
 319
 320
 321void mpihelp_mul_n(mpi_ptr_t prodp,
 322                mpi_ptr_t up, mpi_ptr_t vp, mpi_size_t size)
 323{
 324        if (up == vp) {
 325                if (size < KARATSUBA_THRESHOLD)
 326                        mpih_sqr_n_basecase(prodp, up, size);
 327                else {
 328                        mpi_ptr_t tspace;
 329                        tspace = mpi_alloc_limb_space(2 * size);
 330                        mpih_sqr_n(prodp, up, size, tspace);
 331                        mpi_free_limb_space(tspace);
 332                }
 333        } else {
 334                if (size < KARATSUBA_THRESHOLD)
 335                        mul_n_basecase(prodp, up, vp, size);
 336                else {
 337                        mpi_ptr_t tspace;
 338                        tspace = mpi_alloc_limb_space(2 * size);
 339                        mul_n(prodp, up, vp, size, tspace);
 340                        mpi_free_limb_space(tspace);
 341                }
 342        }
 343}
 344
 345int
 346mpihelp_mul_karatsuba_case(mpi_ptr_t prodp,
 347                           mpi_ptr_t up, mpi_size_t usize,
 348                           mpi_ptr_t vp, mpi_size_t vsize,
 349                           struct karatsuba_ctx *ctx)
 350{
 351        mpi_limb_t cy;
 352
 353        if (!ctx->tspace || ctx->tspace_size < vsize) {
 354                if (ctx->tspace)
 355                        mpi_free_limb_space(ctx->tspace);
 356                ctx->tspace = mpi_alloc_limb_space(2 * vsize);
 357                if (!ctx->tspace)
 358                        return -ENOMEM;
 359                ctx->tspace_size = vsize;
 360        }
 361
 362        MPN_MUL_N_RECURSE(prodp, up, vp, vsize, ctx->tspace);
 363
 364        prodp += vsize;
 365        up += vsize;
 366        usize -= vsize;
 367        if (usize >= vsize) {
 368                if (!ctx->tp || ctx->tp_size < vsize) {
 369                        if (ctx->tp)
 370                                mpi_free_limb_space(ctx->tp);
 371                        ctx->tp = mpi_alloc_limb_space(2 * vsize);
 372                        if (!ctx->tp) {
 373                                if (ctx->tspace)
 374                                        mpi_free_limb_space(ctx->tspace);
 375                                ctx->tspace = NULL;
 376                                return -ENOMEM;
 377                        }
 378                        ctx->tp_size = vsize;
 379                }
 380
 381                do {
 382                        MPN_MUL_N_RECURSE(ctx->tp, up, vp, vsize, ctx->tspace);
 383                        cy = mpihelp_add_n(prodp, prodp, ctx->tp, vsize);
 384                        mpihelp_add_1(prodp + vsize, ctx->tp + vsize, vsize,
 385                                      cy);
 386                        prodp += vsize;
 387                        up += vsize;
 388                        usize -= vsize;
 389                } while (usize >= vsize);
 390        }
 391
 392        if (usize) {
 393                if (usize < KARATSUBA_THRESHOLD) {
 394                        mpi_limb_t tmp;
 395                        if (mpihelp_mul(ctx->tspace, vp, vsize, up, usize, &tmp)
 396                            < 0)
 397                                return -ENOMEM;
 398                } else {
 399                        if (!ctx->next) {
 400                                ctx->next = kzalloc(sizeof *ctx, GFP_KERNEL);
 401                                if (!ctx->next)
 402                                        return -ENOMEM;
 403                        }
 404                        if (mpihelp_mul_karatsuba_case(ctx->tspace,
 405                                                       vp, vsize,
 406                                                       up, usize,
 407                                                       ctx->next) < 0)
 408                                return -ENOMEM;
 409                }
 410
 411                cy = mpihelp_add_n(prodp, prodp, ctx->tspace, vsize);
 412                mpihelp_add_1(prodp + vsize, ctx->tspace + vsize, usize, cy);
 413        }
 414
 415        return 0;
 416}
 417
 418void mpihelp_release_karatsuba_ctx(struct karatsuba_ctx *ctx)
 419{
 420        struct karatsuba_ctx *ctx2;
 421
 422        if (ctx->tp)
 423                mpi_free_limb_space(ctx->tp);
 424        if (ctx->tspace)
 425                mpi_free_limb_space(ctx->tspace);
 426        for (ctx = ctx->next; ctx; ctx = ctx2) {
 427                ctx2 = ctx->next;
 428                if (ctx->tp)
 429                        mpi_free_limb_space(ctx->tp);
 430                if (ctx->tspace)
 431                        mpi_free_limb_space(ctx->tspace);
 432                kfree(ctx);
 433        }
 434}
 435
 436/* Multiply the natural numbers u (pointed to by UP, with USIZE limbs)
 437 * and v (pointed to by VP, with VSIZE limbs), and store the result at
 438 * PRODP.  USIZE + VSIZE limbs are always stored, but if the input
 439 * operands are normalized.  Return the most significant limb of the
 440 * result.
 441 *
 442 * NOTE: The space pointed to by PRODP is overwritten before finished
 443 * with U and V, so overlap is an error.
 444 *
 445 * Argument constraints:
 446 * 1. USIZE >= VSIZE.
 447 * 2. PRODP != UP and PRODP != VP, i.e. the destination
 448 *    must be distinct from the multiplier and the multiplicand.
 449 */
 450
 451int
 452mpihelp_mul(mpi_ptr_t prodp, mpi_ptr_t up, mpi_size_t usize,
 453            mpi_ptr_t vp, mpi_size_t vsize, mpi_limb_t *_result)
 454{
 455        mpi_ptr_t prod_endp = prodp + usize + vsize - 1;
 456        mpi_limb_t cy;
 457        struct karatsuba_ctx ctx;
 458
 459        if (vsize < KARATSUBA_THRESHOLD) {
 460                mpi_size_t i;
 461                mpi_limb_t v_limb;
 462
 463                if (!vsize) {
 464                        *_result = 0;
 465                        return 0;
 466                }
 467
 468                /* Multiply by the first limb in V separately, as the result can be
 469                 * stored (not added) to PROD.  We also avoid a loop for zeroing.  */
 470                v_limb = vp[0];
 471                if (v_limb <= 1) {
 472                        if (v_limb == 1)
 473                                MPN_COPY(prodp, up, usize);
 474                        else
 475                                MPN_ZERO(prodp, usize);
 476                        cy = 0;
 477                } else
 478                        cy = mpihelp_mul_1(prodp, up, usize, v_limb);
 479
 480                prodp[usize] = cy;
 481                prodp++;
 482
 483                /* For each iteration in the outer loop, multiply one limb from
 484                 * U with one limb from V, and add it to PROD.  */
 485                for (i = 1; i < vsize; i++) {
 486                        v_limb = vp[i];
 487                        if (v_limb <= 1) {
 488                                cy = 0;
 489                                if (v_limb == 1)
 490                                        cy = mpihelp_add_n(prodp, prodp, up,
 491                                                           usize);
 492                        } else
 493                                cy = mpihelp_addmul_1(prodp, up, usize, v_limb);
 494
 495                        prodp[usize] = cy;
 496                        prodp++;
 497                }
 498
 499                *_result = cy;
 500                return 0;
 501        }
 502
 503        memset(&ctx, 0, sizeof ctx);
 504        if (mpihelp_mul_karatsuba_case(prodp, up, usize, vp, vsize, &ctx) < 0)
 505                return -ENOMEM;
 506        mpihelp_release_karatsuba_ctx(&ctx);
 507        *_result = *prod_endp;
 508        return 0;
 509}
 510