linux/lib/mpi/mpih-mul.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/* mpihelp-mul.c  -  MPI helper functions
   3 * Copyright (C) 1994, 1996, 1998, 1999,
   4 *               2000 Free Software Foundation, Inc.
   5 *
   6 * This file is part of GnuPG.
   7 *
   8 * Note: This code is heavily based on the GNU MP Library.
   9 *       Actually it's the same code with only minor changes in the
  10 *       way the data is stored; this is to support the abstraction
  11 *       of an optional secure memory allocation which may be used
  12 *       to avoid revealing of sensitive data due to paging etc.
  13 *       The GNU MP Library itself is published under the LGPL;
  14 *       however I decided to publish this code under the plain GPL.
  15 */
  16
  17#include <linux/string.h>
  18#include "mpi-internal.h"
  19#include "longlong.h"
  20
  21#define MPN_MUL_N_RECURSE(prodp, up, vp, size, tspace)          \
  22        do {                                                    \
  23                if ((size) < KARATSUBA_THRESHOLD)               \
  24                        mul_n_basecase(prodp, up, vp, size);    \
  25                else                                            \
  26                        mul_n(prodp, up, vp, size, tspace);     \
  27        } while (0);
  28
  29#define MPN_SQR_N_RECURSE(prodp, up, size, tspace)              \
  30        do {                                                    \
  31                if ((size) < KARATSUBA_THRESHOLD)               \
  32                        mpih_sqr_n_basecase(prodp, up, size);   \
  33                else                                            \
  34                        mpih_sqr_n(prodp, up, size, tspace);    \
  35        } while (0);
  36
  37/* Multiply the natural numbers u (pointed to by UP) and v (pointed to by VP),
  38 * both with SIZE limbs, and store the result at PRODP.  2 * SIZE limbs are
  39 * always stored.  Return the most significant limb.
  40 *
  41 * Argument constraints:
  42 * 1. PRODP != UP and PRODP != VP, i.e. the destination
  43 *    must be distinct from the multiplier and the multiplicand.
  44 *
  45 *
  46 * Handle simple cases with traditional multiplication.
  47 *
  48 * This is the most critical code of multiplication.  All multiplies rely
  49 * on this, both small and huge.  Small ones arrive here immediately.  Huge
  50 * ones arrive here as this is the base case for Karatsuba's recursive
  51 * algorithm below.
  52 */
  53
  54static mpi_limb_t
  55mul_n_basecase(mpi_ptr_t prodp, mpi_ptr_t up, mpi_ptr_t vp, mpi_size_t size)
  56{
  57        mpi_size_t i;
  58        mpi_limb_t cy;
  59        mpi_limb_t v_limb;
  60
  61        /* Multiply by the first limb in V separately, as the result can be
  62         * stored (not added) to PROD.  We also avoid a loop for zeroing.  */
  63        v_limb = vp[0];
  64        if (v_limb <= 1) {
  65                if (v_limb == 1)
  66                        MPN_COPY(prodp, up, size);
  67                else
  68                        MPN_ZERO(prodp, size);
  69                cy = 0;
  70        } else
  71                cy = mpihelp_mul_1(prodp, up, size, v_limb);
  72
  73        prodp[size] = cy;
  74        prodp++;
  75
  76        /* For each iteration in the outer loop, multiply one limb from
  77         * U with one limb from V, and add it to PROD.  */
  78        for (i = 1; i < size; i++) {
  79                v_limb = vp[i];
  80                if (v_limb <= 1) {
  81                        cy = 0;
  82                        if (v_limb == 1)
  83                                cy = mpihelp_add_n(prodp, prodp, up, size);
  84                } else
  85                        cy = mpihelp_addmul_1(prodp, up, size, v_limb);
  86
  87                prodp[size] = cy;
  88                prodp++;
  89        }
  90
  91        return cy;
  92}
  93
  94static void
  95mul_n(mpi_ptr_t prodp, mpi_ptr_t up, mpi_ptr_t vp,
  96                mpi_size_t size, mpi_ptr_t tspace)
  97{
  98        if (size & 1) {
  99                /* The size is odd, and the code below doesn't handle that.
 100                 * Multiply the least significant (size - 1) limbs with a recursive
 101                 * call, and handle the most significant limb of S1 and S2
 102                 * separately.
 103                 * A slightly faster way to do this would be to make the Karatsuba
 104                 * code below behave as if the size were even, and let it check for
 105                 * odd size in the end.  I.e., in essence move this code to the end.
 106                 * Doing so would save us a recursive call, and potentially make the
 107                 * stack grow a lot less.
 108                 */
 109                mpi_size_t esize = size - 1;    /* even size */
 110                mpi_limb_t cy_limb;
 111
 112                MPN_MUL_N_RECURSE(prodp, up, vp, esize, tspace);
 113                cy_limb = mpihelp_addmul_1(prodp + esize, up, esize, vp[esize]);
 114                prodp[esize + esize] = cy_limb;
 115                cy_limb = mpihelp_addmul_1(prodp + esize, vp, size, up[esize]);
 116                prodp[esize + size] = cy_limb;
 117        } else {
 118                /* Anatolij Alekseevich Karatsuba's divide-and-conquer algorithm.
 119                 *
 120                 * Split U in two pieces, U1 and U0, such that
 121                 * U = U0 + U1*(B**n),
 122                 * and V in V1 and V0, such that
 123                 * V = V0 + V1*(B**n).
 124                 *
 125                 * UV is then computed recursively using the identity
 126                 *
 127                 *        2n   n          n                     n
 128                 * UV = (B  + B )U V  +  B (U -U )(V -V )  +  (B + 1)U V
 129                 *                1 1        1  0   0  1              0 0
 130                 *
 131                 * Where B = 2**BITS_PER_MP_LIMB.
 132                 */
 133                mpi_size_t hsize = size >> 1;
 134                mpi_limb_t cy;
 135                int negflg;
 136
 137                /* Product H.      ________________  ________________
 138                 *                |_____U1 x V1____||____U0 x V0_____|
 139                 * Put result in upper part of PROD and pass low part of TSPACE
 140                 * as new TSPACE.
 141                 */
 142                MPN_MUL_N_RECURSE(prodp + size, up + hsize, vp + hsize, hsize,
 143                                  tspace);
 144
 145                /* Product M.      ________________
 146                 *                |_(U1-U0)(V0-V1)_|
 147                 */
 148                if (mpihelp_cmp(up + hsize, up, hsize) >= 0) {
 149                        mpihelp_sub_n(prodp, up + hsize, up, hsize);
 150                        negflg = 0;
 151                } else {
 152                        mpihelp_sub_n(prodp, up, up + hsize, hsize);
 153                        negflg = 1;
 154                }
 155                if (mpihelp_cmp(vp + hsize, vp, hsize) >= 0) {
 156                        mpihelp_sub_n(prodp + hsize, vp + hsize, vp, hsize);
 157                        negflg ^= 1;
 158                } else {
 159                        mpihelp_sub_n(prodp + hsize, vp, vp + hsize, hsize);
 160                        /* No change of NEGFLG.  */
 161                }
 162                /* Read temporary operands from low part of PROD.
 163                 * Put result in low part of TSPACE using upper part of TSPACE
 164                 * as new TSPACE.
 165                 */
 166                MPN_MUL_N_RECURSE(tspace, prodp, prodp + hsize, hsize,
 167                                  tspace + size);
 168
 169                /* Add/copy product H. */
 170                MPN_COPY(prodp + hsize, prodp + size, hsize);
 171                cy = mpihelp_add_n(prodp + size, prodp + size,
 172                                   prodp + size + hsize, hsize);
 173
 174                /* Add product M (if NEGFLG M is a negative number) */
 175                if (negflg)
 176                        cy -=
 177                            mpihelp_sub_n(prodp + hsize, prodp + hsize, tspace,
 178                                          size);
 179                else
 180                        cy +=
 181                            mpihelp_add_n(prodp + hsize, prodp + hsize, tspace,
 182                                          size);
 183
 184                /* Product L.      ________________  ________________
 185                 *                |________________||____U0 x V0_____|
 186                 * Read temporary operands from low part of PROD.
 187                 * Put result in low part of TSPACE using upper part of TSPACE
 188                 * as new TSPACE.
 189                 */
 190                MPN_MUL_N_RECURSE(tspace, up, vp, hsize, tspace + size);
 191
 192                /* Add/copy Product L (twice) */
 193
 194                cy += mpihelp_add_n(prodp + hsize, prodp + hsize, tspace, size);
 195                if (cy)
 196                        mpihelp_add_1(prodp + hsize + size,
 197                                      prodp + hsize + size, hsize, cy);
 198
 199                MPN_COPY(prodp, tspace, hsize);
 200                cy = mpihelp_add_n(prodp + hsize, prodp + hsize, tspace + hsize,
 201                                   hsize);
 202                if (cy)
 203                        mpihelp_add_1(prodp + size, prodp + size, size, 1);
 204        }
 205}
 206
 207void mpih_sqr_n_basecase(mpi_ptr_t prodp, mpi_ptr_t up, mpi_size_t size)
 208{
 209        mpi_size_t i;
 210        mpi_limb_t cy_limb;
 211        mpi_limb_t v_limb;
 212
 213        /* Multiply by the first limb in V separately, as the result can be
 214         * stored (not added) to PROD.  We also avoid a loop for zeroing.  */
 215        v_limb = up[0];
 216        if (v_limb <= 1) {
 217                if (v_limb == 1)
 218                        MPN_COPY(prodp, up, size);
 219                else
 220                        MPN_ZERO(prodp, size);
 221                cy_limb = 0;
 222        } else
 223                cy_limb = mpihelp_mul_1(prodp, up, size, v_limb);
 224
 225        prodp[size] = cy_limb;
 226        prodp++;
 227
 228        /* For each iteration in the outer loop, multiply one limb from
 229         * U with one limb from V, and add it to PROD.  */
 230        for (i = 1; i < size; i++) {
 231                v_limb = up[i];
 232                if (v_limb <= 1) {
 233                        cy_limb = 0;
 234                        if (v_limb == 1)
 235                                cy_limb = mpihelp_add_n(prodp, prodp, up, size);
 236                } else
 237                        cy_limb = mpihelp_addmul_1(prodp, up, size, v_limb);
 238
 239                prodp[size] = cy_limb;
 240                prodp++;
 241        }
 242}
 243
 244void
 245mpih_sqr_n(mpi_ptr_t prodp, mpi_ptr_t up, mpi_size_t size, mpi_ptr_t tspace)
 246{
 247        if (size & 1) {
 248                /* The size is odd, and the code below doesn't handle that.
 249                 * Multiply the least significant (size - 1) limbs with a recursive
 250                 * call, and handle the most significant limb of S1 and S2
 251                 * separately.
 252                 * A slightly faster way to do this would be to make the Karatsuba
 253                 * code below behave as if the size were even, and let it check for
 254                 * odd size in the end.  I.e., in essence move this code to the end.
 255                 * Doing so would save us a recursive call, and potentially make the
 256                 * stack grow a lot less.
 257                 */
 258                mpi_size_t esize = size - 1;    /* even size */
 259                mpi_limb_t cy_limb;
 260
 261                MPN_SQR_N_RECURSE(prodp, up, esize, tspace);
 262                cy_limb = mpihelp_addmul_1(prodp + esize, up, esize, up[esize]);
 263                prodp[esize + esize] = cy_limb;
 264                cy_limb = mpihelp_addmul_1(prodp + esize, up, size, up[esize]);
 265
 266                prodp[esize + size] = cy_limb;
 267        } else {
 268                mpi_size_t hsize = size >> 1;
 269                mpi_limb_t cy;
 270
 271                /* Product H.      ________________  ________________
 272                 *                |_____U1 x U1____||____U0 x U0_____|
 273                 * Put result in upper part of PROD and pass low part of TSPACE
 274                 * as new TSPACE.
 275                 */
 276                MPN_SQR_N_RECURSE(prodp + size, up + hsize, hsize, tspace);
 277
 278                /* Product M.      ________________
 279                 *                |_(U1-U0)(U0-U1)_|
 280                 */
 281                if (mpihelp_cmp(up + hsize, up, hsize) >= 0)
 282                        mpihelp_sub_n(prodp, up + hsize, up, hsize);
 283                else
 284                        mpihelp_sub_n(prodp, up, up + hsize, hsize);
 285
 286                /* Read temporary operands from low part of PROD.
 287                 * Put result in low part of TSPACE using upper part of TSPACE
 288                 * as new TSPACE.  */
 289                MPN_SQR_N_RECURSE(tspace, prodp, hsize, tspace + size);
 290
 291                /* Add/copy product H  */
 292                MPN_COPY(prodp + hsize, prodp + size, hsize);
 293                cy = mpihelp_add_n(prodp + size, prodp + size,
 294                                   prodp + size + hsize, hsize);
 295
 296                /* Add product M (if NEGFLG M is a negative number).  */
 297                cy -= mpihelp_sub_n(prodp + hsize, prodp + hsize, tspace, size);
 298
 299                /* Product L.      ________________  ________________
 300                 *                |________________||____U0 x U0_____|
 301                 * Read temporary operands from low part of PROD.
 302                 * Put result in low part of TSPACE using upper part of TSPACE
 303                 * as new TSPACE.  */
 304                MPN_SQR_N_RECURSE(tspace, up, hsize, tspace + size);
 305
 306                /* Add/copy Product L (twice).  */
 307                cy += mpihelp_add_n(prodp + hsize, prodp + hsize, tspace, size);
 308                if (cy)
 309                        mpihelp_add_1(prodp + hsize + size,
 310                                      prodp + hsize + size, hsize, cy);
 311
 312                MPN_COPY(prodp, tspace, hsize);
 313                cy = mpihelp_add_n(prodp + hsize, prodp + hsize, tspace + hsize,
 314                                   hsize);
 315                if (cy)
 316                        mpihelp_add_1(prodp + size, prodp + size, size, 1);
 317        }
 318}
 319
 320int
 321mpihelp_mul_karatsuba_case(mpi_ptr_t prodp,
 322                           mpi_ptr_t up, mpi_size_t usize,
 323                           mpi_ptr_t vp, mpi_size_t vsize,
 324                           struct karatsuba_ctx *ctx)
 325{
 326        mpi_limb_t cy;
 327
 328        if (!ctx->tspace || ctx->tspace_size < vsize) {
 329                if (ctx->tspace)
 330                        mpi_free_limb_space(ctx->tspace);
 331                ctx->tspace = mpi_alloc_limb_space(2 * vsize);
 332                if (!ctx->tspace)
 333                        return -ENOMEM;
 334                ctx->tspace_size = vsize;
 335        }
 336
 337        MPN_MUL_N_RECURSE(prodp, up, vp, vsize, ctx->tspace);
 338
 339        prodp += vsize;
 340        up += vsize;
 341        usize -= vsize;
 342        if (usize >= vsize) {
 343                if (!ctx->tp || ctx->tp_size < vsize) {
 344                        if (ctx->tp)
 345                                mpi_free_limb_space(ctx->tp);
 346                        ctx->tp = mpi_alloc_limb_space(2 * vsize);
 347                        if (!ctx->tp) {
 348                                if (ctx->tspace)
 349                                        mpi_free_limb_space(ctx->tspace);
 350                                ctx->tspace = NULL;
 351                                return -ENOMEM;
 352                        }
 353                        ctx->tp_size = vsize;
 354                }
 355
 356                do {
 357                        MPN_MUL_N_RECURSE(ctx->tp, up, vp, vsize, ctx->tspace);
 358                        cy = mpihelp_add_n(prodp, prodp, ctx->tp, vsize);
 359                        mpihelp_add_1(prodp + vsize, ctx->tp + vsize, vsize,
 360                                      cy);
 361                        prodp += vsize;
 362                        up += vsize;
 363                        usize -= vsize;
 364                } while (usize >= vsize);
 365        }
 366
 367        if (usize) {
 368                if (usize < KARATSUBA_THRESHOLD) {
 369                        mpi_limb_t tmp;
 370                        if (mpihelp_mul(ctx->tspace, vp, vsize, up, usize, &tmp)
 371                            < 0)
 372                                return -ENOMEM;
 373                } else {
 374                        if (!ctx->next) {
 375                                ctx->next = kzalloc(sizeof *ctx, GFP_KERNEL);
 376                                if (!ctx->next)
 377                                        return -ENOMEM;
 378                        }
 379                        if (mpihelp_mul_karatsuba_case(ctx->tspace,
 380                                                       vp, vsize,
 381                                                       up, usize,
 382                                                       ctx->next) < 0)
 383                                return -ENOMEM;
 384                }
 385
 386                cy = mpihelp_add_n(prodp, prodp, ctx->tspace, vsize);
 387                mpihelp_add_1(prodp + vsize, ctx->tspace + vsize, usize, cy);
 388        }
 389
 390        return 0;
 391}
 392
 393void mpihelp_release_karatsuba_ctx(struct karatsuba_ctx *ctx)
 394{
 395        struct karatsuba_ctx *ctx2;
 396
 397        if (ctx->tp)
 398                mpi_free_limb_space(ctx->tp);
 399        if (ctx->tspace)
 400                mpi_free_limb_space(ctx->tspace);
 401        for (ctx = ctx->next; ctx; ctx = ctx2) {
 402                ctx2 = ctx->next;
 403                if (ctx->tp)
 404                        mpi_free_limb_space(ctx->tp);
 405                if (ctx->tspace)
 406                        mpi_free_limb_space(ctx->tspace);
 407                kfree(ctx);
 408        }
 409}
 410
 411/* Multiply the natural numbers u (pointed to by UP, with USIZE limbs)
 412 * and v (pointed to by VP, with VSIZE limbs), and store the result at
 413 * PRODP.  USIZE + VSIZE limbs are always stored, but if the input
 414 * operands are normalized.  Return the most significant limb of the
 415 * result.
 416 *
 417 * NOTE: The space pointed to by PRODP is overwritten before finished
 418 * with U and V, so overlap is an error.
 419 *
 420 * Argument constraints:
 421 * 1. USIZE >= VSIZE.
 422 * 2. PRODP != UP and PRODP != VP, i.e. the destination
 423 *    must be distinct from the multiplier and the multiplicand.
 424 */
 425
 426int
 427mpihelp_mul(mpi_ptr_t prodp, mpi_ptr_t up, mpi_size_t usize,
 428            mpi_ptr_t vp, mpi_size_t vsize, mpi_limb_t *_result)
 429{
 430        mpi_ptr_t prod_endp = prodp + usize + vsize - 1;
 431        mpi_limb_t cy;
 432        struct karatsuba_ctx ctx;
 433
 434        if (vsize < KARATSUBA_THRESHOLD) {
 435                mpi_size_t i;
 436                mpi_limb_t v_limb;
 437
 438                if (!vsize) {
 439                        *_result = 0;
 440                        return 0;
 441                }
 442
 443                /* Multiply by the first limb in V separately, as the result can be
 444                 * stored (not added) to PROD.  We also avoid a loop for zeroing.  */
 445                v_limb = vp[0];
 446                if (v_limb <= 1) {
 447                        if (v_limb == 1)
 448                                MPN_COPY(prodp, up, usize);
 449                        else
 450                                MPN_ZERO(prodp, usize);
 451                        cy = 0;
 452                } else
 453                        cy = mpihelp_mul_1(prodp, up, usize, v_limb);
 454
 455                prodp[usize] = cy;
 456                prodp++;
 457
 458                /* For each iteration in the outer loop, multiply one limb from
 459                 * U with one limb from V, and add it to PROD.  */
 460                for (i = 1; i < vsize; i++) {
 461                        v_limb = vp[i];
 462                        if (v_limb <= 1) {
 463                                cy = 0;
 464                                if (v_limb == 1)
 465                                        cy = mpihelp_add_n(prodp, prodp, up,
 466                                                           usize);
 467                        } else
 468                                cy = mpihelp_addmul_1(prodp, up, usize, v_limb);
 469
 470                        prodp[usize] = cy;
 471                        prodp++;
 472                }
 473
 474                *_result = cy;
 475                return 0;
 476        }
 477
 478        memset(&ctx, 0, sizeof ctx);
 479        if (mpihelp_mul_karatsuba_case(prodp, up, usize, vp, vsize, &ctx) < 0)
 480                return -ENOMEM;
 481        mpihelp_release_karatsuba_ctx(&ctx);
 482        *_result = *prod_endp;
 483        return 0;
 484}
 485