static void integer_compute_adm_cuda()

in libvmaf/src/feature/cuda/integer_adm_cuda.c [756:926]


static void integer_compute_adm_cuda(VmafFeatureExtractor *fex, AdmStateCuda *s,
        VmafPicture *ref_pic, VmafPicture *dis_pic, AdmBufferCuda *buf,
        double adm_enhn_gain_limit,
        double adm_norm_view_dist,
        int adm_ref_display_height)
{
    int w = ref_pic->w[0];
    int h = ref_pic->h[0];

    AdmFixedParametersCuda p = {
        .dwt2_db2_coeffs_lo = {15826, 27411, 7345, -4240},
        .dwt2_db2_coeffs_hi = {-4240, -7345, 27411, -15826},
        .dwt2_db2_coeffs_lo_sum = 46342,
        .dwt2_db2_coeffs_hi_sum = 0,
        .log2_w = log2(w),
        .log2_h = log2(h),
        .adm_ref_display_height = adm_ref_display_height,
        .adm_norm_view_dist = adm_norm_view_dist,
        .adm_enhn_gain_limit = adm_enhn_gain_limit,
    };

    const double pow2_32 = pow(2, 32);
    const double pow2_21 = pow(2, 21);
    const double pow2_23 = pow(2, 23);
    for (unsigned scale = 0; scale < 4; ++scale) {
        float factor1 = dwt_quant_step(&dwt_7_9_YCbCr_threshold[0], scale, 1,
                adm_norm_view_dist, adm_ref_display_height);
        float factor2 = dwt_quant_step(&dwt_7_9_YCbCr_threshold[0], scale, 2,
                adm_norm_view_dist, adm_ref_display_height);
        p.factor1[scale] = factor1;
        p.factor2[scale] = factor2;
        p.rfactor[scale*3] = 1.0f / factor1;
        p.rfactor[scale*3+1] = 1.0f / factor1;
        p.rfactor[scale*3+2] = 1.0f / factor2;
        if (scale == 0) {
            if (fabs(p.adm_norm_view_dist * p.adm_ref_display_height -
                        DEFAULT_ADM_NORM_VIEW_DIST * DEFAULT_ADM_REF_DISPLAY_HEIGHT) <
                    1.0e-8) {
                p.i_rfactor[scale * 3] = 36453;
                p.i_rfactor[scale * 3 + 1] = 36453;
                p.i_rfactor[scale * 3 + 2] = 49417;
            } else {
                p.i_rfactor[scale * 3] = (uint32_t)(p.rfactor[scale * 3] * pow2_21);
                p.i_rfactor[scale * 3 + 1] =
                    (uint32_t)(p.rfactor[scale * 3 + 1] * pow2_21);
                p.i_rfactor[scale * 3 + 2] =
                    (uint32_t)(p.rfactor[scale * 3 + 2] * pow2_23);
            }
        } else {
            p.i_rfactor[scale * 3] = (uint32_t)(p.rfactor[scale * 3] * pow2_32);
            p.i_rfactor[scale * 3 + 1] =
                (uint32_t)(p.rfactor[scale * 3 + 1] * pow2_32);
            p.i_rfactor[scale * 3 + 2] =
                (uint32_t)(p.rfactor[scale * 3 + 2] * pow2_32);
        }
        uint32_t *i_rfactor = &p.i_rfactor[scale*3];
    }
    CHECK_CUDA(cuMemsetD8Async(buf->tmp_res->data, 0, sizeof(int64_t) * RES_BUFFER_SIZE, s->str));

    size_t curr_ref_stride;
    size_t curr_dis_stride;
    size_t buf_stride = buf->ind_size_x >> 2;

    int32_t *i4_curr_ref_scale = NULL;
    int32_t *i4_curr_dis_scale = NULL;

    if (ref_pic->bpc == 8) {
        curr_ref_stride = ref_pic->stride[0];
        curr_dis_stride = dis_pic->stride[0];
    }
    else {
        curr_ref_stride = dis_pic->stride[0] >> 1;
        curr_dis_stride = ref_pic->stride[0] >> 1;
    }

    for (unsigned scale = 0; scale < 4; ++scale) {
        float num_scale = 0.0;
        float den_scale = 0.0;


        if(scale==0) {
            // run these first dwt kernels on the input iamge stream to make sure it is consumed afterwards continue
            // consumes reference picture
            // produces buf->ref_dwt2, buf->dis_dwt2
            if (ref_pic->bpc == 8) {
                dwt2_8_device(s, (const uint8_t*)ref_pic->data[0], &buf->ref_dwt2, buf->i4_ref_dwt2, (short2*)buf->tmp_ref->data, buf, w, h, curr_ref_stride, buf_stride, &p, vmaf_cuda_picture_get_stream(ref_pic));

                dwt2_8_device(s, (const uint8_t*)dis_pic->data[0], &buf->dis_dwt2, buf->i4_dis_dwt2, (short2*)buf->tmp_dis->data, buf, w, h, curr_dis_stride, buf_stride, &p,  vmaf_cuda_picture_get_stream(dis_pic));
            }
            else {
                adm_dwt2_16_device(s,(uint16_t*)ref_pic->data[0], &buf->ref_dwt2, buf->i4_ref_dwt2, (short2*)buf->tmp_ref->data, buf, w, h, curr_ref_stride, buf_stride, ref_pic->bpc, &p,  vmaf_cuda_picture_get_stream(ref_pic));

                adm_dwt2_16_device(s,(uint16_t*)dis_pic->data[0], &buf->dis_dwt2, buf->i4_dis_dwt2, (short2*)buf->tmp_dis->data, buf, w, h, curr_dis_stride, buf_stride, dis_pic->bpc, &p,  vmaf_cuda_picture_get_stream(dis_pic));

            }
            CHECK_CUDA(cuEventRecord(s->ref_event,  vmaf_cuda_picture_get_stream(ref_pic)));
            CHECK_CUDA(cuEventRecord(s->dis_event,  vmaf_cuda_picture_get_stream(dis_pic)));

            w = (w + 1) / 2;
            h = (h + 1) / 2;

            // This event ensures the input buffer is consumed
            CHECK_CUDA(cuCtxPushCurrent(fex->cu_state->ctx));

            CHECK_CUDA(cuStreamWaitEvent(s->str, s->dis_event, CU_EVENT_WAIT_DEFAULT));
            CHECK_CUDA(cuEventDestroy(s->dis_event));
            CHECK_CUDA(cuEventCreate(&s->dis_event, CU_EVENT_DEFAULT));

            CHECK_CUDA(cuStreamWaitEvent(s->str, s->ref_event, CU_EVENT_WAIT_DEFAULT));
            CHECK_CUDA(cuEventDestroy(s->ref_event));
            CHECK_CUDA(cuEventCreate(&s->ref_event, CU_EVENT_DEFAULT));

            CHECK_CUDA(cuCtxPopCurrent(NULL));
            // consumes buf->ref_dwt2 , buf->dis_dwt2
            // produces buf->decouple_r , buf->decouple_a
            adm_decouple_device(s, buf, w, h, buf_stride, &p, s->str);

            // consumes buf->ref_dwt2
            // produces buf->adm_csf_den[0]
            adm_csf_den_scale_device(s, buf, w, h, buf_stride,
                    adm_norm_view_dist, adm_ref_display_height, s->str);

            // consumes buf->decouple_a
            // produces buf->csf_a , buf->csf_f
            adm_csf_device(s, buf, w, h, buf_stride, &p, s->str);

            // consumes buf->decouple_r, buf->csf_a, buf->csf_a
            // produces buf->adm_cm[0]
            adm_cm_device(s, buf, w, h, buf_stride, buf_stride, &p, s->str);
        }
        else {
            // consumes buf->i4_ref_dwt2.band_a , buf->i4_dis_dwt2.band_a
            // produces buf->i4_ref_dwt2.band_[ahvd] , buf->i4_dis_dwt2.band_[ahvd]
            // uses buf->tmp_ref
            adm_dwt2_s123_combined_device(s, i4_curr_ref_scale, (int32_t*)buf->tmp_ref->data, buf->i4_ref_dwt2, buf, w, h,
                    curr_ref_stride, buf_stride, scale, &p, s->str);
            adm_dwt2_s123_combined_device(s, i4_curr_dis_scale, (int32_t*)buf->tmp_dis->data, buf->i4_dis_dwt2, buf, w, h,
                    curr_dis_stride, buf_stride, scale, &p, s->str);

            w = (w + 1) / 2;
            h = (h + 1) / 2;

            // consumes buf->i4_ref_dwt2 , buf->i4_dis_dwt2
            // produces buf->i4_decouple_r , buf->i4_decouple_a
            adm_decouple_s123_device(s, buf, w, h, buf_stride, &p, s->str);

            // consumes buf->i4_ref_dwt2
            // produces buf->adm_csf_den[1,2,3]
            adm_csf_den_s123_device(
                    s, buf, scale, w, h, buf_stride,
                    adm_norm_view_dist, adm_ref_display_height, s->str);

            // consumes buf->i4_decouple_a
            // produces buf->i4_csf_a , buf->i4_csf_f
            i4_adm_csf_device(s, buf, scale, w, h, buf_stride, &p, s->str);

            // consumes buf->i4_decouple_r, buf->i4_csf_a, buf->i4_csf_a
            // produces buf->adm_cm[1,2,3]
            i4_adm_cm_device(s, buf, w, h, buf_stride, buf_stride, scale, &p, s->str);
        }

        i4_curr_ref_scale = buf->i4_ref_dwt2.band_a;
        i4_curr_dis_scale = buf->i4_dis_dwt2.band_a;

        curr_ref_stride = buf_stride;
        curr_dis_stride = buf_stride;
    }
    CHECK_CUDA(cuStreamSynchronize(s->host_stream));
    CHECK_CUDA(cuMemcpyDtoHAsync(buf->results_host, buf->tmp_res->data, sizeof(int64_t) * RES_BUFFER_SIZE, s->str));
    CHECK_CUDA(cuEventRecord(s->finished, s->str));
}