void vif_statistic_16_neon()

in libvmaf/src/feature/arm64/vif_neon.c [774:1065]


void vif_statistic_16_neon(struct VifPublicState *s, float *num, float *den, unsigned w, unsigned h, int bpc, int scale)
{
    const unsigned int uiw7 = (w > 7 ? w - 7 : 0);
    const unsigned int fwidth = vif_filter1d_width[scale];
    const uint16_t *vif_filt_s = vif_filter1d_table[scale];
    VifBuffer buf = s->buf;
    uint16_t *log2_table = s->log2_table;
    double vif_enhn_gain_limit = s->vif_enhn_gain_limit;

    int32_t add_shift_round_HP, shift_HP;
    int32_t add_shift_round_VP, shift_VP;
    int32_t add_shift_round_VP_sq, shift_VP_sq;
    if (scale == 0)
    {
        shift_HP = 16;
        add_shift_round_HP = 32768;
        shift_VP = bpc;
        add_shift_round_VP = 1 << (bpc - 1);
        shift_VP_sq = (bpc - 8) * 2;
        add_shift_round_VP_sq = (bpc == 8) ? 0 : 1 << (shift_VP_sq - 1);
    }
    else
    {
        shift_HP = 16;
        add_shift_round_HP = 32768;
        shift_VP = 16;
        add_shift_round_VP = 32768;
        shift_VP_sq = 16;
        add_shift_round_VP_sq = 32768;
    }

    const uint32x4_t add_shift_round_VP_vec = vdupq_n_u32(add_shift_round_VP);
    const int32x4_t shift_VP_vec = vdupq_n_s32(-shift_VP);
    const uint64x2_t add_shift_round_VP_sq_vec = vdupq_n_u64(add_shift_round_VP_sq);
    const int64x2_t shift_VP_sq_vec = vdupq_n_s64(-shift_VP_sq);

    const uint64x2_t add_shift_round_HP_vec = vdupq_n_u64(add_shift_round_HP);
    const int64x2_t shift_vec_HP = vdupq_n_s64(-shift_HP);

    const uint16_t *ref = (uint16_t *)buf.ref;
    const uint16_t *dis = (uint16_t *)buf.dis;

    const ptrdiff_t stride_16 = buf.stride / sizeof(uint16_t);
    const ptrdiff_t stride_32 = buf.stride_32 / sizeof(uint32_t);
    ptrdiff_t i_dst_stride = 0;
    int32_t xx[8], yy[8], xy[8];
    int64_t accum_num_log = 0.0;
    int64_t accum_den_log = 0.0;
    int64_t accum_num_non_log = 0;
    int64_t accum_den_non_log = 0;
    static const int32_t sigma_nsq = 65536 << 1;

    for (unsigned i = 0; i < h; ++i, i_dst_stride += stride_32)
    {
        int ii = i - fwidth / 2;
        const uint16_t *p_ref = ref + ii * stride_16;
        const uint16_t *p_dis = dis + ii * stride_16;

        // VERTICAL 
        unsigned int j = 0;
        for (; j < uiw7; j += 8, p_ref += 8, p_dis += 8)
        {
            NEON_FILTER_LOAD_U16X8_AND_MOVE_TO_U32X4_AND_SQR(ref_vec_16u, ref_vec_32u, ref_ref_vec, p_ref);
            NEON_FILTER_LOAD_U16X8_AND_MOVE_TO_U32X4_AND_SQR(dis_vec_16u, dis_vec_32u, dis_dis_vec, p_dis);

            NEON_FILTER_INSTANCE_U32X4_NO_INIT_MULL_U16X4_WITH_CONST_LO_HI(accum_f_ref, ref_vec_16u, vif_filt_s[0]);
            NEON_FILTER_INSTANCE_U32X4_NO_INIT_MULL_U16X4_WITH_CONST_LO_HI(accum_f_dis, dis_vec_16u, vif_filt_s[0]);

            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_U32X2_WITH_CONST_LO_HI_LH(accum_f_ref_ref, add_shift_round_VP_sq_vec, ref_ref_vec, vif_filt_s[0]);
            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_U32X2_WITH_CONST_LO_HI_LH(accum_f_dis_dis, add_shift_round_VP_sq_vec, dis_dis_vec, vif_filt_s[0]);
            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_U32X2_LO_HI(accum_f_ref_dis_l, add_shift_round_VP_sq_vec, accum_f_dis_l, ref_vec_32u_l);
            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_U32X2_LO_HI(accum_f_ref_dis_h, add_shift_round_VP_sq_vec, accum_f_dis_h, ref_vec_32u_h);

            const uint16_t *pp_ref = p_ref + stride_16;
            const uint16_t *pp_dis = p_dis + stride_16;
            for (unsigned fi = 1; fi < fwidth; ++fi, pp_ref += stride_16, pp_dis += stride_16)
            {
                NEON_FILTER_LOAD_U16X8_AND_MOVE_TO_U32X4_AND_SQR(ref_vec_16u, ref_vec_32u, ref_ref_vec, pp_ref);
                NEON_FILTER_LOAD_U16X8_AND_MOVE_TO_U32X4_AND_SQR(dis_vec_16u, dis_vec_32u, dis_dis_vec, pp_dis);

                NEON_FILTER_INSTANCE_U32X4_NO_INIT_MULL_U16X4_WITH_CONST_LO_HI(f_dis, dis_vec_16u, vif_filt_s[fi]);

                NEON_FILTER_UPDATE_ACCUM_U32X4_WITH_CONST_LO_HI(accum_f_ref, ref_vec_16u, vif_filt_s[fi]);
                NEON_FILTER_UPDATE_ACCUM_U32X4_WITH_CONST_LO_HI(accum_f_dis, dis_vec_16u, vif_filt_s[fi]);

                NEON_FILTER_ACCUM_LO_HI_LH_U64X2_WITH_CONST_LH(accum_f_ref_ref, ref_ref_vec, vif_filt_s[fi]);
                NEON_FILTER_ACCUM_LO_HI_LH_U64X2_WITH_CONST_LH(accum_f_dis_dis, dis_dis_vec, vif_filt_s[fi]);

                NEON_FILTER_UPDATE_U64X2_ACCUM_LO_HI(accum_f_ref_dis_l, f_dis_l, ref_vec_32u_l);
                NEON_FILTER_UPDATE_U64X2_ACCUM_LO_HI(accum_f_ref_dis_h, f_dis_h, ref_vec_32u_h);
            }

            NEON_FILTER_OFFSET_SHIFT_STORE_U32X4(accum_f_ref_l, add_shift_round_VP_vec, shift_VP_vec, buf.tmp.mu1 + j);
            NEON_FILTER_OFFSET_SHIFT_STORE_U32X4(accum_f_ref_h, add_shift_round_VP_vec, shift_VP_vec, buf.tmp.mu1 + j + 4);

            NEON_FILTER_OFFSET_SHIFT_STORE_U32X4(accum_f_dis_l, add_shift_round_VP_vec, shift_VP_vec, buf.tmp.mu2 + j);
            NEON_FILTER_OFFSET_SHIFT_STORE_U32X4(accum_f_dis_h, add_shift_round_VP_vec, shift_VP_vec, buf.tmp.mu2 + j + 4);

            NEON_FILTER_SHIFT_UNZIP_STORE_U64X2_TO_U32X4_LO_HI(accum_f_ref_ref_l, shift_VP_sq_vec, buf.tmp.ref + j);
            NEON_FILTER_SHIFT_UNZIP_STORE_U64X2_TO_U32X4_LO_HI(accum_f_ref_ref_h, shift_VP_sq_vec, buf.tmp.ref + j + 4);

            NEON_FILTER_SHIFT_UNZIP_STORE_U64X2_TO_U32X4_LO_HI(accum_f_dis_dis_l, shift_VP_sq_vec, buf.tmp.dis + j);
            NEON_FILTER_SHIFT_UNZIP_STORE_U64X2_TO_U32X4_LO_HI(accum_f_dis_dis_h, shift_VP_sq_vec, buf.tmp.dis + j + 4);

            NEON_FILTER_SHIFT_UNZIP_STORE_U64X2_TO_U32X4_LO_HI(accum_f_ref_dis_l, shift_VP_sq_vec, buf.tmp.ref_dis + j);
            NEON_FILTER_SHIFT_UNZIP_STORE_U64X2_TO_U32X4_LO_HI(accum_f_ref_dis_h, shift_VP_sq_vec, buf.tmp.ref_dis + j + 4);
        }

        // Scalar code for Vertical leftover.
        for (; j < w; ++j)
        {
            uint32_t accum_mu1 = add_shift_round_VP;
            uint32_t accum_mu2 = add_shift_round_VP;
            uint64_t accum_ref = add_shift_round_VP_sq;
            uint64_t accum_dis = add_shift_round_VP_sq;
            uint64_t accum_ref_dis = add_shift_round_VP_sq;
            for (unsigned fi = 0; fi < fwidth; ++fi)
            {
                int ii = i - fwidth / 2;
                int ii_check = ii + fi;
                const uint16_t fcoeff = vif_filt_s[fi];
                const ptrdiff_t stride = buf.stride / sizeof(uint16_t);
                uint16_t *ref = (uint16_t *)buf.ref;
                uint16_t *dis = (uint16_t *)buf.dis;
                uint16_t imgcoeff_ref = ref[ii_check * stride + j];
                uint16_t imgcoeff_dis = dis[ii_check * stride + j];
                uint32_t img_coeff_ref = fcoeff * (uint32_t)imgcoeff_ref;
                uint32_t img_coeff_dis = fcoeff * (uint32_t)imgcoeff_dis;
                accum_mu1 += img_coeff_ref;
                accum_mu2 += img_coeff_dis;
                accum_ref += img_coeff_ref * (uint64_t)imgcoeff_ref;
                accum_dis += img_coeff_dis * (uint64_t)imgcoeff_dis;
                accum_ref_dis += img_coeff_ref * (uint64_t)imgcoeff_dis;
            }
            buf.tmp.mu1[j] = (uint16_t)(accum_mu1 >> shift_VP);
            buf.tmp.mu2[j] = (uint16_t)(accum_mu2 >> shift_VP);
            buf.tmp.ref[j] = (uint32_t)(accum_ref >> shift_VP_sq);
            buf.tmp.dis[j] = (uint32_t)(accum_dis >> shift_VP_sq);
            buf.tmp.ref_dis[j] = (uint32_t)(accum_ref_dis >> shift_VP_sq);
        }

        PADDING_SQ_DATA(buf, w, fwidth / 2);

        // HORIZONTAL
        uint32_t *pMul1 = (uint32_t *)buf.tmp.mu1 - (fwidth / 2);
        uint32_t *pMul2 = (uint32_t *)buf.tmp.mu2 - (fwidth / 2);
        uint32_t *pRef = (uint32_t *)buf.tmp.ref - (fwidth / 2);
        uint32_t *pDis = (uint32_t *)buf.tmp.dis - (fwidth / 2);
        uint32_t *pRefDis = (uint32_t *)buf.tmp.ref_dis - (fwidth / 2);

        j = 0;
        for (; j < uiw7; j += 8, pMul1 += 8, pMul2 += 8, pDis += 8, pRef += 8, pRefDis += 8)
        {
            uint32x4_t mul1_vec_u32_0 = vld1q_u32(pMul1);
            uint32x4_t mul2_vec_u32_0 = vld1q_u32(pMul2);
            uint32x4_t ref_vec_u32_0 = vld1q_u32(pRef);
            uint32x4_t dis_vec_u32_0 = vld1q_u32(pDis);
            uint32x4_t ref_dis_vec_u32_0 = vld1q_u32(pRefDis);

            uint32x4_t mul1_vec_u32_1 = vld1q_u32(pMul1 + 4);
            uint32x4_t mul2_vec_u32_1 = vld1q_u32(pMul2 + 4);
            uint32x4_t ref_vec_u32_1 = vld1q_u32(pRef + 4);
            uint32x4_t dis_vec_u32_1 = vld1q_u32(pDis + 4);
            uint32x4_t ref_dis_vec_u32_1 = vld1q_u32(pRefDis + 4);

            uint32x4_t accum_mu1_0 = vmulq_n_u32(mul1_vec_u32_0, vif_filt_s[0]);
            uint32x4_t accum_mu2_0 = vmulq_n_u32(mul2_vec_u32_0, vif_filt_s[0]);
            uint32x4_t accum_mu1_1 = vmulq_n_u32(mul1_vec_u32_1, vif_filt_s[0]);
            uint32x4_t accum_mu2_1 = vmulq_n_u32(mul2_vec_u32_1, vif_filt_s[0]);

            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_U32X2_WITH_CONST_LO_HI_LU2(accum_ref, add_shift_round_HP_vec, ref_vec_u32, vif_filt_s[0]);
            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_U32X2_WITH_CONST_LO_HI_LU2(accum_dis, add_shift_round_HP_vec, dis_vec_u32, vif_filt_s[0]);
            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_U32X2_WITH_CONST_LO_HI_LU2(accum_ref_dis, add_shift_round_HP_vec, ref_dis_vec_u32, vif_filt_s[0]);

            for (unsigned fj = 1; fj < fwidth; ++fj)
            {
                mul1_vec_u32_0 = vld1q_u32(pMul1 + fj);
                mul2_vec_u32_0 = vld1q_u32(pMul2 + fj);
                ref_vec_u32_0 = vld1q_u32(pRef + fj);
                dis_vec_u32_0 = vld1q_u32(pDis + fj);
                ref_dis_vec_u32_0 = vld1q_u32(pRefDis + fj);
                mul1_vec_u32_1 = vld1q_u32(pMul1 + 4 + fj);
                mul2_vec_u32_1 = vld1q_u32(pMul2 + 4 + fj);
                ref_vec_u32_1 = vld1q_u32(pRef + 4 + fj);
                dis_vec_u32_1 = vld1q_u32(pDis + 4 + fj);
                ref_dis_vec_u32_1 = vld1q_u32(pRefDis + 4 + fj);

                accum_mu1_0 = vmlaq_n_u32(accum_mu1_0, mul1_vec_u32_0, vif_filt_s[fj]);
                accum_mu2_0 = vmlaq_n_u32(accum_mu2_0, mul2_vec_u32_0, vif_filt_s[fj]);
                accum_mu1_1 = vmlaq_n_u32(accum_mu1_1, mul1_vec_u32_1, vif_filt_s[fj]);
                accum_mu2_1 = vmlaq_n_u32(accum_mu2_1, mul2_vec_u32_1, vif_filt_s[fj]);

                NEON_FILTER_UPDATE_ACCUM_U64X2_WITH_CONST_LO_HI_LU2(accum_ref, ref_vec_u32, vif_filt_s[fj]);
                NEON_FILTER_UPDATE_ACCUM_U64X2_WITH_CONST_LO_HI_LU2(accum_dis, dis_vec_u32, vif_filt_s[fj]);
                NEON_FILTER_UPDATE_ACCUM_U64X2_WITH_CONST_LO_HI_LU2(accum_ref_dis, ref_dis_vec_u32, vif_filt_s[fj]);
            }

            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_SHIFT_UNZIP_U32X4_LO_HI(mu1_sq_vec_l, vdupq_n_u64(2147483648), accum_mu1_0, accum_mu1_0, vdupq_n_s64(-32));
            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_SHIFT_UNZIP_U32X4_LO_HI(mu1_sq_vec_h, vdupq_n_u64(2147483648), accum_mu1_1, accum_mu1_1, vdupq_n_s64(-32));

            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_SHIFT_UNZIP_U32X4_LO_HI(mu2_sq_vec_l, vdupq_n_u64(2147483648), accum_mu2_0, accum_mu2_0, vdupq_n_s64(-32));
            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_SHIFT_UNZIP_U32X4_LO_HI(mu2_sq_vec_h, vdupq_n_u64(2147483648), accum_mu2_1, accum_mu2_1, vdupq_n_s64(-32));

            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_SHIFT_UNZIP_U32X4_LO_HI(mu1_mu2_sq_vec_l, vdupq_n_u64(2147483648), accum_mu1_0, accum_mu2_0, vdupq_n_s64(-32));
            NEON_FILTER_INSTANCE_U64X2_INIT_MULL_SHIFT_UNZIP_U32X4_LO_HI(mu1_mu2_sq_vec_h, vdupq_n_u64(2147483648), accum_mu1_1, accum_mu2_1, vdupq_n_s64(-32));

            NEON_FILTER_SHIFT_UNZIP_U64X2_TO_U32X4_LO_HI(accum_ref_0, shift_vec_HP, xx_filt_vec_l);
            NEON_FILTER_SHIFT_UNZIP_U64X2_TO_U32X4_LO_HI(accum_ref_1, shift_vec_HP, xx_filt_vec_h);

            NEON_FILTER_SHIFT_UNZIP_U64X2_TO_U32X4_LO_HI(accum_dis_0, shift_vec_HP, yy_filt_vec_l);
            NEON_FILTER_SHIFT_UNZIP_U64X2_TO_U32X4_LO_HI(accum_dis_1, shift_vec_HP, yy_filt_vec_h);

            NEON_FILTER_SHIFT_UNZIP_U64X2_TO_U32X4_LO_HI(accum_ref_dis_0, shift_vec_HP, xy_filt_vec_l);
            NEON_FILTER_SHIFT_UNZIP_U64X2_TO_U32X4_LO_HI(accum_ref_dis_1, shift_vec_HP , xy_filt_vec_h);

            int32x4_t sigma1_sq_vec_l = vreinterpretq_s32_u32(vsubq_u32(xx_filt_vec_l,mu1_sq_vec_l));
            int32x4_t sigma1_sq_vec_h = vreinterpretq_s32_u32(vsubq_u32(xx_filt_vec_h,mu1_sq_vec_h));

            int32x4_t sigma2_sq_vec_l = vreinterpretq_s32_u32(vsubq_u32(yy_filt_vec_l,mu2_sq_vec_l));
            int32x4_t sigma2_sq_vec_h = vreinterpretq_s32_u32(vsubq_u32(yy_filt_vec_h,mu2_sq_vec_h));

            int32x4_t sigma12_vec_l = vreinterpretq_s32_u32(vsubq_u32(xy_filt_vec_l,mu1_mu2_sq_vec_l));
            int32x4_t sigma12_vec_h = vreinterpretq_s32_u32(vsubq_u32(xy_filt_vec_h,mu1_mu2_sq_vec_h));

            vst1q_s32(xx,       sigma1_sq_vec_l);
            vst1q_s32(xx + 4,   sigma1_sq_vec_h);

            vst1q_s32(yy,       vmaxq_s32(vdupq_n_s32(0),sigma2_sq_vec_l));
            vst1q_s32(yy + 4,   vmaxq_s32(vdupq_n_s32(0),sigma2_sq_vec_h));

            vst1q_s32(xy,       sigma12_vec_l);
            vst1q_s32(xy + 4,   sigma12_vec_h);

            for (unsigned int b = 0; b < 8; b++) {
                int32_t sigma1_sq = xx[b];
                int32_t sigma2_sq = yy[b];
                int32_t sigma12 = xy[b];

                if (sigma1_sq >= sigma_nsq) {
                    /**
                    * log values are taken from the look-up table generated by
                    * log_generate() function which is called in integer_combo_threadfunc
                    * den_val in float is log2(1 + sigma1_sq/2)
                    * here it is converted to equivalent of log2(2+sigma1_sq) - log2(2) i.e log2(2*65536+sigma1_sq) - 17
                    * multiplied by 2048 as log_value = log2(i)*2048 i=16384 to 65535 generated using log_value
                    * x because best 16 bits are taken
                    */
                    accum_den_log += log2_32(log2_table, sigma_nsq + sigma1_sq) - 2048 * 17;

                    if (sigma12 > 0 && sigma2_sq > 0)
                    {
                        // num_val = log2f(1.0f + (g * g * sigma1_sq) / (sv_sq + sigma_nsq));
                        /**
                        * In floating-point numerator = log2((1.0f + (g * g * sigma1_sq)/(sv_sq + sigma_nsq))
                        *
                        * In Fixed-point the above is converted to
                        * numerator = log2((sv_sq + sigma_nsq)+(g * g * sigma1_sq))- log2(sv_sq + sigma_nsq)
                        */

                        const double eps = 65536 * 1.0e-10;
                        double g = sigma12 / (sigma1_sq + eps); // this epsilon can go away
                        int32_t sv_sq = sigma2_sq - g * sigma12;

                        sv_sq = (uint32_t)(MAX(sv_sq, 0));

                        g = MIN(g, vif_enhn_gain_limit);

                        uint32_t numer1 = (sv_sq + sigma_nsq);
                        int64_t numer1_tmp = (int64_t)((g * g * sigma1_sq)) + numer1; //numerator
                        accum_num_log += log2_64(log2_table, numer1_tmp) - log2_64(log2_table, numer1);
                    }
                }
                else {
                    accum_num_non_log += sigma2_sq;
                    accum_den_non_log++;
                }
            }
        }

        if (j != w)
        {
            VifResiduals residuals =
                vif_compute_line_residuals(s, j, w, scale);
            accum_num_log += residuals.accum_num_log;
            accum_den_log += residuals.accum_den_log;
            accum_num_non_log += residuals.accum_num_non_log;
            accum_den_non_log += residuals.accum_den_non_log;
        }
    }
    num[0] = accum_num_log / 2048.0 + (accum_den_non_log - ((accum_num_non_log) / 16384.0) / (65025.0));
    den[0] = accum_den_log / 2048.0 + accum_den_non_log;
}