void vif_subsample_rd_16_avx512()

in libvmaf/src/feature/x86/vif_avx512.c [1105:1267]


void vif_subsample_rd_16_avx512(VifBuffer buf, unsigned w, unsigned h, int scale,
                               int bpc)
{
    const unsigned fwidth = vif_filter1d_width[scale + 1];
    const uint16_t *vif_filt = vif_filter1d_table[scale + 1];
    int32_t add_shift_round_VP, shift_VP;
    int fwidth_half = fwidth >> 1;
    const ptrdiff_t stride = buf.stride / sizeof(uint16_t);
    const ptrdiff_t stride16 = buf.stride_16 / sizeof(uint16_t);
    uint16_t *ref = buf.ref;
    uint16_t *dis = buf.dis;

    if (scale == 0)
    {
        add_shift_round_VP = 1 << (bpc - 1);
        shift_VP = bpc;
    }
    else
    {
        add_shift_round_VP = 32768;
        shift_VP = 16;
    }

    for (unsigned i = 0; i < h; ++i)
    {
        //VERTICAL

        int n = w >> 4;
        int ii = i - fwidth_half;
        for (int j = 0; j < n << 4; j = j + 32)
        {
            int ii_check = ii;
            __m512i accumr_lo, accumr_hi, accumd_lo, accumd_hi, rmul1, rmul2, dmul1, dmul2;
            accumr_lo = accumr_hi = accumd_lo = accumd_hi = rmul1 = rmul2 = dmul1 = dmul2 = _mm512_setzero_si512();
            __m512i mask3 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);   //first half of 512
            __m512i mask4 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); //second half of 512
            for (unsigned fi = 0; fi < fwidth; ++fi, ii_check = ii + fi)
            {

                const uint16_t fcoeff = vif_filt[fi];
                __m512i f1 = _mm512_set1_epi16(fcoeff);
                __m512i ref1 = _mm512_loadu_si512((__m512i *)(ref + (ii_check * stride) + j));
                __m512i dis1 = _mm512_loadu_si512((__m512i *)(dis + (ii_check * stride) + j));
                __m512i result2 = _mm512_mulhi_epu16(ref1, f1);
                __m512i result2lo = _mm512_mullo_epi16(ref1, f1);
                rmul1 = _mm512_unpacklo_epi16(result2lo, result2);
                rmul2 = _mm512_unpackhi_epi16(result2lo, result2);
                accumr_lo = _mm512_add_epi32(accumr_lo, rmul1);
                accumr_hi = _mm512_add_epi32(accumr_hi, rmul2);

                __m512i d0 = _mm512_mulhi_epu16(dis1, f1);
                __m512i d0lo = _mm512_mullo_epi16(dis1, f1);
                dmul1 = _mm512_unpacklo_epi16(d0lo, d0);
                dmul2 = _mm512_unpackhi_epi16(d0lo, d0);
                accumd_lo = _mm512_add_epi32(accumd_lo, dmul1);
                accumd_hi = _mm512_add_epi32(accumd_hi, dmul2);
            }
            __m512i addnum = _mm512_set1_epi32(add_shift_round_VP);
            accumr_lo = _mm512_add_epi32(accumr_lo, addnum);
            accumr_hi = _mm512_add_epi32(accumr_hi, addnum);
            accumr_lo = _mm512_srli_epi32(accumr_lo, shift_VP);
            accumr_hi = _mm512_srli_epi32(accumr_hi, shift_VP);

            _mm512_storeu_si512((__m512i *)(buf.tmp.ref_convol + j),
                                _mm512_permutex2var_epi64(accumr_lo, mask3, accumr_hi));
            _mm512_storeu_si512((__m512i *)(buf.tmp.ref_convol + j + 16),
                                _mm512_permutex2var_epi64(accumr_lo, mask4, accumr_hi));

            accumd_lo = _mm512_add_epi32(accumd_lo, addnum);
            accumd_hi = _mm512_add_epi32(accumd_hi, addnum);
            accumd_lo = _mm512_srli_epi32(accumd_lo, shift_VP);
            accumd_hi = _mm512_srli_epi32(accumd_hi, shift_VP);
            _mm512_storeu_si512((__m512i *)(buf.tmp.dis_convol + j),
                                _mm512_permutex2var_epi64(accumd_lo, mask3, accumd_hi));
            _mm512_storeu_si512((__m512i *)(buf.tmp.dis_convol + j + 16),
                                _mm512_permutex2var_epi64(accumd_lo, mask4, accumd_hi));
        }

        // //VERTICAL
        for (unsigned j = n << 4; j < w; ++j)
        {
            uint32_t accum_ref = 0;
            uint32_t accum_dis = 0;
            int ii_check = ii;
            for (unsigned fi = 0; fi < fwidth; ++fi, ii_check = ii + fi)
            {
                const uint16_t fcoeff = vif_filt[fi];
                accum_ref += fcoeff * ((uint32_t)ref[ii_check * stride + j]);
                accum_dis += fcoeff * ((uint32_t)dis[ii_check * stride + j]);
            }
            buf.tmp.ref_convol[j] = (uint16_t)((accum_ref + add_shift_round_VP) >> shift_VP);
            buf.tmp.dis_convol[j] = (uint16_t)((accum_dis + add_shift_round_VP) >> shift_VP);
        }

        PADDING_SQ_DATA_2(buf, w, fwidth_half);

        //HORIZONTAL
        n = w >> 4;
        for (int j = 0; j < n << 4; j = j + 16)
        {
            int jj = j - fwidth_half;
            int jj_check = jj;
            __m512i accumrlo, accumdlo, accumrhi, accumdhi;
            accumrlo = accumdlo = accumrhi = accumdhi = _mm512_setzero_si512();
            for (unsigned fj = 0; fj < fwidth; ++fj, jj_check = jj + fj)
            {

                __m512i refconvol = _mm512_loadu_si512((__m512i *)(buf.tmp.ref_convol + jj_check));
                __m512i fcoeff = _mm512_set1_epi16(vif_filt[fj]);
                __m512i result2 = _mm512_mulhi_epu16(refconvol, fcoeff);
                __m512i result2lo = _mm512_mullo_epi16(refconvol, fcoeff);
                accumrlo = _mm512_add_epi32(accumrlo, _mm512_unpacklo_epi16(result2lo, result2));
                accumrhi = _mm512_add_epi32(accumrhi, _mm512_unpackhi_epi16(result2lo, result2));
                __m512i disconvol = _mm512_loadu_si512((__m512i *)(buf.tmp.dis_convol + jj_check));
                result2 = _mm512_mulhi_epu16(disconvol, fcoeff);
                result2lo = _mm512_mullo_epi16(disconvol, fcoeff);
                accumdlo = _mm512_add_epi32(accumdlo, _mm512_unpacklo_epi16(result2lo, result2));
                accumdhi = _mm512_add_epi32(accumdhi, _mm512_unpackhi_epi16(result2lo, result2));
            }

            __m512i addnum = _mm512_set1_epi32(32768);
            accumdlo = _mm512_add_epi32(accumdlo, addnum);
            accumdhi = _mm512_add_epi32(accumdhi, addnum);
            accumrlo = _mm512_add_epi32(accumrlo, addnum);
            accumrhi = _mm512_add_epi32(accumrhi, addnum);
            accumdlo = _mm512_srli_epi32(accumdlo, 0x10);
            accumdhi = _mm512_srli_epi32(accumdhi, 0x10);
            accumrlo = _mm512_srli_epi32(accumrlo, 0x10);
            accumrhi = _mm512_srli_epi32(accumrhi, 0x10);

            // __m512i mask2 = _mm512_set_epi16(60, 56, 28, 24, 52, 48, 20, 16, 44,
            //                                  40, 12, 8, 36, 32, 4, 0, 60, 56, 28, 24,
            //                                  52, 48, 20, 16, 44, 40, 12, 8, 36, 32, 4, 0);
            const int M = 1 << 16;
            __m512i mask2 = _mm512_set_epi32(60 * M + 56, 28 * M + 24, 52 * M + 48, 20 * M + 16,
                                             44 * M + 40, 12 * M +  8, 36 * M + 32,  4 * M +  0,
                                             60 * M + 56, 28 * M + 24, 52 * M + 48, 20 * M + 16,
                                             44 * M + 40, 12 * M +  8, 36 * M + 32,  4 * M +  0);

            _mm256_storeu_si256((__m256i *)(buf.mu1 + (stride16 * i) + j),
                                _mm512_castsi512_si256(_mm512_permutex2var_epi16(accumrlo, mask2, accumrhi)));
            _mm256_storeu_si256((__m256i *)(buf.mu2 + (stride16 * i) + j),
                                _mm512_castsi512_si256(_mm512_permutex2var_epi16(accumdlo, mask2, accumdhi)));
        }

        for (unsigned j = n << 4; j < w; ++j)
        {
            uint32_t accum_ref = 0;
            uint32_t accum_dis = 0;
            int jj = j - fwidth_half;
            int jj_check = jj;
            for (unsigned fj = 0; fj < fwidth; ++fj, jj_check = jj + fj)
            {
                const uint16_t fcoeff = vif_filt[fj];
                accum_ref += fcoeff * ((uint32_t)buf.tmp.ref_convol[jj_check]);
                accum_dis += fcoeff * ((uint32_t)buf.tmp.dis_convol[jj_check]);
            }
            buf.mu1[i * stride16 + j] = (uint16_t)((accum_ref + 32768) >> 16);
            buf.mu2[i * stride16 + j] = (uint16_t)((accum_dis + 32768) >> 16);
        }
    }
    decimate_and_pad(buf, w, h, scale);
}