int BN_mod_exp_mont_consttime()

in src/crypto/fipsmodule/bn/exponentiation.c [901:1223]


int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                              const BIGNUM *m, BN_CTX *ctx,
                              const BN_MONT_CTX *mont) {
  int i, ret = 0, window, wvalue;
  BN_MONT_CTX *new_mont = NULL;

  int numPowers;
  unsigned char *powerbufFree = NULL;
  int powerbufLen = 0;
  BN_ULONG *powerbuf = NULL;
  BIGNUM tmp, am;

  if (!BN_is_odd(m)) {
    OPENSSL_PUT_ERROR(BN, BN_R_CALLED_WITH_EVEN_MODULUS);
    return 0;
  }
  if (m->neg) {
    OPENSSL_PUT_ERROR(BN, BN_R_NEGATIVE_NUMBER);
    return 0;
  }
  if (a->neg || BN_ucmp(a, m) >= 0) {
    OPENSSL_PUT_ERROR(BN, BN_R_INPUT_NOT_REDUCED);
    return 0;
  }

  // Use all bits stored in |p|, rather than |BN_num_bits|, so we do not leak
  // whether the top bits are zero.
  int max_bits = p->width * BN_BITS2;
  int bits = max_bits;
  if (bits == 0) {
    // x**0 mod 1 is still zero.
    if (BN_abs_is_word(m, 1)) {
      BN_zero(rr);
      return 1;
    }
    return BN_one(rr);
  }

  // Allocate a montgomery context if it was not supplied by the caller.
  if (mont == NULL) {
    new_mont = BN_MONT_CTX_new_consttime(m, ctx);
    if (new_mont == NULL) {
      goto err;
    }
    mont = new_mont;
  }

  // Use the width in |mont->N|, rather than the copy in |m|. The assembly
  // implementation assumes it can use |top| to size R.
  int top = mont->N.width;

#if defined(OPENSSL_BN_ASM_MONT5) || defined(RSAZ_ENABLED)
  // Share one large stack-allocated buffer between the RSAZ and non-RSAZ code
  // paths. If we were to use separate static buffers for each then there is
  // some chance that both large buffers would be allocated on the stack,
  // causing the stack space requirement to be truly huge (~10KB).
  alignas(MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH) BN_ULONG
    storage[MOD_EXP_CTIME_STORAGE_LEN];
#endif
#if defined(RSAZ_ENABLED)
  // If the size of the operands allow it, perform the optimized RSAZ
  // exponentiation. For further information see crypto/fipsmodule/bn/rsaz_exp.c
  // and accompanying assembly modules.
  if (a->width == 16 && p->width == 16 && BN_num_bits(m) == 1024 &&
      rsaz_avx2_preferred()) {
    if (!bn_wexpand(rr, 16)) {
      goto err;
    }
    RSAZ_1024_mod_exp_avx2(rr->d, a->d, p->d, m->d, mont->RR.d, mont->n0[0],
                           storage);
    rr->width = 16;
    rr->neg = 0;
    ret = 1;
    goto err;
  }
#endif

  // Get the window size to use with size of p.
  window = BN_window_bits_for_ctime_exponent_size(bits);
#if defined(OPENSSL_BN_ASM_MONT5)
  if (window >= 5) {
    window = 5;  // ~5% improvement for RSA2048 sign, and even for RSA4096
    // reserve space for mont->N.d[] copy
    powerbufLen += top * sizeof(mont->N.d[0]);
  }
#endif

  // Allocate a buffer large enough to hold all of the pre-computed
  // powers of am, am itself and tmp.
  numPowers = 1 << window;
  powerbufLen +=
      sizeof(m->d[0]) *
      (top * numPowers + ((2 * top) > numPowers ? (2 * top) : numPowers));

#if defined(OPENSSL_BN_ASM_MONT5)
  if ((size_t)powerbufLen <= sizeof(storage)) {
    powerbuf = storage;
  }
  // |storage| is more than large enough to handle 1024-bit inputs.
  assert(powerbuf != NULL || top * BN_BITS2 > 1024);
#endif
  if (powerbuf == NULL) {
    powerbufFree =
        OPENSSL_malloc(powerbufLen + MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH);
    if (powerbufFree == NULL) {
      goto err;
    }
    powerbuf = (BN_ULONG *)MOD_EXP_CTIME_ALIGN(powerbufFree);
  }
  OPENSSL_memset(powerbuf, 0, powerbufLen);

  // lay down tmp and am right after powers table
  tmp.d = powerbuf + top * numPowers;
  am.d = tmp.d + top;
  tmp.width = am.width = 0;
  tmp.dmax = am.dmax = top;
  tmp.neg = am.neg = 0;
  tmp.flags = am.flags = BN_FLG_STATIC_DATA;

  if (!bn_one_to_montgomery(&tmp, mont, ctx)) {
    goto err;
  }

  // prepare a^1 in Montgomery domain
  assert(!a->neg);
  assert(BN_ucmp(a, m) < 0);
  if (!BN_to_montgomery(&am, a, mont, ctx)) {
    goto err;
  }

#if defined(OPENSSL_BN_ASM_MONT5)
  // This optimization uses ideas from http://eprint.iacr.org/2011/239,
  // specifically optimization of cache-timing attack countermeasures
  // and pre-computation optimization.

  // Dedicated window==4 case improves 512-bit RSA sign by ~15%, but as
  // 512-bit RSA is hardly relevant, we omit it to spare size...
  if (window == 5 && top > 1) {
    const BN_ULONG *n0 = mont->n0;
    BN_ULONG *np;

    // BN_to_montgomery can contaminate words above .top
    // [in BN_DEBUG[_DEBUG] build]...
    for (i = am.width; i < top; i++) {
      am.d[i] = 0;
    }
    for (i = tmp.width; i < top; i++) {
      tmp.d[i] = 0;
    }

    // copy mont->N.d[] to improve cache locality
    for (np = am.d + top, i = 0; i < top; i++) {
      np[i] = mont->N.d[i];
    }

    bn_scatter5(tmp.d, top, powerbuf, 0);
    bn_scatter5(am.d, am.width, powerbuf, 1);
    bn_mul_mont(tmp.d, am.d, am.d, np, n0, top);
    bn_scatter5(tmp.d, top, powerbuf, 2);

    // same as above, but uses squaring for 1/2 of operations
    for (i = 4; i < 32; i *= 2) {
      bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
      bn_scatter5(tmp.d, top, powerbuf, i);
    }
    for (i = 3; i < 8; i += 2) {
      int j;
      bn_mul_mont_gather5(tmp.d, am.d, powerbuf, np, n0, top, i - 1);
      bn_scatter5(tmp.d, top, powerbuf, i);
      for (j = 2 * i; j < 32; j *= 2) {
        bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
        bn_scatter5(tmp.d, top, powerbuf, j);
      }
    }
    for (; i < 16; i += 2) {
      bn_mul_mont_gather5(tmp.d, am.d, powerbuf, np, n0, top, i - 1);
      bn_scatter5(tmp.d, top, powerbuf, i);
      bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
      bn_scatter5(tmp.d, top, powerbuf, 2 * i);
    }
    for (; i < 32; i += 2) {
      bn_mul_mont_gather5(tmp.d, am.d, powerbuf, np, n0, top, i - 1);
      bn_scatter5(tmp.d, top, powerbuf, i);
    }

    bits--;
    for (wvalue = 0, i = bits % 5; i >= 0; i--, bits--) {
      wvalue = (wvalue << 1) + BN_is_bit_set(p, bits);
    }
    bn_gather5(tmp.d, top, powerbuf, wvalue);

    // At this point |bits| is 4 mod 5 and at least -1. (|bits| is the first bit
    // that has not been read yet.)
    assert(bits >= -1 && (bits == -1 || bits % 5 == 4));

    // Scan the exponent one window at a time starting from the most
    // significant bits.
    if (top & 7) {
      while (bits >= 0) {
        for (wvalue = 0, i = 0; i < 5; i++, bits--) {
          wvalue = (wvalue << 1) + BN_is_bit_set(p, bits);
        }

        bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
        bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
        bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
        bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
        bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
        bn_mul_mont_gather5(tmp.d, tmp.d, powerbuf, np, n0, top, wvalue);
      }
    } else {
      const uint8_t *p_bytes = (const uint8_t *)p->d;
      assert(bits < max_bits);
      // |p = 0| has been handled as a special case, so |max_bits| is at least
      // one word.
      assert(max_bits >= 64);

      // If the first bit to be read lands in the last byte, unroll the first
      // iteration to avoid reading past the bounds of |p->d|. (After the first
      // iteration, we are guaranteed to be past the last byte.) Note |bits|
      // here is the top bit, inclusive.
      if (bits - 4 >= max_bits - 8) {
        // Read five bits from |bits-4| through |bits|, inclusive.
        wvalue = p_bytes[p->width * BN_BYTES - 1];
        wvalue >>= (bits - 4) & 7;
        wvalue &= 0x1f;
        bits -= 5;
        bn_power5(tmp.d, tmp.d, powerbuf, np, n0, top, wvalue);
      }
      while (bits >= 0) {
        // Read five bits from |bits-4| through |bits|, inclusive.
        int first_bit = bits - 4;
        uint16_t val;
        OPENSSL_memcpy(&val, p_bytes + (first_bit >> 3), sizeof(val));
        val >>= first_bit & 7;
        val &= 0x1f;
        bits -= 5;
        bn_power5(tmp.d, tmp.d, powerbuf, np, n0, top, val);
      }
    }

    ret = bn_from_montgomery(tmp.d, tmp.d, NULL, np, n0, top);
    tmp.width = top;
    if (ret) {
      if (!BN_copy(rr, &tmp)) {
        ret = 0;
      }
      goto err;  // non-zero ret means it's not error
    }
  } else
#endif
  {
    copy_to_prebuf(&tmp, top, powerbuf, 0, window);
    copy_to_prebuf(&am, top, powerbuf, 1, window);

    // If the window size is greater than 1, then calculate
    // val[i=2..2^winsize-1]. Powers are computed as a*a^(i-1)
    // (even powers could instead be computed as (a^(i/2))^2
    // to use the slight performance advantage of sqr over mul).
    if (window > 1) {
      if (!BN_mod_mul_montgomery(&tmp, &am, &am, mont, ctx)) {
        goto err;
      }

      copy_to_prebuf(&tmp, top, powerbuf, 2, window);

      for (i = 3; i < numPowers; i++) {
        // Calculate a^i = a^(i-1) * a
        if (!BN_mod_mul_montgomery(&tmp, &am, &tmp, mont, ctx)) {
          goto err;
        }

        copy_to_prebuf(&tmp, top, powerbuf, i, window);
      }
    }

    bits--;
    for (wvalue = 0, i = bits % window; i >= 0; i--, bits--) {
      wvalue = (wvalue << 1) + BN_is_bit_set(p, bits);
    }
    if (!copy_from_prebuf(&tmp, top, powerbuf, wvalue, window)) {
      goto err;
    }

    // Scan the exponent one window at a time starting from the most
    // significant bits.
    while (bits >= 0) {
      wvalue = 0;  // The 'value' of the window

      // Scan the window, squaring the result as we go
      for (i = 0; i < window; i++, bits--) {
        if (!BN_mod_mul_montgomery(&tmp, &tmp, &tmp, mont, ctx)) {
          goto err;
        }
        wvalue = (wvalue << 1) + BN_is_bit_set(p, bits);
      }

      // Fetch the appropriate pre-computed value from the pre-buf
      if (!copy_from_prebuf(&am, top, powerbuf, wvalue, window)) {
        goto err;
      }

      // Multiply the result into the intermediate result
      if (!BN_mod_mul_montgomery(&tmp, &tmp, &am, mont, ctx)) {
        goto err;
      }
    }
  }

  // Convert the final result from montgomery to standard format
  if (!BN_from_montgomery(rr, &tmp, mont, ctx)) {
    goto err;
  }
  ret = 1;

err:
  BN_MONT_CTX_free(new_mont);
  if (powerbuf != NULL && powerbufFree == NULL) {
    OPENSSL_cleanse(powerbuf, powerbufLen);
  }
  OPENSSL_free(powerbufFree);
  return (ret);
}