static void convolution_f32_avx_s_1d_v_sq_scanline()

in libvmaf/src/feature/common/convolution_avx.c [1128:1266]


static void convolution_f32_avx_s_1d_v_sq_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{

	if (N == 5)
	{
		convolution_f32_avx_s_1d_v_sq_scanline_5(filter, filter_width, src, dst, src_stride, j_end);
	}
	else if (N == 9)
	{
		convolution_f32_avx_s_1d_v_sq_scanline_9(filter, filter_width, src, dst, src_stride, j_end);
	}
	else if (N == 17)
	{
		convolution_f32_avx_s_1d_v_sq_scanline_17(filter, filter_width, src, dst, src_stride, j_end);
	}
	else {

		int radius = filter_width / 2;
		src -= radius * src_stride;

		for (int y = 0; y < filter_width; y += 9) {
			__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;

			f0 = _mm256_setzero_ps();
			f1 = _mm256_setzero_ps();
			f2 = _mm256_setzero_ps();
			f3 = _mm256_setzero_ps();
			f5 = _mm256_setzero_ps();
			f6 = _mm256_setzero_ps();
			f7 = _mm256_setzero_ps();
			f8 = _mm256_setzero_ps();

			switch (filter_width - y) {
			default:
				f8 = _mm256_broadcast_ss(filter + y + 8);
				// fall through
			case 8:
				f7 = _mm256_broadcast_ss(filter + y + 7);
				// fall through
			case 7:
				f6 = _mm256_broadcast_ss(filter + y + 6);
				// fall through
			case 6:
				f5 = _mm256_broadcast_ss(filter + y + 5);
				// fall through
			case 5:
				f4 = _mm256_broadcast_ss(filter + y + 4);
				// fall through
			case 4:
				f3 = _mm256_broadcast_ss(filter + y + 3);
				// fall through
			case 3:
				f2 = _mm256_broadcast_ss(filter + y + 2);
				// fall through
			case 2:
				f1 = _mm256_broadcast_ss(filter + y + 1);
				// fall through
			case 1:
				f0 = _mm256_broadcast_ss(filter + y + 0);
				// fall through
			}

			for (int j = 0; j < j_end; j += 8) {
				__m256 accum = _mm256_setzero_ps();
				__m256 sum0, sum1, sum2, sum3;
				__m256 g;

				sum0 = _mm256_setzero_ps();
				sum1 = _mm256_setzero_ps();
				sum2 = _mm256_setzero_ps();
				sum3 = _mm256_setzero_ps();

				switch (filter_width - y) {
				default:
					g = _mm256_load_ps(src + (y + 8) * src_stride + j);
					g = _mm256_mul_ps(g, g);
					sum0 = _mm256_mul_ps(f8, g);
					// fall through
				case 8:
					g = _mm256_load_ps(src + (y + 7) * src_stride + j);
					g = _mm256_mul_ps(g, g);
					sum3 = _mm256_mul_ps(f7, g);
					// fall through
				case 7:
					g = _mm256_load_ps(src + (y + 6) * src_stride + j);
					g = _mm256_mul_ps(g, g);
					sum2 = _mm256_mul_ps(f6, g);
					// fall through
				case 6:
					g = _mm256_load_ps(src + (y + 5) * src_stride + j);
					g = _mm256_mul_ps(g, g);
					sum1 = _mm256_mul_ps(f5, g);
					// fall through
				case 5:
					g = _mm256_load_ps(src + (y + 4) * src_stride + j);
					g = _mm256_mul_ps(g, g);
					g = _mm256_mul_ps(f4, g);
					sum0 = _mm256_add_ps(sum0, g);
					// fall through
				case 4:
					g = _mm256_load_ps(src + (y + 3) * src_stride + j);
					g = _mm256_mul_ps(g, g);
					g = _mm256_mul_ps(f3, g);
					sum3 = _mm256_add_ps(sum3, g);
					// fall through
				case 3:
					g = _mm256_load_ps(src + (y + 2) * src_stride + j);
					g = _mm256_mul_ps(g, g);
					g = _mm256_mul_ps(f2, g);
					sum2 = _mm256_add_ps(sum2, g);
					// fall through
				case 2:
					g = _mm256_load_ps(src + (y + 1) * src_stride + j);
					g = _mm256_mul_ps(g, g);
					g = _mm256_mul_ps(f1, g);
					sum1 = _mm256_add_ps(sum1, g);
					// fall through
				case 1:
					g = _mm256_load_ps(src + (y + 0) * src_stride + j);
					g = _mm256_mul_ps(g, g);
					g = _mm256_mul_ps(f0, g);
					sum0 = _mm256_add_ps(sum0, g);
					// fall through
				}

				sum0 = _mm256_add_ps(sum0, sum2);
				sum1 = _mm256_add_ps(sum1, sum3);

				sum0 = _mm256_add_ps(sum0, sum1);
				accum = _mm256_add_ps(accum, sum0);

				if (y)
					accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j));

				_mm256_store_ps(dst + j, accum);
			}
		}
	}
}