in libvmaf/src/feature/cuda/integer_adm_cuda.c [756:926]
static void integer_compute_adm_cuda(VmafFeatureExtractor *fex, AdmStateCuda *s,
VmafPicture *ref_pic, VmafPicture *dis_pic, AdmBufferCuda *buf,
double adm_enhn_gain_limit,
double adm_norm_view_dist,
int adm_ref_display_height)
{
int w = ref_pic->w[0];
int h = ref_pic->h[0];
AdmFixedParametersCuda p = {
.dwt2_db2_coeffs_lo = {15826, 27411, 7345, -4240},
.dwt2_db2_coeffs_hi = {-4240, -7345, 27411, -15826},
.dwt2_db2_coeffs_lo_sum = 46342,
.dwt2_db2_coeffs_hi_sum = 0,
.log2_w = log2(w),
.log2_h = log2(h),
.adm_ref_display_height = adm_ref_display_height,
.adm_norm_view_dist = adm_norm_view_dist,
.adm_enhn_gain_limit = adm_enhn_gain_limit,
};
const double pow2_32 = pow(2, 32);
const double pow2_21 = pow(2, 21);
const double pow2_23 = pow(2, 23);
for (unsigned scale = 0; scale < 4; ++scale) {
float factor1 = dwt_quant_step(&dwt_7_9_YCbCr_threshold[0], scale, 1,
adm_norm_view_dist, adm_ref_display_height);
float factor2 = dwt_quant_step(&dwt_7_9_YCbCr_threshold[0], scale, 2,
adm_norm_view_dist, adm_ref_display_height);
p.factor1[scale] = factor1;
p.factor2[scale] = factor2;
p.rfactor[scale*3] = 1.0f / factor1;
p.rfactor[scale*3+1] = 1.0f / factor1;
p.rfactor[scale*3+2] = 1.0f / factor2;
if (scale == 0) {
if (fabs(p.adm_norm_view_dist * p.adm_ref_display_height -
DEFAULT_ADM_NORM_VIEW_DIST * DEFAULT_ADM_REF_DISPLAY_HEIGHT) <
1.0e-8) {
p.i_rfactor[scale * 3] = 36453;
p.i_rfactor[scale * 3 + 1] = 36453;
p.i_rfactor[scale * 3 + 2] = 49417;
} else {
p.i_rfactor[scale * 3] = (uint32_t)(p.rfactor[scale * 3] * pow2_21);
p.i_rfactor[scale * 3 + 1] =
(uint32_t)(p.rfactor[scale * 3 + 1] * pow2_21);
p.i_rfactor[scale * 3 + 2] =
(uint32_t)(p.rfactor[scale * 3 + 2] * pow2_23);
}
} else {
p.i_rfactor[scale * 3] = (uint32_t)(p.rfactor[scale * 3] * pow2_32);
p.i_rfactor[scale * 3 + 1] =
(uint32_t)(p.rfactor[scale * 3 + 1] * pow2_32);
p.i_rfactor[scale * 3 + 2] =
(uint32_t)(p.rfactor[scale * 3 + 2] * pow2_32);
}
uint32_t *i_rfactor = &p.i_rfactor[scale*3];
}
CHECK_CUDA(cuMemsetD8Async(buf->tmp_res->data, 0, sizeof(int64_t) * RES_BUFFER_SIZE, s->str));
size_t curr_ref_stride;
size_t curr_dis_stride;
size_t buf_stride = buf->ind_size_x >> 2;
int32_t *i4_curr_ref_scale = NULL;
int32_t *i4_curr_dis_scale = NULL;
if (ref_pic->bpc == 8) {
curr_ref_stride = ref_pic->stride[0];
curr_dis_stride = dis_pic->stride[0];
}
else {
curr_ref_stride = dis_pic->stride[0] >> 1;
curr_dis_stride = ref_pic->stride[0] >> 1;
}
for (unsigned scale = 0; scale < 4; ++scale) {
float num_scale = 0.0;
float den_scale = 0.0;
if(scale==0) {
// run these first dwt kernels on the input iamge stream to make sure it is consumed afterwards continue
// consumes reference picture
// produces buf->ref_dwt2, buf->dis_dwt2
if (ref_pic->bpc == 8) {
dwt2_8_device(s, (const uint8_t*)ref_pic->data[0], &buf->ref_dwt2, buf->i4_ref_dwt2, (short2*)buf->tmp_ref->data, buf, w, h, curr_ref_stride, buf_stride, &p, vmaf_cuda_picture_get_stream(ref_pic));
dwt2_8_device(s, (const uint8_t*)dis_pic->data[0], &buf->dis_dwt2, buf->i4_dis_dwt2, (short2*)buf->tmp_dis->data, buf, w, h, curr_dis_stride, buf_stride, &p, vmaf_cuda_picture_get_stream(dis_pic));
}
else {
adm_dwt2_16_device(s,(uint16_t*)ref_pic->data[0], &buf->ref_dwt2, buf->i4_ref_dwt2, (short2*)buf->tmp_ref->data, buf, w, h, curr_ref_stride, buf_stride, ref_pic->bpc, &p, vmaf_cuda_picture_get_stream(ref_pic));
adm_dwt2_16_device(s,(uint16_t*)dis_pic->data[0], &buf->dis_dwt2, buf->i4_dis_dwt2, (short2*)buf->tmp_dis->data, buf, w, h, curr_dis_stride, buf_stride, dis_pic->bpc, &p, vmaf_cuda_picture_get_stream(dis_pic));
}
CHECK_CUDA(cuEventRecord(s->ref_event, vmaf_cuda_picture_get_stream(ref_pic)));
CHECK_CUDA(cuEventRecord(s->dis_event, vmaf_cuda_picture_get_stream(dis_pic)));
w = (w + 1) / 2;
h = (h + 1) / 2;
// This event ensures the input buffer is consumed
CHECK_CUDA(cuCtxPushCurrent(fex->cu_state->ctx));
CHECK_CUDA(cuStreamWaitEvent(s->str, s->dis_event, CU_EVENT_WAIT_DEFAULT));
CHECK_CUDA(cuEventDestroy(s->dis_event));
CHECK_CUDA(cuEventCreate(&s->dis_event, CU_EVENT_DEFAULT));
CHECK_CUDA(cuStreamWaitEvent(s->str, s->ref_event, CU_EVENT_WAIT_DEFAULT));
CHECK_CUDA(cuEventDestroy(s->ref_event));
CHECK_CUDA(cuEventCreate(&s->ref_event, CU_EVENT_DEFAULT));
CHECK_CUDA(cuCtxPopCurrent(NULL));
// consumes buf->ref_dwt2 , buf->dis_dwt2
// produces buf->decouple_r , buf->decouple_a
adm_decouple_device(s, buf, w, h, buf_stride, &p, s->str);
// consumes buf->ref_dwt2
// produces buf->adm_csf_den[0]
adm_csf_den_scale_device(s, buf, w, h, buf_stride,
adm_norm_view_dist, adm_ref_display_height, s->str);
// consumes buf->decouple_a
// produces buf->csf_a , buf->csf_f
adm_csf_device(s, buf, w, h, buf_stride, &p, s->str);
// consumes buf->decouple_r, buf->csf_a, buf->csf_a
// produces buf->adm_cm[0]
adm_cm_device(s, buf, w, h, buf_stride, buf_stride, &p, s->str);
}
else {
// consumes buf->i4_ref_dwt2.band_a , buf->i4_dis_dwt2.band_a
// produces buf->i4_ref_dwt2.band_[ahvd] , buf->i4_dis_dwt2.band_[ahvd]
// uses buf->tmp_ref
adm_dwt2_s123_combined_device(s, i4_curr_ref_scale, (int32_t*)buf->tmp_ref->data, buf->i4_ref_dwt2, buf, w, h,
curr_ref_stride, buf_stride, scale, &p, s->str);
adm_dwt2_s123_combined_device(s, i4_curr_dis_scale, (int32_t*)buf->tmp_dis->data, buf->i4_dis_dwt2, buf, w, h,
curr_dis_stride, buf_stride, scale, &p, s->str);
w = (w + 1) / 2;
h = (h + 1) / 2;
// consumes buf->i4_ref_dwt2 , buf->i4_dis_dwt2
// produces buf->i4_decouple_r , buf->i4_decouple_a
adm_decouple_s123_device(s, buf, w, h, buf_stride, &p, s->str);
// consumes buf->i4_ref_dwt2
// produces buf->adm_csf_den[1,2,3]
adm_csf_den_s123_device(
s, buf, scale, w, h, buf_stride,
adm_norm_view_dist, adm_ref_display_height, s->str);
// consumes buf->i4_decouple_a
// produces buf->i4_csf_a , buf->i4_csf_f
i4_adm_csf_device(s, buf, scale, w, h, buf_stride, &p, s->str);
// consumes buf->i4_decouple_r, buf->i4_csf_a, buf->i4_csf_a
// produces buf->adm_cm[1,2,3]
i4_adm_cm_device(s, buf, w, h, buf_stride, buf_stride, scale, &p, s->str);
}
i4_curr_ref_scale = buf->i4_ref_dwt2.band_a;
i4_curr_dis_scale = buf->i4_dis_dwt2.band_a;
curr_ref_stride = buf_stride;
curr_dis_stride = buf_stride;
}
CHECK_CUDA(cuStreamSynchronize(s->host_stream));
CHECK_CUDA(cuMemcpyDtoHAsync(buf->results_host, buf->tmp_res->data, sizeof(int64_t) * RES_BUFFER_SIZE, s->str));
CHECK_CUDA(cuEventRecord(s->finished, s->str));
}