int compute_vif()

in libvmaf/src/feature/vif.c [45:271]


int compute_vif(const float *ref, const float *dis, int w, int h, int ref_stride, int dis_stride,
        double *score, double *score_num, double *score_den, double *scores,
        double vif_enhn_gain_limit, double vif_kernelscale)
{
    float *data_buf = 0;
    char *data_top;

    float *ref_scale;
    float *dis_scale;

    float *mu1;
    float *mu2;
    float *ref_sq_filt;
    float *dis_sq_filt;
    float *ref_dis_filt;
    float *tmpbuf;

    const float *filter;
    int filter_width;

    /* Offset pointers to adjust for convolution border handling. */
    float *mu1_adj = 0;
    float *mu2_adj = 0;

#ifdef VIF_OPT_DEBUG_DUMP
    float *mu1_sq_adj;
    float *mu2_sq_adj;
    float *mu1_mu2_adj;
    float *ref_sq_filt_adj;
    float *dis_sq_filt_adj;
    float *ref_dis_filt_adj = 0;
#endif

    /* Special handling of first scale. */
    const float *curr_ref_scale = ref;
    const float *curr_dis_scale = dis;
    int curr_ref_stride = ref_stride;
    int curr_dis_stride = dis_stride;

    int buf_stride = ALIGN_CEIL(w * sizeof(float));
    size_t buf_sz_one = (size_t)buf_stride * h;

    int scale;
    int ret = 1;

    int kernelscale_index = -1;
    if (ALMOST_EQUAL(vif_kernelscale, 1.0)) {
        kernelscale_index = vif_kernelscale_1;
    } else if (ALMOST_EQUAL(vif_kernelscale, 1.0/2)) {
        kernelscale_index = vif_kernelscale_1o2;
    } else if (ALMOST_EQUAL(vif_kernelscale, 3.0/2)) {
        kernelscale_index = vif_kernelscale_3o2;
    } else if (ALMOST_EQUAL(vif_kernelscale, 2.0)) {
        kernelscale_index = vif_kernelscale_2;
    } else if (ALMOST_EQUAL(vif_kernelscale, 2.0/3)) {
        kernelscale_index = vif_kernelscale_2o3;
    } else if (ALMOST_EQUAL(vif_kernelscale, 2.4/1.0)) {
        kernelscale_index = vif_kernelscale_24o10;
    } else if (ALMOST_EQUAL(vif_kernelscale, 360/97.0)) {
        kernelscale_index = vif_kernelscale_360o97;
    } else if (ALMOST_EQUAL(vif_kernelscale, 4.0/3.0)) {
        kernelscale_index = vif_kernelscale_4o3;
    } else if (ALMOST_EQUAL(vif_kernelscale, 3.5/3.0)) {
        kernelscale_index = vif_kernelscale_3d5o3;
    } else if (ALMOST_EQUAL(vif_kernelscale, 3.75/3.0)) {
        kernelscale_index = vif_kernelscale_3d75o3;
    } else if (ALMOST_EQUAL(vif_kernelscale, 4.25/3.0)) {
        kernelscale_index = vif_kernelscale_4d25o3;
    } else {
        printf("error: vif_kernelscale can only be 0.5, 1.0, 1.5, 2.0, 2.0/3, 2.4, 360/97, 4.0/3.0, 3.5/3.0, 3.75/3.0, 4.25/3.0 for now, but is %f\n", vif_kernelscale);
        fflush(stdout);
        goto fail_or_end;
    }

    // Code optimized to save on multiple buffer copies
    // hence the reduction in the number of buffers required from 15 to 8
#define VIF_BUF_CNT 8
    if (SIZE_MAX / buf_sz_one < VIF_BUF_CNT)
    {
        printf("error: SIZE_MAX / buf_sz_one < VIF_BUF_CNT, buf_sz_one = %zu.\n", buf_sz_one);
        fflush(stdout);
        goto fail_or_end;
    }

    if (!(data_buf = aligned_malloc(buf_sz_one * VIF_BUF_CNT, MAX_ALIGN)))
    {
        printf("error: aligned_malloc failed for data_buf.\n");
        fflush(stdout);
        goto fail_or_end;
    }

    data_top = (char *)data_buf;

    ref_scale = (float *)data_top; data_top += buf_sz_one;
    dis_scale = (float *)data_top; data_top += buf_sz_one;
    mu1 = (float *)data_top; data_top += buf_sz_one;
    mu2 = (float *)data_top; data_top += buf_sz_one;
    ref_sq_filt = (float *)data_top; data_top += buf_sz_one;
    dis_sq_filt = (float *)data_top; data_top += buf_sz_one;
    ref_dis_filt = (float *)data_top; data_top += buf_sz_one;
    tmpbuf = (float *)data_top; data_top += buf_sz_one;

    for (scale = 0; scale < 4; ++scale)
    {
#ifdef VIF_OPT_DEBUG_DUMP
        char pathbuf[256];
#endif

        filter = vif_filter1d_table_s[kernelscale_index][scale];
        filter_width = vif_filter1d_width[kernelscale_index][scale];

#ifdef VIF_OPT_HANDLE_BORDERS
        int buf_valid_w = w;
        int buf_valid_h = h;

  #define ADJUST(x) x
#else
        int filter_adj  = filter_width / 2;
        int buf_valid_w = w - filter_adj * 2;
        int buf_valid_h = h - filter_adj * 2;

  #define ADJUST(x) ((float *)((char *)(x) + filter_adj * buf_stride + filter_adj * sizeof(float)))
#endif

        if (scale > 0)
        {
            vif_filter1d_s(filter, curr_ref_scale, mu1, tmpbuf, w, h, curr_ref_stride, buf_stride, filter_width);
            vif_filter1d_s(filter, curr_dis_scale, mu2, tmpbuf, w, h, curr_dis_stride, buf_stride, filter_width);

            mu1_adj = ADJUST(mu1);
            mu2_adj = ADJUST(mu2);

            vif_dec2_s(mu1_adj, ref_scale, buf_valid_w, buf_valid_h, buf_stride, buf_stride);
            vif_dec2_s(mu2_adj, dis_scale, buf_valid_w, buf_valid_h, buf_stride, buf_stride);

            w  = buf_valid_w / 2;
            h  = buf_valid_h / 2;
#ifdef VIF_OPT_HANDLE_BORDERS
            buf_valid_w = w;
            buf_valid_h = h;
#else
            buf_valid_w = w - filter_adj * 2;
            buf_valid_h = h - filter_adj * 2;
#endif
            curr_ref_scale = ref_scale;
            curr_dis_scale = dis_scale;

            curr_ref_stride = buf_stride;
            curr_dis_stride = buf_stride;
        }

        vif_filter1d_s(filter, curr_ref_scale, mu1, tmpbuf, w, h, curr_ref_stride, buf_stride, filter_width);
        vif_filter1d_s(filter, curr_dis_scale, mu2, tmpbuf, w, h, curr_dis_stride, buf_stride, filter_width);

        // Code optimized by adding intrinsic code for the functions,
        // vif_filter1d_sq and vif_filter1d_sq
        vif_filter1d_sq_s(filter, curr_ref_scale, ref_sq_filt, tmpbuf, w, h, curr_ref_stride, buf_stride, filter_width);
        vif_filter1d_sq_s(filter, curr_dis_scale, dis_sq_filt, tmpbuf, w, h, curr_dis_stride, buf_stride, filter_width);
        vif_filter1d_xy_s(filter, curr_ref_scale, curr_dis_scale, ref_dis_filt, tmpbuf, w, h, curr_ref_stride, curr_dis_stride, buf_stride, filter_width);

        float num, den;
        vif_statistic_s(mu1, mu2, ref_sq_filt, dis_sq_filt, ref_dis_filt, &num, &den,
            w, h, buf_stride, buf_stride, buf_stride, buf_stride, buf_stride, vif_enhn_gain_limit);
        mu1_adj = ADJUST(mu1);
        mu2_adj = ADJUST(mu2);

#ifdef VIF_OPT_DEBUG_DUMP
        ref_sq_filt_adj  = ADJUST(ref_sq_filt);
        dis_sq_filt_adj  = ADJUST(dis_sq_filt);
        ref_dis_filt_adj = ADJUST(ref_dis_filt);
#endif

#undef ADJUST

#ifdef VIF_OPT_DEBUG_DUMP
        sprintf(pathbuf, "stage/ref[%d].bin", scale);
        write_image(pathbuf, curr_ref_scale, w, h, curr_ref_stride, sizeof(float));

        sprintf(pathbuf, "stage/dis[%d].bin", scale);
        write_image(pathbuf, curr_dis_scale, w, h, curr_dis_stride, sizeof(float));

        sprintf(pathbuf, "stage/mu1[%d].bin", scale);
        write_image(pathbuf, mu1_adj, buf_valid_w, buf_valid_h, buf_stride, sizeof(float));

        sprintf(pathbuf, "stage/mu2[%d].bin", scale);
        write_image(pathbuf, mu2_adj, buf_valid_w, buf_valid_h, buf_stride, sizeof(float));

        sprintf(pathbuf, "stage/ref_sq_filt[%d].bin", scale);
        write_image(pathbuf, ref_sq_filt_adj, buf_valid_w, buf_valid_h, buf_stride, sizeof(float));

        sprintf(pathbuf, "stage/dis_sq_filt[%d].bin", scale);
        write_image(pathbuf, dis_sq_filt_adj, buf_valid_w, buf_valid_h, buf_stride, sizeof(float));

        sprintf(pathbuf, "stage/ref_dis_filt[%d].bin", scale);
        write_image(pathbuf, ref_dis_filt_adj, buf_valid_w, buf_valid_h, buf_stride, sizeof(float));
#endif

        scores[2*scale] = num;
        scores[2*scale+1] = den;

#ifdef VIF_OPT_DEBUG_DUMP
        printf("num[%d]: %e\n", scale, num);
        printf("den[%d]: %e\n", scale, den);
#endif
    }

    *score_num = 0.0;
    *score_den = 0.0;
    for (scale = 0; scale < 4; ++scale)
    {
        *score_num += scores[2*scale];
        *score_den += scores[2*scale+1];
    }
    if (*score_den == 0.0)
    {
        *score = 1.0f;
    }
    else
    {
        *score = (*score_num) / (*score_den);
    }

    ret = 0;
fail_or_end:
    aligned_free(data_buf);
    return ret;
}