static void poly_mul_vec_aux()

in Sources/CCryptoBoringSSL/crypto/hrss/hrss.c [979:1214]


static void poly_mul_vec_aux(vec_t *restrict out, vec_t *restrict scratch,
                             const vec_t *restrict a, const vec_t *restrict b,
                             const size_t n) {
  // In [HRSS], the technique they used for polynomial multiplication is
  // described: they start with Toom-4 at the top level and then two layers of
  // Karatsuba. Karatsuba is a specific instance of the general Toom–Cook
  // decomposition, which splits an input n-ways and produces 2n-1
  // multiplications of those parts. So, starting with 704 coefficients (rounded
  // up from 701 to have more factors of two), Toom-4 gives seven
  // multiplications of degree-174 polynomials. Each round of Karatsuba (which
  // is Toom-2) increases the number of multiplications by a factor of three
  // while halving the size of the values being multiplied. So two rounds gives
  // 63 multiplications of degree-44 polynomials. Then they (I think) form
  // vectors by gathering all 63 coefficients of each power together, for each
  // input, and doing more rounds of Karatsuba on the vectors until they bottom-
  // out somewhere with schoolbook multiplication.
  //
  // I tried something like that for NEON. NEON vectors are 128 bits so hold
  // eight coefficients. I wrote a function that did Karatsuba on eight
  // multiplications at the same time, using such vectors, and a Go script that
  // decomposed from degree-704, with Karatsuba in non-transposed form, until it
  // reached multiplications of degree-44. It batched up those 81
  // multiplications into lots of eight with a single one left over (which was
  // handled directly).
  //
  // It worked, but it was significantly slower than the dumb algorithm used
  // below. Potentially that was because I misunderstood how [HRSS] did it, or
  // because Clang is bad at generating good code from NEON intrinsics on ARMv7.
  // (Which is true: the code generated by Clang for the below is pretty crap.)
  //
  // This algorithm is much simpler. It just does Karatsuba decomposition all
  // the way down and never transposes. When it gets down to degree-16 or
  // degree-24 values, they are multiplied using schoolbook multiplication and
  // vector intrinsics. The vector operations form each of the eight phase-
  // shifts of one of the inputs, point-wise multiply, and then add into the
  // result at the correct place. This means that 33% (degree-16) or 25%
  // (degree-24) of the multiplies and adds are wasted, but it does ok.
  if (n == 2) {
    vec_t result[4];
    vec_t vec_a[3];
    static const vec_t kZero = {0};
    vec_a[0] = a[0];
    vec_a[1] = a[1];
    vec_a[2] = kZero;

    result[0] = vec_mul(vec_a[0], vec_get_word(b[0], 0));
    result[1] = vec_mul(vec_a[1], vec_get_word(b[0], 0));

    result[1] = vec_fma(result[1], vec_a[0], vec_get_word(b[1], 0));
    result[2] = vec_mul(vec_a[1], vec_get_word(b[1], 0));
    result[3] = kZero;

    vec3_rshift_word(vec_a);

#define BLOCK(x, y)                                                      \
  do {                                                                   \
    result[x + 0] =                                                      \
        vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
    result[x + 1] =                                                      \
        vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
    result[x + 2] =                                                      \
        vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
  } while (0)

    BLOCK(0, 1);
    BLOCK(1, 9);

    vec3_rshift_word(vec_a);

    BLOCK(0, 2);
    BLOCK(1, 10);

    vec3_rshift_word(vec_a);

    BLOCK(0, 3);
    BLOCK(1, 11);

    vec3_rshift_word(vec_a);

    BLOCK(0, 4);
    BLOCK(1, 12);

    vec3_rshift_word(vec_a);

    BLOCK(0, 5);
    BLOCK(1, 13);

    vec3_rshift_word(vec_a);

    BLOCK(0, 6);
    BLOCK(1, 14);

    vec3_rshift_word(vec_a);

    BLOCK(0, 7);
    BLOCK(1, 15);

#undef BLOCK

    memcpy(out, result, sizeof(result));
    return;
  }

  if (n == 3) {
    vec_t result[6];
    vec_t vec_a[4];
    static const vec_t kZero = {0};
    vec_a[0] = a[0];
    vec_a[1] = a[1];
    vec_a[2] = a[2];
    vec_a[3] = kZero;

    result[0] = vec_mul(a[0], vec_get_word(b[0], 0));
    result[1] = vec_mul(a[1], vec_get_word(b[0], 0));
    result[2] = vec_mul(a[2], vec_get_word(b[0], 0));

#define BLOCK_PRE(x, y)                                                  \
  do {                                                                   \
    result[x + 0] =                                                      \
        vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
    result[x + 1] =                                                      \
        vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
    result[x + 2] = vec_mul(vec_a[2], vec_get_word(b[y / 8], y % 8));    \
  } while (0)

    BLOCK_PRE(1, 8);
    BLOCK_PRE(2, 16);

    result[5] = kZero;

    vec4_rshift_word(vec_a);

#define BLOCK(x, y)                                                      \
  do {                                                                   \
    result[x + 0] =                                                      \
        vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
    result[x + 1] =                                                      \
        vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
    result[x + 2] =                                                      \
        vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
    result[x + 3] =                                                      \
        vec_fma(result[x + 3], vec_a[3], vec_get_word(b[y / 8], y % 8)); \
  } while (0)

    BLOCK(0, 1);
    BLOCK(1, 9);
    BLOCK(2, 17);

    vec4_rshift_word(vec_a);

    BLOCK(0, 2);
    BLOCK(1, 10);
    BLOCK(2, 18);

    vec4_rshift_word(vec_a);

    BLOCK(0, 3);
    BLOCK(1, 11);
    BLOCK(2, 19);

    vec4_rshift_word(vec_a);

    BLOCK(0, 4);
    BLOCK(1, 12);
    BLOCK(2, 20);

    vec4_rshift_word(vec_a);

    BLOCK(0, 5);
    BLOCK(1, 13);
    BLOCK(2, 21);

    vec4_rshift_word(vec_a);

    BLOCK(0, 6);
    BLOCK(1, 14);
    BLOCK(2, 22);

    vec4_rshift_word(vec_a);

    BLOCK(0, 7);
    BLOCK(1, 15);
    BLOCK(2, 23);

#undef BLOCK
#undef BLOCK_PRE

    memcpy(out, result, sizeof(result));

    return;
  }

  // Karatsuba multiplication.
  // https://en.wikipedia.org/wiki/Karatsuba_algorithm

  // When |n| is odd, the two "halves" will have different lengths. The first is
  // always the smaller.
  const size_t low_len = n / 2;
  const size_t high_len = n - low_len;
  const vec_t *a_high = &a[low_len];
  const vec_t *b_high = &b[low_len];

  // Store a_1 + a_0 in the first half of |out| and b_1 + b_0 in the second
  // half.
  for (size_t i = 0; i < low_len; i++) {
    out[i] = vec_add(a_high[i], a[i]);
    out[high_len + i] = vec_add(b_high[i], b[i]);
  }
  if (high_len != low_len) {
    out[low_len] = a_high[low_len];
    out[high_len + low_len] = b_high[low_len];
  }

  vec_t *const child_scratch = &scratch[2 * high_len];
  // Calculate (a_1 + a_0) × (b_1 + b_0) and write to scratch buffer.
  poly_mul_vec_aux(scratch, child_scratch, out, &out[high_len], high_len);
  // Calculate a_1 × b_1.
  poly_mul_vec_aux(&out[low_len * 2], child_scratch, a_high, b_high, high_len);
  // Calculate a_0 × b_0.
  poly_mul_vec_aux(out, child_scratch, a, b, low_len);

  // Subtract those last two products from the first.
  for (size_t i = 0; i < low_len * 2; i++) {
    scratch[i] = vec_sub(scratch[i], vec_add(out[i], out[low_len * 2 + i]));
  }
  if (low_len != high_len) {
    scratch[low_len * 2] = vec_sub(scratch[low_len * 2], out[low_len * 4]);
    scratch[low_len * 2 + 1] =
        vec_sub(scratch[low_len * 2 + 1], out[low_len * 4 + 1]);
  }

  // Add the middle product into the output.
  for (size_t i = 0; i < high_len * 2; i++) {
    out[low_len + i] = vec_add(out[low_len + i], scratch[i]);
  }
}