in libvmaf/src/feature/common/convolution_avx.c [417:546]
static void convolution_f32_avx_s_1d_v_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{
if (N == 5)
{
convolution_f32_avx_s_1d_v_scanline_5(filter, filter_width, src, dst, src_stride, j_end);
}
else if (N == 9)
{
convolution_f32_avx_s_1d_v_scanline_9(filter, filter_width, src, dst, src_stride, j_end);
}
else if (N == 17)
{
convolution_f32_avx_s_1d_v_scanline_17(filter, filter_width, src, dst, src_stride, j_end);
}
else {
int radius = filter_width / 2;
src -= radius * src_stride;
for (int y = 0; y < filter_width; y += 9) {
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
f0 = _mm256_setzero_ps();
f1 = _mm256_setzero_ps();
f2 = _mm256_setzero_ps();
f3 = _mm256_setzero_ps();
f5 = _mm256_setzero_ps();
f6 = _mm256_setzero_ps();
f7 = _mm256_setzero_ps();
f8 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
f8 = _mm256_broadcast_ss(filter + y + 8);
// fall through
case 8:
f7 = _mm256_broadcast_ss(filter + y + 7);
// fall through
case 7:
f6 = _mm256_broadcast_ss(filter + y + 6);
// fall through
case 6:
f5 = _mm256_broadcast_ss(filter + y + 5);
// fall through
case 5:
f4 = _mm256_broadcast_ss(filter + y + 4);
// fall through
case 4:
f3 = _mm256_broadcast_ss(filter + y + 3);
// fall through
case 3:
f2 = _mm256_broadcast_ss(filter + y + 2);
// fall through
case 2:
f1 = _mm256_broadcast_ss(filter + y + 1);
// fall through
case 1:
f0 = _mm256_broadcast_ss(filter + y + 0);
// fall through
}
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
sum0 = _mm256_setzero_ps();
sum1 = _mm256_setzero_ps();
sum2 = _mm256_setzero_ps();
sum3 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
g = _mm256_load_ps(src + (y + 8) * src_stride + j);
sum0 = _mm256_mul_ps(f8, g);
// fall through
case 8:
g = _mm256_load_ps(src + (y + 7) * src_stride + j);
sum3 = _mm256_mul_ps(f7, g);
// fall through
case 7:
g = _mm256_load_ps(src + (y + 6) * src_stride + j);
sum2 = _mm256_mul_ps(f6, g);
// fall through
case 6:
g = _mm256_load_ps(src + (y + 5) * src_stride + j);
sum1 = _mm256_mul_ps(f5, g);
// fall through
case 5:
g = _mm256_load_ps(src + (y + 4) * src_stride + j);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
case 4:
g = _mm256_load_ps(src + (y + 3) * src_stride + j);
g = _mm256_mul_ps(f3, g);
sum3 = _mm256_add_ps(sum3, g);
// fall through
case 3:
g = _mm256_load_ps(src + (y + 2) * src_stride + j);
g = _mm256_mul_ps(f2, g);
sum2 = _mm256_add_ps(sum2, g);
// fall through
case 2:
g = _mm256_load_ps(src + (y + 1) * src_stride + j);
g = _mm256_mul_ps(f1, g);
sum1 = _mm256_add_ps(sum1, g);
// fall through
case 1:
g = _mm256_load_ps(src + (y + 0) * src_stride + j);
g = _mm256_mul_ps(f0, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
}
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
if (y)
accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j));
_mm256_store_ps(dst + j, accum);
}
}
}
}