static void convolution_f32_avx_s_1d_v_xy_scanline()

in libvmaf/src/feature/common/convolution_avx.c [1910:2058]


static void convolution_f32_avx_s_1d_v_xy_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end)
{

	if (N == 5)
	{
		convolution_f32_avx_s_1d_v_xy_scanline_5(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end);
	}
	else if (N == 9)
	{
		convolution_f32_avx_s_1d_v_xy_scanline_9(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end);
	}
	else if (N == 17)
	{
		convolution_f32_avx_s_1d_v_xy_scanline_17(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end);
	}
	else {

		int radius = filter_width / 2;
		src1 -= radius * src1_stride;
		src2 -= radius * src2_stride;

		for (int y = 0; y < filter_width; y += 9) {
			__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;

			f0 = _mm256_setzero_ps();
			f1 = _mm256_setzero_ps();
			f2 = _mm256_setzero_ps();
			f3 = _mm256_setzero_ps();
			f5 = _mm256_setzero_ps();
			f6 = _mm256_setzero_ps();
			f7 = _mm256_setzero_ps();
			f8 = _mm256_setzero_ps();

			switch (filter_width - y) {
			default:
				f8 = _mm256_broadcast_ss(filter + y + 8);
				// fall through
			case 8:
				f7 = _mm256_broadcast_ss(filter + y + 7);
				// fall through
			case 7:
				f6 = _mm256_broadcast_ss(filter + y + 6);
				// fall through
			case 6:
				f5 = _mm256_broadcast_ss(filter + y + 5);
				// fall through
			case 5:
				f4 = _mm256_broadcast_ss(filter + y + 4);
				// fall through
			case 4:
				f3 = _mm256_broadcast_ss(filter + y + 3);
				// fall through
			case 3:
				f2 = _mm256_broadcast_ss(filter + y + 2);
				// fall through
			case 2:
				f1 = _mm256_broadcast_ss(filter + y + 1);
				// fall through
			case 1:
				f0 = _mm256_broadcast_ss(filter + y + 0);
				// fall through
			}

			for (int j = 0; j < j_end; j += 8) {
				__m256 accum = _mm256_setzero_ps();
				__m256 sum0, sum1, sum2, sum3;
				__m256 g, g2;

				sum0 = _mm256_setzero_ps();
				sum1 = _mm256_setzero_ps();
				sum2 = _mm256_setzero_ps();
				sum3 = _mm256_setzero_ps();

				switch (filter_width - y) {
				default:
					g = _mm256_load_ps(src1 + (y + 8) * src1_stride + j);
					g2 = _mm256_load_ps(src2 + (y + 8) * src2_stride + j);
					g = _mm256_mul_ps(g, g2);
					sum0 = _mm256_mul_ps(f8, g);
					// fall through
				case 8:
					g = _mm256_load_ps(src1 + (y + 7) * src1_stride + j);
					g2 = _mm256_load_ps(src2 + (y + 7) * src2_stride + j);
					g = _mm256_mul_ps(g, g2);
					sum3 = _mm256_mul_ps(f7, g);
					// fall through
				case 7:
					g = _mm256_load_ps(src1 + (y + 6) * src1_stride + j);
					g2 = _mm256_load_ps(src2 + (y + 6) * src2_stride + j);
					g = _mm256_mul_ps(g, g2);
					sum2 = _mm256_mul_ps(f6, g);
					// fall through
				case 6:
					g = _mm256_load_ps(src1 + (y + 5) * src1_stride + j);
					g2 = _mm256_load_ps(src2 + (y + 5) * src2_stride + j);
					g = _mm256_mul_ps(g, g2);
					sum1 = _mm256_mul_ps(f5, g);
					// fall through
				case 5:
					g = _mm256_load_ps(src1 + (y + 4) * src1_stride + j);
					g2 = _mm256_load_ps(src2 + (y + 4) * src2_stride + j);
					g = _mm256_mul_ps(g, g2);
					g = _mm256_mul_ps(f4, g);
					sum0 = _mm256_add_ps(sum0, g);
					// fall through
				case 4:
					g = _mm256_load_ps(src1 + (y + 3) * src1_stride + j);
					g2 = _mm256_load_ps(src2 + (y + 3) * src2_stride + j);
					g = _mm256_mul_ps(g, g2);
					g = _mm256_mul_ps(f3, g);
					sum3 = _mm256_add_ps(sum3, g);
					// fall through
				case 3:
					g = _mm256_load_ps(src1 + (y + 2) * src1_stride + j);
					g2 = _mm256_load_ps(src2 + (y + 2) * src2_stride + j);
					g = _mm256_mul_ps(g, g2);
					g = _mm256_mul_ps(f2, g);
					sum2 = _mm256_add_ps(sum2, g);
					// fall through
				case 2:
					g = _mm256_load_ps(src1 + (y + 1) * src1_stride + j);
					g2 = _mm256_load_ps(src2 + (y + 1) * src2_stride + j);
					g = _mm256_mul_ps(g, g2);
					g = _mm256_mul_ps(f1, g);
					sum1 = _mm256_add_ps(sum1, g);
					// fall through
				case 1:
					g = _mm256_load_ps(src1 + (y + 0) * src1_stride + j);
					g2 = _mm256_load_ps(src2 + (y + 0) * src2_stride + j);
					g = _mm256_mul_ps(g, g2);
					g = _mm256_mul_ps(f0, g);
					sum0 = _mm256_add_ps(sum0, g);
					// fall through
				}

				sum0 = _mm256_add_ps(sum0, sum2);
				sum1 = _mm256_add_ps(sum1, sum3);

				sum0 = _mm256_add_ps(sum0, sum1);
				accum = _mm256_add_ps(accum, sum0);

				if (y)
					accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j));

				_mm256_store_ps(dst + j, accum);
			}
		}
	}
}