void vif_statistic_8_avx512()

in libvmaf/src/feature/x86/vif_avx512.c [144:385]


void vif_statistic_8_avx512(struct VifPublicState *s, float *num, float *den, unsigned w, unsigned h) {
    const unsigned fwidth = vif_filter1d_width[0];
    const uint16_t *vif_filt = vif_filter1d_table[0];
    VifBuffer buf = s->buf;
    const uint8_t *ref = (uint8_t*)buf.ref;
    const uint8_t *dis = (uint8_t*)buf.dis;
    const unsigned fwidth_half = fwidth >> 1;
    const uint16_t *log2_table = s->log2_table;
    double vif_enhn_gain_limit = s->vif_enhn_gain_limit;

#if defined __GNUC__
#define ALIGNED(x) __attribute__ ((aligned (x)))
#elif defined (_MSC_VER)  && (!defined UNDER_CE)
#define ALIGNED(x) __declspec (align(x))
#else
#define ALIGNED(x)
#endif

    int64_t accum_num_log = 0;
    int64_t accum_den_log = 0;
    int64_t accum_num_non_log = 0;
    int64_t accum_den_non_log = 0;

    __m512i round_128 = _mm512_set1_epi32(128);
    __m512i mask2 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);

    Residuals512 residuals;
    residuals.maccum_den_log = _mm512_setzero_si512();
    residuals.maccum_num_log = _mm512_setzero_si512();
    residuals.maccum_den_non_log = _mm512_setzero_si512();
    residuals.maccum_num_non_log = _mm512_setzero_si512();
    for (unsigned i = 0; i < h; ++i)
    {
        // VERTICAL
        int i_back = i - fwidth_half;
        int i_forward = i + fwidth_half;

        // First consider all blocks of 16 elements until it's not possible anymore
        unsigned n = w >> 4;
        for (unsigned jj = 0; jj < n << 4; jj += 16) {

            __m512i f0 = _mm512_set1_epi32(vif_filt[fwidth / 2]);
            __m512i r0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((__m128i*)(ref + (buf.stride * i) + jj)));
            __m512i d0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((__m128i*)(dis + (buf.stride * i) + jj)));

            // filtered r,d
            __m512i accum_mu1 = _mm512_mullo_epi32(r0, f0);
            __m512i accum_mu2 = _mm512_mullo_epi32(d0, f0);
            __m512i accum_ref = _mm512_mullo_epi32(f0, _mm512_mullo_epi32(r0, r0));
            __m512i accum_dis = _mm512_mullo_epi32(f0, _mm512_mullo_epi32(d0, d0));
            __m512i accum_ref_dis = _mm512_mullo_epi32(f0, _mm512_mullo_epi32(r0, d0));

            for (unsigned int tap = 0; tap < fwidth / 2; tap++) {
                int ii_back = i_back + tap;
                int ii_forward = i_forward - tap;

                __m512i f0 = _mm512_set1_epi32(vif_filt[tap]);
                __m512i r0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((__m128i*)(ref + (buf.stride * ii_back) + jj)));
                __m512i d0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((__m128i*)(dis + (buf.stride * ii_back) + jj)));
                __m512i r1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((__m128i*)(ref + (buf.stride * ii_forward) + jj)));
                __m512i d1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((__m128i*)(dis + (buf.stride * ii_forward) + jj)));

                accum_mu1 = _mm512_add_epi32(accum_mu1, _mm512_mullo_epi32(_mm512_add_epi32(r0, r1), f0));
                accum_mu2 = _mm512_add_epi32(accum_mu2, _mm512_mullo_epi32(_mm512_add_epi32(d0, d1), f0));
                accum_ref = _mm512_add_epi32(accum_ref, _mm512_mullo_epi32(f0, _mm512_add_epi32(_mm512_mullo_epi32(r0, r0), _mm512_mullo_epi32(r1, r1))));
                accum_dis = _mm512_add_epi32(accum_dis, _mm512_mullo_epi32(f0, _mm512_add_epi32(_mm512_mullo_epi32(d0, d0), _mm512_mullo_epi32(d1, d1))));
                accum_ref_dis = _mm512_add_epi32(accum_ref_dis, _mm512_mullo_epi32(f0, _mm512_add_epi32(_mm512_mullo_epi32(d0, r0), _mm512_mullo_epi32(d1, r1))));
            }
            accum_mu1 = _mm512_add_epi32(accum_mu1, round_128);
            accum_mu2 = _mm512_add_epi32(accum_mu2, round_128);
            accum_mu1 = _mm512_srli_epi32(accum_mu1, 0x08);
            accum_mu2 = _mm512_srli_epi32(accum_mu2, 0x08);

            _mm512_storeu_si512((__m512i*)(buf.tmp.mu1 + jj), accum_mu1);
            _mm512_storeu_si512((__m512i*)(buf.tmp.mu2 + jj), accum_mu2);
            _mm512_storeu_si512((__m512i*)(buf.tmp.ref + jj), accum_ref);
            _mm512_storeu_si512((__m512i*)(buf.tmp.dis + jj), accum_dis);
            _mm512_storeu_si512((__m512i*)(buf.tmp.ref_dis + jj), accum_ref_dis);
        }
        // Then consider the remaining elements individually
        for (unsigned j = n << 4; j < w; ++j) {
            uint32_t accum_mu1 = 0;
            uint32_t accum_mu2 = 0;
            uint64_t accum_ref = 0;
            uint64_t accum_dis = 0;
            uint64_t accum_ref_dis = 0;

            for (unsigned fi = 0; fi < fwidth; ++fi) {
                int ii = i - fwidth_half;
                int ii_check = ii + fi;
                const uint16_t fcoeff = vif_filt[fi];
                uint16_t imgcoeff_ref = ref[ii_check * buf.stride + j];
                uint16_t imgcoeff_dis = dis[ii_check * buf.stride + j];
                uint32_t img_coeff_ref = fcoeff * (uint32_t)imgcoeff_ref;
                uint32_t img_coeff_dis = fcoeff * (uint32_t)imgcoeff_dis;
                accum_mu1 += img_coeff_ref;
                accum_mu2 += img_coeff_dis;
                accum_ref += img_coeff_ref * (uint64_t)imgcoeff_ref;
                accum_dis += img_coeff_dis * (uint64_t)imgcoeff_dis;
                accum_ref_dis += img_coeff_ref * (uint64_t)imgcoeff_dis;
            }

            buf.tmp.mu1[j] = (accum_mu1 + 128) >> 8;
            buf.tmp.mu2[j] = (accum_mu2 + 128) >> 8;
            buf.tmp.ref[j] = accum_ref;
            buf.tmp.dis[j] = accum_dis;
            buf.tmp.ref_dis[j] = accum_ref_dis;
        }

        PADDING_SQ_DATA(buf, w, fwidth_half);

        //HORIZONTAL
        for (unsigned j = 0; j < n << 4; j += 16) {
            __m512i mu1sq;
            __m512i mu2sq;
            __m512i mu1mu2;
            __m512i xx;
            __m512i yy;
            __m512i xy;
            __m512i mask5 = _mm512_set_epi32(30, 28, 14, 12, 26, 24, 10, 8, 22, 20, 6, 4, 18, 16, 2, 0);
            // compute mu1sq, mu2sq, mu1mu2
            {
                __m512i fq = _mm512_set1_epi32(vif_filt[fwidth / 2]);
                __m512i acc0 = _mm512_mullo_epi32(_mm512_loadu_si512((__m512i*)(buf.tmp.mu1 + j + 0)), fq);
                __m512i acc1 = _mm512_mullo_epi32(_mm512_loadu_si512((__m512i*)(buf.tmp.mu2 + j + 0)), fq);

                for (unsigned fj = 0; fj < fwidth / 2; ++fj) {
                    __m512i fq = _mm512_set1_epi32(vif_filt[fj]);
                    acc0 = _mm512_add_epi64(acc0, _mm512_mullo_epi32(_mm512_loadu_si512((__m512i*)(buf.tmp.mu1 + j - fwidth / 2 + fj + 0)), fq));
                    acc0 = _mm512_add_epi64(acc0, _mm512_mullo_epi32(_mm512_loadu_si512((__m512i*)(buf.tmp.mu1 + j + fwidth / 2 - fj + 0)), fq));
                    acc1 = _mm512_add_epi64(acc1, _mm512_mullo_epi32(_mm512_loadu_si512((__m512i*)(buf.tmp.mu2 + j - fwidth / 2 + fj + 0)), fq));
                    acc1 = _mm512_add_epi64(acc1, _mm512_mullo_epi32(_mm512_loadu_si512((__m512i*)(buf.tmp.mu2 + j + fwidth / 2 - fj + 0)), fq));
                }
                __m512i mu1 = acc0;
                __m512i acc0_lo_512 = _mm512_unpacklo_epi32(acc0, _mm512_setzero_si512());
                __m512i acc0_hi_512 = _mm512_unpackhi_epi32(acc0, _mm512_setzero_si512());
                acc0_lo_512 = _mm512_mul_epu32(acc0_lo_512, acc0_lo_512);
                acc0_hi_512 = _mm512_mul_epu32(acc0_hi_512, acc0_hi_512);
                acc0_lo_512 = _mm512_srli_epi64(_mm512_add_epi64(acc0_lo_512, _mm512_set1_epi64(0x80000000)), 32);
                acc0_hi_512 = _mm512_srli_epi64(_mm512_add_epi64(acc0_hi_512, _mm512_set1_epi64(0x80000000)), 32);
                mu1sq = _mm512_permutex2var_epi32(acc0_lo_512, mask5, acc0_hi_512);

                __m512i acc0lo_512 = _mm512_unpacklo_epi32(acc1, _mm512_setzero_si512());
                __m512i acc0hi_512 = _mm512_unpackhi_epi32(acc1, _mm512_setzero_si512());
                __m512i mu1lo_512 = _mm512_unpacklo_epi32(mu1, _mm512_setzero_si512());
                __m512i mu1hi_512 = _mm512_unpackhi_epi32(mu1, _mm512_setzero_si512());

                mu1lo_512 = _mm512_mul_epu32(mu1lo_512, acc0lo_512);
                mu1hi_512 = _mm512_mul_epu32(mu1hi_512, acc0hi_512);
                mu1lo_512 = _mm512_srli_epi64(_mm512_add_epi64(mu1lo_512, _mm512_set1_epi64(0x80000000)), 32);
                mu1hi_512 = _mm512_srli_epi64(_mm512_add_epi64(mu1hi_512, _mm512_set1_epi64(0x80000000)), 32);

                mu1mu2 = _mm512_permutex2var_epi32(mu1lo_512, mask5, mu1hi_512);
                acc0lo_512 = _mm512_mul_epu32(acc0lo_512, acc0lo_512);
                acc0hi_512 = _mm512_mul_epu32(acc0hi_512, acc0hi_512);
                acc0lo_512 = _mm512_srli_epi64(_mm512_add_epi64(acc0lo_512, _mm512_set1_epi64(0x80000000)), 32);
                acc0hi_512 = _mm512_srli_epi64(_mm512_add_epi64(acc0hi_512, _mm512_set1_epi64(0x80000000)), 32);
                mu2sq = _mm512_permutex2var_epi32(acc0lo_512, mask5, acc0hi_512);
            }

            // compute xx, yy, xy
            {
                __m512i rounder = _mm512_set1_epi64(0x8000);
                __m512i fq = _mm512_set1_epi64(vif_filt[fwidth / 2]);
                __m512i s0 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref + j + 0))); // 4
                __m512i s2 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref + j + 8))); // 4
                __m512i refsq_lo = _mm512_add_epi64(rounder, _mm512_mul_epu32(s0, fq));
                __m512i refsq_hi = _mm512_add_epi64(rounder, _mm512_mul_epu32(s2, fq));

                s0 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.dis + j + 0))); // 4
                s2 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.dis + j + 8))); // 4
                __m512i dissq_lo = _mm512_add_epi64(rounder, _mm512_mul_epu32(s0, fq));
                __m512i dissq_hi = _mm512_add_epi64(rounder, _mm512_mul_epu32(s2, fq));

                s0 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref_dis + j + 0))); // 4
                s2 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref_dis + j + 8))); // 4
                __m512i refdis_lo = _mm512_add_epi64(rounder, _mm512_mul_epu32(s0, fq));
                __m512i refdis_hi = _mm512_add_epi64(rounder, _mm512_mul_epu32(s2, fq));

                for (unsigned fj = 0; fj < fwidth / 2; ++fj) {
                    __m512i fq = _mm512_set1_epi64(vif_filt[fj]);
                    s0 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref + j - fwidth / 2 + fj + 0))); // 4
                    s2 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref + j - fwidth / 2 + fj + 8))); // 4
                    refsq_lo = _mm512_add_epi64(refsq_lo, _mm512_mul_epu32(s0, fq));
                    refsq_hi = _mm512_add_epi64(refsq_hi, _mm512_mul_epu32(s2, fq));
                    s0 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref + j + fwidth / 2 - fj + 0))); // 4
                    s2 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref + j + fwidth / 2 - fj + 8))); // 4
                    refsq_lo = _mm512_add_epi64(refsq_lo, _mm512_mul_epu32(s0, fq));
                    refsq_hi = _mm512_add_epi64(refsq_hi, _mm512_mul_epu32(s2, fq));

                    s0 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.dis + j - fwidth / 2 + fj + 0))); // 4
                    s2 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.dis + j - fwidth / 2 + fj + 8))); // 4
                    dissq_lo = _mm512_add_epi64(dissq_lo, _mm512_mul_epu32(s0, fq));
                    dissq_hi = _mm512_add_epi64(dissq_hi, _mm512_mul_epu32(s2, fq));
                    s0 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.dis + j + fwidth / 2 - fj + 0))); // 4
                    s2 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.dis + j + fwidth / 2 - fj + 8))); // 4
                    dissq_lo = _mm512_add_epi64(dissq_lo, _mm512_mul_epu32(s0, fq));
                    dissq_hi = _mm512_add_epi64(dissq_hi, _mm512_mul_epu32(s2, fq));

                    s0 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref_dis + j - fwidth / 2 + fj + 0))); // 4
                    s2 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref_dis + j - fwidth / 2 + fj + 8))); // 4
                    refdis_lo = _mm512_add_epi64(refdis_lo, _mm512_mul_epu32(s0, fq));
                    refdis_hi = _mm512_add_epi64(refdis_hi, _mm512_mul_epu32(s2, fq));
                    s0 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref_dis + j + fwidth / 2 - fj + 0))); // 4
                    s2 = _mm512_cvtepu32_epi64(_mm256_loadu_si256((__m256i*)(buf.tmp.ref_dis + j + fwidth / 2 - fj + 8))); // 4
                    refdis_lo = _mm512_add_epi64(refdis_lo, _mm512_mul_epu32(s0, fq));
                    refdis_hi = _mm512_add_epi64(refdis_hi, _mm512_mul_epu32(s2, fq));
                }
                refsq_lo = _mm512_srli_epi64(refsq_lo, 16);
                refsq_hi = _mm512_srli_epi64(refsq_hi, 16);
                __m512i refsq = _mm512_permutex2var_epi32(refsq_lo, mask2, refsq_hi);
                xx = _mm512_sub_epi32(refsq, mu1sq);

                dissq_lo = _mm512_srli_epi64(dissq_lo, 16);
                dissq_hi = _mm512_srli_epi64(dissq_hi, 16);
                __m512i dissq = _mm512_permutex2var_epi32(dissq_lo, mask2, dissq_hi);
                yy = _mm512_max_epi32(_mm512_sub_epi32(dissq, mu2sq), _mm512_setzero_si512());

                refdis_lo = _mm512_srli_epi64(refdis_lo, 16);
                refdis_hi = _mm512_srli_epi64(refdis_hi, 16);
                __m512i refdis = _mm512_permutex2var_epi32(refdis_lo, mask2, refdis_hi);
                xy = _mm512_sub_epi32(refdis, mu1mu2);
            }
            vif_statistic_avx512(&residuals, xx, xy, yy, log2_table, vif_enhn_gain_limit);
        }

        if ((n << 4) != w) {
            VifResiduals residuals = vif_compute_line_residuals(s, n << 4, w, 0);
            accum_num_log += residuals.accum_num_log;
            accum_den_log += residuals.accum_den_log;
            accum_num_non_log += residuals.accum_num_non_log;
            accum_den_non_log += residuals.accum_den_non_log;
        }
    }

    accum_num_log += _mm512_reduce_add_epi64(residuals.maccum_num_log);
    accum_den_log += _mm512_reduce_add_epi64(residuals.maccum_den_log);
    accum_num_non_log += _mm512_reduce_add_epi64(residuals.maccum_num_non_log);
    accum_den_non_log += _mm512_reduce_add_epi64(residuals.maccum_den_non_log);
    num[0] = accum_num_log / 2048.0 + (accum_den_non_log - ((accum_num_non_log) / 16384.0) / (65025.0));
    den[0] = accum_den_log / 2048.0 + accum_den_non_log;
}