in lattices/lattice_utils.cpp [293:328]
float fvec_L2sqr (const float * x,
const float * y,
size_t d)
{
__m256 msum1 = _mm256_setzero_ps();
while (d >= 8) {
__m256 mx = _mm256_loadu_ps (x); x += 8;
__m256 my = _mm256_loadu_ps (y); y += 8;
const __m256 a_m_b1 = mx - my;
msum1 += a_m_b1 * a_m_b1;
d -= 8;
}
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
msum2 += _mm256_extractf128_ps(msum1, 0);
if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
const __m128 a_m_b1 = mx - my;
msum2 += a_m_b1 * a_m_b1;
d -= 4;
}
if (d > 0) {
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
__m128 a_m_b1 = mx - my;
msum2 += a_m_b1 * a_m_b1;
}
msum2 = _mm_hadd_ps (msum2, msum2);
msum2 = _mm_hadd_ps (msum2, msum2);
return _mm_cvtss_f32 (msum2);
}