in lattices/lattice_utils.cpp [259:291]
float fvec_inner_product (const float * x,
const float * y,
size_t d)
{
__m256 msum1 = _mm256_setzero_ps();
while (d >= 8) {
__m256 mx = _mm256_loadu_ps (x); x += 8;
__m256 my = _mm256_loadu_ps (y); y += 8;
msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
d -= 8;
}
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
msum2 += _mm256_extractf128_ps(msum1, 0);
if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
d -= 4;
}
if (d > 0) {
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
}
msum2 = _mm_hadd_ps (msum2, msum2);
msum2 = _mm_hadd_ps (msum2, msum2);
return _mm_cvtss_f32 (msum2);
}