in Sources/CCryptoBoringSSL/crypto/hrss/hrss.c [979:1214]
static void poly_mul_vec_aux(vec_t *restrict out, vec_t *restrict scratch,
const vec_t *restrict a, const vec_t *restrict b,
const size_t n) {
// In [HRSS], the technique they used for polynomial multiplication is
// described: they start with Toom-4 at the top level and then two layers of
// Karatsuba. Karatsuba is a specific instance of the general Toom–Cook
// decomposition, which splits an input n-ways and produces 2n-1
// multiplications of those parts. So, starting with 704 coefficients (rounded
// up from 701 to have more factors of two), Toom-4 gives seven
// multiplications of degree-174 polynomials. Each round of Karatsuba (which
// is Toom-2) increases the number of multiplications by a factor of three
// while halving the size of the values being multiplied. So two rounds gives
// 63 multiplications of degree-44 polynomials. Then they (I think) form
// vectors by gathering all 63 coefficients of each power together, for each
// input, and doing more rounds of Karatsuba on the vectors until they bottom-
// out somewhere with schoolbook multiplication.
//
// I tried something like that for NEON. NEON vectors are 128 bits so hold
// eight coefficients. I wrote a function that did Karatsuba on eight
// multiplications at the same time, using such vectors, and a Go script that
// decomposed from degree-704, with Karatsuba in non-transposed form, until it
// reached multiplications of degree-44. It batched up those 81
// multiplications into lots of eight with a single one left over (which was
// handled directly).
//
// It worked, but it was significantly slower than the dumb algorithm used
// below. Potentially that was because I misunderstood how [HRSS] did it, or
// because Clang is bad at generating good code from NEON intrinsics on ARMv7.
// (Which is true: the code generated by Clang for the below is pretty crap.)
//
// This algorithm is much simpler. It just does Karatsuba decomposition all
// the way down and never transposes. When it gets down to degree-16 or
// degree-24 values, they are multiplied using schoolbook multiplication and
// vector intrinsics. The vector operations form each of the eight phase-
// shifts of one of the inputs, point-wise multiply, and then add into the
// result at the correct place. This means that 33% (degree-16) or 25%
// (degree-24) of the multiplies and adds are wasted, but it does ok.
if (n == 2) {
vec_t result[4];
vec_t vec_a[3];
static const vec_t kZero = {0};
vec_a[0] = a[0];
vec_a[1] = a[1];
vec_a[2] = kZero;
result[0] = vec_mul(vec_a[0], vec_get_word(b[0], 0));
result[1] = vec_mul(vec_a[1], vec_get_word(b[0], 0));
result[1] = vec_fma(result[1], vec_a[0], vec_get_word(b[1], 0));
result[2] = vec_mul(vec_a[1], vec_get_word(b[1], 0));
result[3] = kZero;
vec3_rshift_word(vec_a);
#define BLOCK(x, y) \
do { \
result[x + 0] = \
vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
result[x + 1] = \
vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
result[x + 2] = \
vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
} while (0)
BLOCK(0, 1);
BLOCK(1, 9);
vec3_rshift_word(vec_a);
BLOCK(0, 2);
BLOCK(1, 10);
vec3_rshift_word(vec_a);
BLOCK(0, 3);
BLOCK(1, 11);
vec3_rshift_word(vec_a);
BLOCK(0, 4);
BLOCK(1, 12);
vec3_rshift_word(vec_a);
BLOCK(0, 5);
BLOCK(1, 13);
vec3_rshift_word(vec_a);
BLOCK(0, 6);
BLOCK(1, 14);
vec3_rshift_word(vec_a);
BLOCK(0, 7);
BLOCK(1, 15);
#undef BLOCK
memcpy(out, result, sizeof(result));
return;
}
if (n == 3) {
vec_t result[6];
vec_t vec_a[4];
static const vec_t kZero = {0};
vec_a[0] = a[0];
vec_a[1] = a[1];
vec_a[2] = a[2];
vec_a[3] = kZero;
result[0] = vec_mul(a[0], vec_get_word(b[0], 0));
result[1] = vec_mul(a[1], vec_get_word(b[0], 0));
result[2] = vec_mul(a[2], vec_get_word(b[0], 0));
#define BLOCK_PRE(x, y) \
do { \
result[x + 0] = \
vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
result[x + 1] = \
vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
result[x + 2] = vec_mul(vec_a[2], vec_get_word(b[y / 8], y % 8)); \
} while (0)
BLOCK_PRE(1, 8);
BLOCK_PRE(2, 16);
result[5] = kZero;
vec4_rshift_word(vec_a);
#define BLOCK(x, y) \
do { \
result[x + 0] = \
vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
result[x + 1] = \
vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
result[x + 2] = \
vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
result[x + 3] = \
vec_fma(result[x + 3], vec_a[3], vec_get_word(b[y / 8], y % 8)); \
} while (0)
BLOCK(0, 1);
BLOCK(1, 9);
BLOCK(2, 17);
vec4_rshift_word(vec_a);
BLOCK(0, 2);
BLOCK(1, 10);
BLOCK(2, 18);
vec4_rshift_word(vec_a);
BLOCK(0, 3);
BLOCK(1, 11);
BLOCK(2, 19);
vec4_rshift_word(vec_a);
BLOCK(0, 4);
BLOCK(1, 12);
BLOCK(2, 20);
vec4_rshift_word(vec_a);
BLOCK(0, 5);
BLOCK(1, 13);
BLOCK(2, 21);
vec4_rshift_word(vec_a);
BLOCK(0, 6);
BLOCK(1, 14);
BLOCK(2, 22);
vec4_rshift_word(vec_a);
BLOCK(0, 7);
BLOCK(1, 15);
BLOCK(2, 23);
#undef BLOCK
#undef BLOCK_PRE
memcpy(out, result, sizeof(result));
return;
}
// Karatsuba multiplication.
// https://en.wikipedia.org/wiki/Karatsuba_algorithm
// When |n| is odd, the two "halves" will have different lengths. The first is
// always the smaller.
const size_t low_len = n / 2;
const size_t high_len = n - low_len;
const vec_t *a_high = &a[low_len];
const vec_t *b_high = &b[low_len];
// Store a_1 + a_0 in the first half of |out| and b_1 + b_0 in the second
// half.
for (size_t i = 0; i < low_len; i++) {
out[i] = vec_add(a_high[i], a[i]);
out[high_len + i] = vec_add(b_high[i], b[i]);
}
if (high_len != low_len) {
out[low_len] = a_high[low_len];
out[high_len + low_len] = b_high[low_len];
}
vec_t *const child_scratch = &scratch[2 * high_len];
// Calculate (a_1 + a_0) × (b_1 + b_0) and write to scratch buffer.
poly_mul_vec_aux(scratch, child_scratch, out, &out[high_len], high_len);
// Calculate a_1 × b_1.
poly_mul_vec_aux(&out[low_len * 2], child_scratch, a_high, b_high, high_len);
// Calculate a_0 × b_0.
poly_mul_vec_aux(out, child_scratch, a, b, low_len);
// Subtract those last two products from the first.
for (size_t i = 0; i < low_len * 2; i++) {
scratch[i] = vec_sub(scratch[i], vec_add(out[i], out[low_len * 2 + i]));
}
if (low_len != high_len) {
scratch[low_len * 2] = vec_sub(scratch[low_len * 2], out[low_len * 4]);
scratch[low_len * 2 + 1] =
vec_sub(scratch[low_len * 2 + 1], out[low_len * 4 + 1]);
}
// Add the middle product into the output.
for (size_t i = 0; i < high_len * 2; i++) {
out[low_len + i] = vec_add(out[low_len + i], scratch[i]);
}
}