in libvmaf/src/feature/common/convolution_avx.c [1910:2058]
static void convolution_f32_avx_s_1d_v_xy_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end)
{
if (N == 5)
{
convolution_f32_avx_s_1d_v_xy_scanline_5(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end);
}
else if (N == 9)
{
convolution_f32_avx_s_1d_v_xy_scanline_9(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end);
}
else if (N == 17)
{
convolution_f32_avx_s_1d_v_xy_scanline_17(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end);
}
else {
int radius = filter_width / 2;
src1 -= radius * src1_stride;
src2 -= radius * src2_stride;
for (int y = 0; y < filter_width; y += 9) {
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
f0 = _mm256_setzero_ps();
f1 = _mm256_setzero_ps();
f2 = _mm256_setzero_ps();
f3 = _mm256_setzero_ps();
f5 = _mm256_setzero_ps();
f6 = _mm256_setzero_ps();
f7 = _mm256_setzero_ps();
f8 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
f8 = _mm256_broadcast_ss(filter + y + 8);
// fall through
case 8:
f7 = _mm256_broadcast_ss(filter + y + 7);
// fall through
case 7:
f6 = _mm256_broadcast_ss(filter + y + 6);
// fall through
case 6:
f5 = _mm256_broadcast_ss(filter + y + 5);
// fall through
case 5:
f4 = _mm256_broadcast_ss(filter + y + 4);
// fall through
case 4:
f3 = _mm256_broadcast_ss(filter + y + 3);
// fall through
case 3:
f2 = _mm256_broadcast_ss(filter + y + 2);
// fall through
case 2:
f1 = _mm256_broadcast_ss(filter + y + 1);
// fall through
case 1:
f0 = _mm256_broadcast_ss(filter + y + 0);
// fall through
}
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
sum0 = _mm256_setzero_ps();
sum1 = _mm256_setzero_ps();
sum2 = _mm256_setzero_ps();
sum3 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
g = _mm256_load_ps(src1 + (y + 8) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 8) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
sum0 = _mm256_mul_ps(f8, g);
// fall through
case 8:
g = _mm256_load_ps(src1 + (y + 7) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 7) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
sum3 = _mm256_mul_ps(f7, g);
// fall through
case 7:
g = _mm256_load_ps(src1 + (y + 6) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 6) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
sum2 = _mm256_mul_ps(f6, g);
// fall through
case 6:
g = _mm256_load_ps(src1 + (y + 5) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 5) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
sum1 = _mm256_mul_ps(f5, g);
// fall through
case 5:
g = _mm256_load_ps(src1 + (y + 4) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 4) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
case 4:
g = _mm256_load_ps(src1 + (y + 3) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 3) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = _mm256_add_ps(sum3, g);
// fall through
case 3:
g = _mm256_load_ps(src1 + (y + 2) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 2) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = _mm256_add_ps(sum2, g);
// fall through
case 2:
g = _mm256_load_ps(src1 + (y + 1) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 1) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = _mm256_add_ps(sum1, g);
// fall through
case 1:
g = _mm256_load_ps(src1 + (y + 0) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 0) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
}
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
if (y)
accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j));
_mm256_store_ps(dst + j, accum);
}
}
}
}