libvmaf/src/feature/common/convolution_avx.c (1,891 lines of code) (raw):
/**
*
* Copyright 2016-2020 Netflix, Inc.
*
* Licensed under the BSD+Patent License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSDplusPatent
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#include <immintrin.h>
#include "alignment.h"
#include "convolution.h"
#include "convolution_internal.h"
void convolution_f32_avx_s_1d_h_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end);
void convolution_f32_avx_s_1d_h_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end);
void convolution_f32_avx_s_1d_h_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end);
void convolution_f32_avx_s_1d_v_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end);
void convolution_f32_avx_s_1d_v_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end);
void convolution_f32_avx_s_1d_v_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end);
void convolution_f32_avx_s_1d_h_sq_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end);
void convolution_f32_avx_s_1d_h_sq_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end);
void convolution_f32_avx_s_1d_h_sq_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end);
void convolution_f32_avx_s_1d_v_sq_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end);
void convolution_f32_avx_s_1d_v_sq_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end);
void convolution_f32_avx_s_1d_v_sq_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end);
void convolution_f32_avx_s_1d_h_xy_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end);
void convolution_f32_avx_s_1d_h_xy_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end);
void convolution_f32_avx_s_1d_h_xy_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end);
void convolution_f32_avx_s_1d_v_xy_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end);
void convolution_f32_avx_s_1d_v_xy_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end);
void convolution_f32_avx_s_1d_v_xy_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end);
// Filter a single scanline.
static void convolution_f32_avx_s_1d_h_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end)
{
if (N == 5)
{
convolution_f32_avx_s_1d_h_scanline_5(filter, filter_width, src, dst, j_end);
}
else if (N == 9)
{
convolution_f32_avx_s_1d_h_scanline_9(filter, filter_width, src, dst, j_end);
}
else if (N == 17)
{
convolution_f32_avx_s_1d_h_scanline_17(filter, filter_width, src, dst, j_end);
}
else {
int radius = filter_width / 2;
for (int x = 0; x < filter_width; x += 9) {
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
f0 = _mm256_setzero_ps();
f1 = _mm256_setzero_ps();
f2 = _mm256_setzero_ps();
f3 = _mm256_setzero_ps();
f5 = _mm256_setzero_ps();
f6 = _mm256_setzero_ps();
f7 = _mm256_setzero_ps();
f8 = _mm256_setzero_ps();
switch (filter_width - x) {
default:
f8 = _mm256_broadcast_ss(filter + x + 8);
// fall through
case 8:
f7 = _mm256_broadcast_ss(filter + x + 7);
// fall through
case 7:
f6 = _mm256_broadcast_ss(filter + x + 6);
// fall through
case 6:
f5 = _mm256_broadcast_ss(filter + x + 5);
// fall through
case 5:
f4 = _mm256_broadcast_ss(filter + x + 4);
// fall through
case 4:
f3 = _mm256_broadcast_ss(filter + x + 3);
// fall through
case 3:
f2 = _mm256_broadcast_ss(filter + x + 2);
// fall through
case 2:
f1 = _mm256_broadcast_ss(filter + x + 1);
// fall through
case 1:
f0 = _mm256_broadcast_ss(filter + x + 0);
// fall through
}
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
sum0 = _mm256_setzero_ps();
sum1 = _mm256_setzero_ps();
sum2 = _mm256_setzero_ps();
sum3 = _mm256_setzero_ps();
switch (filter_width - x) {
default:
g = _mm256_loadu_ps(src + j + x + 8);
sum0 = _mm256_mul_ps(f8, g);
// fall through
case 8:
g = _mm256_loadu_ps(src + j + x + 7);
sum3 = _mm256_mul_ps(f7, g);
// fall through
case 7:
g = _mm256_loadu_ps(src + j + x + 6);
sum2 = _mm256_mul_ps(f6, g);
// fall through
case 6:
g = _mm256_loadu_ps(src + j + x + 5);
sum1 = _mm256_mul_ps(f5, g);
// fall through
case 5:
g = _mm256_loadu_ps(src + j + x + 4);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
case 4:
g = _mm256_loadu_ps(src + j + x + 3);
g = _mm256_mul_ps(f3, g);
sum3 = _mm256_add_ps(sum3, g);
// fall through
case 3:
g = _mm256_loadu_ps(src + j + x + 2);
g = _mm256_mul_ps(f2, g);
sum2 = _mm256_add_ps(sum2, g);
// fall through
case 2:
g = _mm256_loadu_ps(src + j + x + 1);
g = _mm256_mul_ps(f1, g);
sum1 = _mm256_add_ps(sum1, g);
// fall through
case 1:
g = _mm256_loadu_ps(src + j + x + 0);
g = _mm256_mul_ps(f0, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
}
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
if (x)
accum = _mm256_add_ps(accum, _mm256_loadu_ps(dst + j + radius));
_mm256_storeu_ps(dst + j + radius, accum);
}
}
}
}
void convolution_f32_avx_s_1d_h_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
// Evaluate filter taps 0-8
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_loadu_ps(src + j + 0);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src + j + 1);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src + j + 2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src + j + 3);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src + j + 4);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_loadu_ps(src + j + 5);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_loadu_ps(src + j + 6);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_loadu_ps(src + j + 7);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_loadu_ps(src + j + 8);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
_mm256_store_ps(dst + j + 8, accum); // radius = 8
}
// Evaluate filter taps 9-16
f0 = _mm256_broadcast_ss(filter + 9);
f1 = _mm256_broadcast_ss(filter + 10);
f2 = _mm256_broadcast_ss(filter + 11);
f3 = _mm256_broadcast_ss(filter + 12);
f4 = _mm256_broadcast_ss(filter + 13);
f5 = _mm256_broadcast_ss(filter + 14);
f6 = _mm256_broadcast_ss(filter + 15);
f7 = _mm256_broadcast_ss(filter + 16);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
float *dst_ptr = dst + j + 8; // radius = 8
g = _mm256_loadu_ps(src + j + 9);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src + j + 10);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src + j + 11);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src + j + 12);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src + j + 13);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_loadu_ps(src + j + 14);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_loadu_ps(src + j + 15);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_loadu_ps(src + j + 16);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
sum0 = _mm256_add_ps(_mm256_load_ps(dst_ptr), sum0);
_mm256_store_ps(dst_ptr, sum0);
}
}
void convolution_f32_avx_s_1d_h_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_loadu_ps(src + j + 0);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src + j + 1);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src + j + 2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src + j + 3);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src + j + 4);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_loadu_ps(src + j + 5);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_loadu_ps(src + j + 6);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_loadu_ps(src + j + 7);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_loadu_ps(src + j + 8);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
_mm256_storeu_ps(dst + j + 4, accum); // radius = 4
}
}
void convolution_f32_avx_s_1d_h_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4;
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_loadu_ps(src + j + 0);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src + j + 1);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src + j + 2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src + j + 3);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src + j + 4);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
_mm256_storeu_ps(dst + j + 2, accum); // radius = 2
}
}
// Filter a single scanline.
static void convolution_f32_avx_s_1d_v_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{
if (N == 5)
{
convolution_f32_avx_s_1d_v_scanline_5(filter, filter_width, src, dst, src_stride, j_end);
}
else if (N == 9)
{
convolution_f32_avx_s_1d_v_scanline_9(filter, filter_width, src, dst, src_stride, j_end);
}
else if (N == 17)
{
convolution_f32_avx_s_1d_v_scanline_17(filter, filter_width, src, dst, src_stride, j_end);
}
else {
int radius = filter_width / 2;
src -= radius * src_stride;
for (int y = 0; y < filter_width; y += 9) {
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
f0 = _mm256_setzero_ps();
f1 = _mm256_setzero_ps();
f2 = _mm256_setzero_ps();
f3 = _mm256_setzero_ps();
f5 = _mm256_setzero_ps();
f6 = _mm256_setzero_ps();
f7 = _mm256_setzero_ps();
f8 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
f8 = _mm256_broadcast_ss(filter + y + 8);
// fall through
case 8:
f7 = _mm256_broadcast_ss(filter + y + 7);
// fall through
case 7:
f6 = _mm256_broadcast_ss(filter + y + 6);
// fall through
case 6:
f5 = _mm256_broadcast_ss(filter + y + 5);
// fall through
case 5:
f4 = _mm256_broadcast_ss(filter + y + 4);
// fall through
case 4:
f3 = _mm256_broadcast_ss(filter + y + 3);
// fall through
case 3:
f2 = _mm256_broadcast_ss(filter + y + 2);
// fall through
case 2:
f1 = _mm256_broadcast_ss(filter + y + 1);
// fall through
case 1:
f0 = _mm256_broadcast_ss(filter + y + 0);
// fall through
}
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
sum0 = _mm256_setzero_ps();
sum1 = _mm256_setzero_ps();
sum2 = _mm256_setzero_ps();
sum3 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
g = _mm256_load_ps(src + (y + 8) * src_stride + j);
sum0 = _mm256_mul_ps(f8, g);
// fall through
case 8:
g = _mm256_load_ps(src + (y + 7) * src_stride + j);
sum3 = _mm256_mul_ps(f7, g);
// fall through
case 7:
g = _mm256_load_ps(src + (y + 6) * src_stride + j);
sum2 = _mm256_mul_ps(f6, g);
// fall through
case 6:
g = _mm256_load_ps(src + (y + 5) * src_stride + j);
sum1 = _mm256_mul_ps(f5, g);
// fall through
case 5:
g = _mm256_load_ps(src + (y + 4) * src_stride + j);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
case 4:
g = _mm256_load_ps(src + (y + 3) * src_stride + j);
g = _mm256_mul_ps(f3, g);
sum3 = _mm256_add_ps(sum3, g);
// fall through
case 3:
g = _mm256_load_ps(src + (y + 2) * src_stride + j);
g = _mm256_mul_ps(f2, g);
sum2 = _mm256_add_ps(sum2, g);
// fall through
case 2:
g = _mm256_load_ps(src + (y + 1) * src_stride + j);
g = _mm256_mul_ps(f1, g);
sum1 = _mm256_add_ps(sum1, g);
// fall through
case 1:
g = _mm256_load_ps(src + (y + 0) * src_stride + j);
g = _mm256_mul_ps(f0, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
}
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
if (y)
accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j));
_mm256_store_ps(dst + j, accum);
}
}
}
}
void convolution_f32_avx_s_1d_v_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
src -= 8 * src_stride; // radius = 8
// Evaluate filter taps 0-8
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_load_ps(src + 0 * src_stride + j);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src + 1 * src_stride + j);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src + 2 * src_stride + j);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src + 3 * src_stride + j);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src + 4 * src_stride + j);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_load_ps(src + 5 * src_stride + j);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_load_ps(src + 6 * src_stride + j);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_load_ps(src + 7 * src_stride + j);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_load_ps(src + 8 * src_stride + j);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
_mm256_store_ps(dst + j, sum0);
}
// Evaluate filter taps 9-16
f0 = _mm256_broadcast_ss(filter + 9);
f1 = _mm256_broadcast_ss(filter + 10);
f2 = _mm256_broadcast_ss(filter + 11);
f3 = _mm256_broadcast_ss(filter + 12);
f4 = _mm256_broadcast_ss(filter + 13);
f5 = _mm256_broadcast_ss(filter + 14);
f6 = _mm256_broadcast_ss(filter + 15);
f7 = _mm256_broadcast_ss(filter + 16);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_load_ps(src + 9 * src_stride + j);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src + 10 * src_stride + j);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src + 11 * src_stride + j);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src + 12 * src_stride + j);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src + 13 * src_stride + j);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_load_ps(src + 14 * src_stride + j);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_load_ps(src + 15 * src_stride + j);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_load_ps(src + 16 * src_stride + j);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
sum0 = _mm256_add_ps(_mm256_load_ps(dst + j), sum0);
_mm256_store_ps(dst + j, sum0);
}
}
void convolution_f32_avx_s_1d_v_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
src -= 4 * src_stride; // radius = 4
// Evaluate filter taps 0-8
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_load_ps(src + 0 * src_stride + j);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src + 1 * src_stride + j);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src + 2 * src_stride + j);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src + 3 * src_stride + j);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src + 4 * src_stride + j);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_load_ps(src + 5 * src_stride + j);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_load_ps(src + 6 * src_stride + j);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_load_ps(src + 7 * src_stride + j);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_load_ps(src + 8 * src_stride + j);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
_mm256_store_ps(dst + j, sum0);
}
}
void convolution_f32_avx_s_1d_v_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4;
src -= 2 * src_stride; // radius = 2
// Evaluate filter taps 0-5
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_load_ps(src + 0 * src_stride + j);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src + 1 * src_stride + j);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src + 2 * src_stride + j);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src + 3 * src_stride + j);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src + 4 * src_stride + j);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
_mm256_store_ps(dst + j, sum0);
}
}
void convolution_f32_avx_s_1d(
int N,
const float * RESTRICT filter,
int filter_width,
const float * RESTRICT src,
float * RESTRICT dst,
float * RESTRICT tmp,
int width,
int height,
int src_stride,
int dst_stride)
{
int radius = filter_width / 2;
int width_mod8 = vmaf_floorn(width, 8);
int tmp_stride = vmaf_ceiln(width, 8);
int i_vec_end = height - radius;
int j_vec_end = width_mod8 - vmaf_ceiln(radius + 1, 8);
// Vertical pass.
for (int i = 0; i < radius; ++i) {
for (int j = 0; j < width; ++j) {
tmp[i * tmp_stride + j] = convolution_edge_s(false, filter, filter_width, src, width, height, src_stride, i, j);
}
}
for (int i = radius; i < i_vec_end; ++i) {
convolution_f32_avx_s_1d_v_scanline(N, filter, filter_width, src + i * src_stride, tmp + i * tmp_stride, src_stride, width_mod8);
for (int j = width_mod8; j < width; ++j) {
tmp[i * tmp_stride + j] = convolution_edge_s(false, filter, filter_width, src, width, height, src_stride, i, j);
}
}
for (int i = i_vec_end; i < height; ++i) {
for (int j = 0; j < width; ++j) {
tmp[i * tmp_stride + j] = convolution_edge_s(false, filter, filter_width, src, width, height, src_stride, i, j);
}
}
// Horizontal pass.
for (int i = 0; i < height; ++i) {
for (int j = 0; j < radius; ++j) {
dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j);
}
convolution_f32_avx_s_1d_h_scanline(N, filter, filter_width, tmp + i * tmp_stride, dst + i * dst_stride, j_vec_end);
for (int j = j_vec_end + radius; j < width; ++j) {
dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j);
}
}
}
void convolution_f32_avx_s(const float *filter, int filter_width, const float *src, float *dst, float *tmp, int width, int height, int src_stride, int dst_stride)
{
switch (filter_width) {
case 17:
convolution_f32_avx_s_1d(17, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
case 9:
convolution_f32_avx_s_1d(9, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
case 5:
convolution_f32_avx_s_1d(5, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
case 3:
convolution_f32_avx_s_1d(3, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
default:
convolution_f32_avx_s_1d(0, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
}
}
void convolution_f32_avx_s_1d_h_sq_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
// Evaluate filter taps 0-8
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_loadu_ps(src + j + 0);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src + j + 1);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src + j + 2);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src + j + 3);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src + j + 4);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_loadu_ps(src + j + 5);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_loadu_ps(src + j + 6);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_loadu_ps(src + j + 7);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_loadu_ps(src + j + 8);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
_mm256_store_ps(dst + j + 8, accum); // radius = 8
}
// Evaluate filter taps 9-16
f0 = _mm256_broadcast_ss(filter + 9);
f1 = _mm256_broadcast_ss(filter + 10);
f2 = _mm256_broadcast_ss(filter + 11);
f3 = _mm256_broadcast_ss(filter + 12);
f4 = _mm256_broadcast_ss(filter + 13);
f5 = _mm256_broadcast_ss(filter + 14);
f6 = _mm256_broadcast_ss(filter + 15);
f7 = _mm256_broadcast_ss(filter + 16);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
float *dst_ptr = dst + j + 8; // radius = 8
g = _mm256_loadu_ps(src + j + 9);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src + j + 10);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src + j + 11);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src + j + 12);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src + j + 13);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_loadu_ps(src + j + 14);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_loadu_ps(src + j + 15);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_loadu_ps(src + j + 16);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
sum0 = _mm256_add_ps(_mm256_load_ps(dst_ptr), sum0);
_mm256_store_ps(dst_ptr, sum0);
}
}
void convolution_f32_avx_s_1d_h_sq_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_loadu_ps(src + j + 0);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src + j + 1);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src + j + 2);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src + j + 3);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src + j + 4);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_loadu_ps(src + j + 5);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_loadu_ps(src + j + 6);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_loadu_ps(src + j + 7);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_loadu_ps(src + j + 8);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
_mm256_storeu_ps(dst + j + 4, accum); // radius = 4
}
}
void convolution_f32_avx_s_1d_h_sq_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4;
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_loadu_ps(src + j + 0);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src + j + 1);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src + j + 2);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src + j + 3);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src + j + 4);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
_mm256_storeu_ps(dst + j + 2, accum); // radius = 2
}
}
// Filter a single scanline.
static void convolution_f32_avx_s_1d_v_sq_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{
if (N == 5)
{
convolution_f32_avx_s_1d_v_sq_scanline_5(filter, filter_width, src, dst, src_stride, j_end);
}
else if (N == 9)
{
convolution_f32_avx_s_1d_v_sq_scanline_9(filter, filter_width, src, dst, src_stride, j_end);
}
else if (N == 17)
{
convolution_f32_avx_s_1d_v_sq_scanline_17(filter, filter_width, src, dst, src_stride, j_end);
}
else {
int radius = filter_width / 2;
src -= radius * src_stride;
for (int y = 0; y < filter_width; y += 9) {
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
f0 = _mm256_setzero_ps();
f1 = _mm256_setzero_ps();
f2 = _mm256_setzero_ps();
f3 = _mm256_setzero_ps();
f5 = _mm256_setzero_ps();
f6 = _mm256_setzero_ps();
f7 = _mm256_setzero_ps();
f8 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
f8 = _mm256_broadcast_ss(filter + y + 8);
// fall through
case 8:
f7 = _mm256_broadcast_ss(filter + y + 7);
// fall through
case 7:
f6 = _mm256_broadcast_ss(filter + y + 6);
// fall through
case 6:
f5 = _mm256_broadcast_ss(filter + y + 5);
// fall through
case 5:
f4 = _mm256_broadcast_ss(filter + y + 4);
// fall through
case 4:
f3 = _mm256_broadcast_ss(filter + y + 3);
// fall through
case 3:
f2 = _mm256_broadcast_ss(filter + y + 2);
// fall through
case 2:
f1 = _mm256_broadcast_ss(filter + y + 1);
// fall through
case 1:
f0 = _mm256_broadcast_ss(filter + y + 0);
// fall through
}
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g;
sum0 = _mm256_setzero_ps();
sum1 = _mm256_setzero_ps();
sum2 = _mm256_setzero_ps();
sum3 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
g = _mm256_load_ps(src + (y + 8) * src_stride + j);
g = _mm256_mul_ps(g, g);
sum0 = _mm256_mul_ps(f8, g);
// fall through
case 8:
g = _mm256_load_ps(src + (y + 7) * src_stride + j);
g = _mm256_mul_ps(g, g);
sum3 = _mm256_mul_ps(f7, g);
// fall through
case 7:
g = _mm256_load_ps(src + (y + 6) * src_stride + j);
g = _mm256_mul_ps(g, g);
sum2 = _mm256_mul_ps(f6, g);
// fall through
case 6:
g = _mm256_load_ps(src + (y + 5) * src_stride + j);
g = _mm256_mul_ps(g, g);
sum1 = _mm256_mul_ps(f5, g);
// fall through
case 5:
g = _mm256_load_ps(src + (y + 4) * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
case 4:
g = _mm256_load_ps(src + (y + 3) * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f3, g);
sum3 = _mm256_add_ps(sum3, g);
// fall through
case 3:
g = _mm256_load_ps(src + (y + 2) * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f2, g);
sum2 = _mm256_add_ps(sum2, g);
// fall through
case 2:
g = _mm256_load_ps(src + (y + 1) * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f1, g);
sum1 = _mm256_add_ps(sum1, g);
// fall through
case 1:
g = _mm256_load_ps(src + (y + 0) * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f0, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
}
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
if (y)
accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j));
_mm256_store_ps(dst + j, accum);
}
}
}
}
void convolution_f32_avx_s_1d_v_sq_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
src -= 8 * src_stride; // radius = 8
// Evaluate filter taps 0-8
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_load_ps(src + 0 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src + 1 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src + 2 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src + 3 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src + 4 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_load_ps(src + 5 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_load_ps(src + 6 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_load_ps(src + 7 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_load_ps(src + 8 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
_mm256_store_ps(dst + j, sum0);
}
// Evaluate filter taps 9-16
f0 = _mm256_broadcast_ss(filter + 9);
f1 = _mm256_broadcast_ss(filter + 10);
f2 = _mm256_broadcast_ss(filter + 11);
f3 = _mm256_broadcast_ss(filter + 12);
f4 = _mm256_broadcast_ss(filter + 13);
f5 = _mm256_broadcast_ss(filter + 14);
f6 = _mm256_broadcast_ss(filter + 15);
f7 = _mm256_broadcast_ss(filter + 16);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_load_ps(src + 9 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src + 10 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src + 11 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src + 12 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src + 13 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_load_ps(src + 14 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_load_ps(src + 15 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_load_ps(src + 16 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
sum0 = _mm256_add_ps(_mm256_load_ps(dst + j), sum0);
_mm256_store_ps(dst + j, sum0);
}
}
void convolution_f32_avx_s_1d_v_sq_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
src -= 4 * src_stride; // radius = 4
// Evaluate filter taps 0-8
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_load_ps(src + 0 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src + 1 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src + 2 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src + 3 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src + 4 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_load_ps(src + 5 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_load_ps(src + 6 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_load_ps(src + 7 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_load_ps(src + 8 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
_mm256_store_ps(dst + j, sum0);
}
}
void convolution_f32_avx_s_1d_v_sq_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4;
src -= 2 * src_stride; // radius = 2
// Evaluate filter taps 0-5
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g;
g = _mm256_load_ps(src + 0 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src + 1 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src + 2 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src + 3 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src + 4 * src_stride + j);
g = _mm256_mul_ps(g, g);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
_mm256_store_ps(dst + j, sum0);
}
}
void convolution_f32_avx_s_1d_sq(
int N,
const float * RESTRICT filter,
int filter_width,
const float * RESTRICT src,
float * RESTRICT dst,
float * RESTRICT tmp,
int width,
int height,
int src_stride,
int dst_stride)
{
int radius = filter_width / 2;
int width_mod8 = vmaf_floorn(width, 8);
int tmp_stride = vmaf_ceiln(width, 8);
int i_vec_end = height - radius;
int j_vec_end = width_mod8 - vmaf_ceiln(radius + 1, 8);
// Vertical pass.
for (int i = 0; i < radius; ++i) {
for (int j = 0; j < width; ++j) {
tmp[i * tmp_stride + j] = convolution_edge_sq_s(false, filter, filter_width, src, width, height, src_stride, i, j);
}
}
for (int i = radius; i < i_vec_end; ++i) {
convolution_f32_avx_s_1d_v_sq_scanline(N, filter, filter_width, src + i * src_stride, tmp + i * tmp_stride, src_stride, width_mod8);
for (int j = width_mod8; j < width; ++j) {
tmp[i * tmp_stride + j] = convolution_edge_sq_s(false, filter, filter_width, src, width, height, src_stride, i, j);
}
}
for (int i = i_vec_end; i < height; ++i) {
for (int j = 0; j < width; ++j) {
tmp[i * tmp_stride + j] = convolution_edge_sq_s(false, filter, filter_width, src, width, height, src_stride, i, j);
}
}
// Horizontal pass.
for (int i = 0; i < height; ++i) {
for (int j = 0; j < radius; ++j) {
dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j);
}
convolution_f32_avx_s_1d_h_scanline(N, filter, filter_width, tmp + i * tmp_stride, dst + i * dst_stride, j_vec_end);
for (int j = j_vec_end + radius; j < width; ++j) {
dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j);
}
}
}
void convolution_f32_avx_sq_s(const float *filter, int filter_width, const float *src, float *dst, float *tmp, int width, int height, int src_stride, int dst_stride)
{
switch (filter_width) {
case 17:
convolution_f32_avx_s_1d_sq(17, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
case 9:
convolution_f32_avx_s_1d_sq(9, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
case 5:
convolution_f32_avx_s_1d_sq(5, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
case 3:
convolution_f32_avx_s_1d_sq(3, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
default:
convolution_f32_avx_s_1d_sq(0, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride);
break;
}
}
void convolution_f32_avx_s_1d_h_xy_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
// Evaluate filter taps 0-8
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
g = _mm256_loadu_ps(src1 + j + 0);
g2 = _mm256_loadu_ps(src2 + j + 0);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src1 + j + 1);
g2 = _mm256_loadu_ps(src2 + j + 1);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src1 + j + 2);
g2 = _mm256_loadu_ps(src2 + j + 2);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src1 + j + 3);
g2 = _mm256_loadu_ps(src2 + j + 3);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src1 + j + 4);
g2 = _mm256_loadu_ps(src2 + j + 4);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_loadu_ps(src1 + j + 5);
g2 = _mm256_loadu_ps(src2 + j + 5);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_loadu_ps(src1 + j + 6);
g2 = _mm256_loadu_ps(src2 + j + 6);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_loadu_ps(src1 + j + 7);
g2 = _mm256_loadu_ps(src2 + j + 7);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_loadu_ps(src1 + j + 8);
g2 = _mm256_loadu_ps(src2 + j + 8);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
_mm256_store_ps(dst + j + 8, accum); // radius = 8
}
// Evaluate filter taps 9-16
f0 = _mm256_broadcast_ss(filter + 9);
f1 = _mm256_broadcast_ss(filter + 10);
f2 = _mm256_broadcast_ss(filter + 11);
f3 = _mm256_broadcast_ss(filter + 12);
f4 = _mm256_broadcast_ss(filter + 13);
f5 = _mm256_broadcast_ss(filter + 14);
f6 = _mm256_broadcast_ss(filter + 15);
f7 = _mm256_broadcast_ss(filter + 16);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
float *dst_ptr = dst + j + 8; // radius = 8
g = _mm256_loadu_ps(src1 + j + 9);
g2 = _mm256_loadu_ps(src2 + j + 9);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src1 + j + 10);
g2 = _mm256_loadu_ps(src2 + j + 10);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src1 + j + 11);
g2 = _mm256_loadu_ps(src2 + j + 11);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src1 + j + 12);
g2 = _mm256_loadu_ps(src2 + j + 12);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src1 + j + 13);
g2 = _mm256_loadu_ps(src2 + j + 13);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_loadu_ps(src1 + j + 14);
g2 = _mm256_loadu_ps(src2 + j + 14);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_loadu_ps(src1 + j + 15);
g2 = _mm256_loadu_ps(src2 + j + 15);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_loadu_ps(src1 + j + 16);
g2 = _mm256_loadu_ps(src2 + j + 16);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
sum0 = _mm256_add_ps(_mm256_load_ps(dst_ptr), sum0);
_mm256_store_ps(dst_ptr, sum0);
}
}
void convolution_f32_avx_s_1d_h_xy_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
g = _mm256_loadu_ps(src1 + j + 0);
g2 = _mm256_loadu_ps(src2 + j + 0);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src1 + j + 1);
g2 = _mm256_loadu_ps(src2 + j + 1);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src1 + j + 2);
g2 = _mm256_loadu_ps(src2 + j + 2);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src1 + j + 3);
g2 = _mm256_loadu_ps(src2 + j + 3);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src1 + j + 4);
g2 = _mm256_loadu_ps(src2 + j + 4);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_loadu_ps(src1 + j + 5);
g2 = _mm256_loadu_ps(src2 + j + 5);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_loadu_ps(src1 + j + 6);
g2 = _mm256_loadu_ps(src2 + j + 6);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_loadu_ps(src1 + j + 7);
g2 = _mm256_loadu_ps(src2 + j + 7);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_loadu_ps(src1 + j + 8);
g2 = _mm256_loadu_ps(src2 + j + 8);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
_mm256_storeu_ps(dst + j + 4, accum); // radius = 4
}
}
void convolution_f32_avx_s_1d_h_xy_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4;
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
g = _mm256_loadu_ps(src1 + j + 0);
g2 = _mm256_loadu_ps(src2 + j + 0);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_loadu_ps(src1 + j + 1);
g2 = _mm256_loadu_ps(src2 + j + 1);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_loadu_ps(src1 + j + 2);
g2 = _mm256_loadu_ps(src2 + j + 2);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_loadu_ps(src1 + j + 3);
g2 = _mm256_loadu_ps(src2 + j + 3);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_loadu_ps(src1 + j + 4);
g2 = _mm256_loadu_ps(src2 + j + 4);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
_mm256_storeu_ps(dst + j + 2, accum); // radius = 2
}
}
// Filter a single scanline.
static void convolution_f32_avx_s_1d_v_xy_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end)
{
if (N == 5)
{
convolution_f32_avx_s_1d_v_xy_scanline_5(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end);
}
else if (N == 9)
{
convolution_f32_avx_s_1d_v_xy_scanline_9(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end);
}
else if (N == 17)
{
convolution_f32_avx_s_1d_v_xy_scanline_17(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end);
}
else {
int radius = filter_width / 2;
src1 -= radius * src1_stride;
src2 -= radius * src2_stride;
for (int y = 0; y < filter_width; y += 9) {
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
f0 = _mm256_setzero_ps();
f1 = _mm256_setzero_ps();
f2 = _mm256_setzero_ps();
f3 = _mm256_setzero_ps();
f5 = _mm256_setzero_ps();
f6 = _mm256_setzero_ps();
f7 = _mm256_setzero_ps();
f8 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
f8 = _mm256_broadcast_ss(filter + y + 8);
// fall through
case 8:
f7 = _mm256_broadcast_ss(filter + y + 7);
// fall through
case 7:
f6 = _mm256_broadcast_ss(filter + y + 6);
// fall through
case 6:
f5 = _mm256_broadcast_ss(filter + y + 5);
// fall through
case 5:
f4 = _mm256_broadcast_ss(filter + y + 4);
// fall through
case 4:
f3 = _mm256_broadcast_ss(filter + y + 3);
// fall through
case 3:
f2 = _mm256_broadcast_ss(filter + y + 2);
// fall through
case 2:
f1 = _mm256_broadcast_ss(filter + y + 1);
// fall through
case 1:
f0 = _mm256_broadcast_ss(filter + y + 0);
// fall through
}
for (int j = 0; j < j_end; j += 8) {
__m256 accum = _mm256_setzero_ps();
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
sum0 = _mm256_setzero_ps();
sum1 = _mm256_setzero_ps();
sum2 = _mm256_setzero_ps();
sum3 = _mm256_setzero_ps();
switch (filter_width - y) {
default:
g = _mm256_load_ps(src1 + (y + 8) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 8) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
sum0 = _mm256_mul_ps(f8, g);
// fall through
case 8:
g = _mm256_load_ps(src1 + (y + 7) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 7) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
sum3 = _mm256_mul_ps(f7, g);
// fall through
case 7:
g = _mm256_load_ps(src1 + (y + 6) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 6) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
sum2 = _mm256_mul_ps(f6, g);
// fall through
case 6:
g = _mm256_load_ps(src1 + (y + 5) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 5) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
sum1 = _mm256_mul_ps(f5, g);
// fall through
case 5:
g = _mm256_load_ps(src1 + (y + 4) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 4) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
case 4:
g = _mm256_load_ps(src1 + (y + 3) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 3) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = _mm256_add_ps(sum3, g);
// fall through
case 3:
g = _mm256_load_ps(src1 + (y + 2) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 2) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = _mm256_add_ps(sum2, g);
// fall through
case 2:
g = _mm256_load_ps(src1 + (y + 1) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 1) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = _mm256_add_ps(sum1, g);
// fall through
case 1:
g = _mm256_load_ps(src1 + (y + 0) * src1_stride + j);
g2 = _mm256_load_ps(src2 + (y + 0) * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = _mm256_add_ps(sum0, g);
// fall through
}
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
accum = _mm256_add_ps(accum, sum0);
if (y)
accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j));
_mm256_store_ps(dst + j, accum);
}
}
}
}
void convolution_f32_avx_s_1d_v_xy_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
src1 -= 8 * src1_stride; // radius = 8
src2 -= 8 * src2_stride; // radius = 8
// Evaluate filter taps 0-8
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
g = _mm256_load_ps(src1 + 0 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 0 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src1 + 1 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 1 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src1 + 2 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 2 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src1 + 3 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 3 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src1 + 4 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 4 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_load_ps(src1 + 5 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 5 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_load_ps(src1 + 6 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 6 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_load_ps(src1 + 7 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 7 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_load_ps(src1 + 8 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 8 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
_mm256_store_ps(dst + j, sum0);
}
// Evaluate filter taps 9-16
f0 = _mm256_broadcast_ss(filter + 9);
f1 = _mm256_broadcast_ss(filter + 10);
f2 = _mm256_broadcast_ss(filter + 11);
f3 = _mm256_broadcast_ss(filter + 12);
f4 = _mm256_broadcast_ss(filter + 13);
f5 = _mm256_broadcast_ss(filter + 14);
f6 = _mm256_broadcast_ss(filter + 15);
f7 = _mm256_broadcast_ss(filter + 16);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
g = _mm256_load_ps(src1 + 9 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 9 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src1 + 10 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 10 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src1 + 11 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 11 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src1 + 12 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 12 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src1 + 13 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 13 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_load_ps(src1 + 14 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 14 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_load_ps(src1 + 15 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 15 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_load_ps(src1 + 16 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 16 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
sum0 = _mm256_add_ps(_mm256_load_ps(dst + j), sum0);
_mm256_store_ps(dst + j, sum0);
}
}
void convolution_f32_avx_s_1d_v_xy_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4, f5, f6, f7, f8;
src1 -= 4 * src1_stride; // radius = 4
src2 -= 4 * src2_stride; // radius = 4
// Evaluate filter taps 0-8
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
f5 = _mm256_broadcast_ss(filter + 5);
f6 = _mm256_broadcast_ss(filter + 6);
f7 = _mm256_broadcast_ss(filter + 7);
f8 = _mm256_broadcast_ss(filter + 8);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
g = _mm256_load_ps(src1 + 0 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 0 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src1 + 1 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 1 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src1 + 2 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 2 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src1 + 3 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 3 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src1 + 4 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 4 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
g = _mm256_load_ps(src1 + 5 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 5 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f5, g);
sum1 = _mm256_add_ps(sum1, g);
g = _mm256_load_ps(src1 + 6 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 6 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f6, g);
sum2 = _mm256_add_ps(sum2, g);
g = _mm256_load_ps(src1 + 7 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 7 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f7, g);
sum3 = _mm256_add_ps(sum3, g);
g = _mm256_load_ps(src1 + 8 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 8 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f8, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
_mm256_store_ps(dst + j, sum0);
}
}
void convolution_f32_avx_s_1d_v_xy_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end)
{
(void) filter_width;
__m256 f0, f1, f2, f3, f4;
src1 -= 2 * src1_stride; // radius = 2
src2 -= 2 * src2_stride; // radius = 2
// Evaluate filter taps 0-5
f0 = _mm256_broadcast_ss(filter + 0);
f1 = _mm256_broadcast_ss(filter + 1);
f2 = _mm256_broadcast_ss(filter + 2);
f3 = _mm256_broadcast_ss(filter + 3);
f4 = _mm256_broadcast_ss(filter + 4);
for (int j = 0; j < j_end; j += 8) {
__m256 sum0, sum1, sum2, sum3;
__m256 g, g2;
g = _mm256_load_ps(src1 + 0 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 0 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f0, g);
sum0 = g;
g = _mm256_load_ps(src1 + 1 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 1 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f1, g);
sum1 = g;
g = _mm256_load_ps(src1 + 2 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 2 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f2, g);
sum2 = g;
g = _mm256_load_ps(src1 + 3 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 3 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f3, g);
sum3 = g;
g = _mm256_load_ps(src1 + 4 * src1_stride + j);
g2 = _mm256_load_ps(src2 + 4 * src2_stride + j);
g = _mm256_mul_ps(g, g2);
g = _mm256_mul_ps(f4, g);
sum0 = _mm256_add_ps(sum0, g);
sum0 = _mm256_add_ps(sum0, sum2);
sum1 = _mm256_add_ps(sum1, sum3);
sum0 = _mm256_add_ps(sum0, sum1);
_mm256_store_ps(dst + j, sum0);
}
}
void convolution_f32_avx_s_1d_xy(
int N,
const float * RESTRICT filter,
int filter_width,
const float * RESTRICT src1,
const float * RESTRICT src2,
float * RESTRICT dst,
float * RESTRICT tmp,
int width,
int height,
int src1_stride,
int src2_stride,
int dst_stride)
{
int radius = filter_width / 2;
int width_mod8 = vmaf_floorn(width, 8);
int tmp_stride = vmaf_ceiln(width, 8);
int i_vec_end = height - radius;
int j_vec_end = width_mod8 - vmaf_ceiln(radius + 1, 8);
// Vertical pass.
for (int i = 0; i < radius; ++i) {
for (int j = 0; j < width; ++j) {
tmp[i * tmp_stride + j] = convolution_edge_xy_s(false, filter, filter_width, src1, src2, width, height, src1_stride, src2_stride, i, j);
}
}
for (int i = radius; i < i_vec_end; ++i) {
convolution_f32_avx_s_1d_v_xy_scanline(N, filter, filter_width, src1 + i * src1_stride, src2 + i * src2_stride, tmp + i * tmp_stride, src1_stride, src2_stride, width_mod8);
for (int j = width_mod8; j < width; ++j) {
tmp[i * tmp_stride + j] = convolution_edge_xy_s(false, filter, filter_width, src1, src2, width, height, src1_stride, src2_stride, i, j);
}
}
for (int i = i_vec_end; i < height; ++i) {
for (int j = 0; j < width; ++j) {
tmp[i * tmp_stride + j] = convolution_edge_xy_s(false, filter, filter_width, src1, src2, width, height, src1_stride, src2_stride, i, j);
}
}
// Horizontal pass.
for (int i = 0; i < height; ++i) {
for (int j = 0; j < radius; ++j) {
dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j);
}
convolution_f32_avx_s_1d_h_scanline(N, filter, filter_width, tmp + i * tmp_stride, dst + i * dst_stride, j_vec_end);
for (int j = j_vec_end + radius; j < width; ++j) {
dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j);
}
}
}
void convolution_f32_avx_xy_s(const float *filter, int filter_width, const float *src1, const float *src2, float *dst, float *tmp, int width, int height, int src1_stride, int src2_stride, int dst_stride)
{
switch (filter_width) {
case 17:
convolution_f32_avx_s_1d_xy(17, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride);
break;
case 9:
convolution_f32_avx_s_1d_xy(9, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride);
break;
case 5:
convolution_f32_avx_s_1d_xy(5, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride);
break;
case 3:
convolution_f32_avx_s_1d_xy(3, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride);
break;
default:
convolution_f32_avx_s_1d_xy(0, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride);
break;
}
}