in libvmaf/src/feature/cuda/integer_adm_cuda.c [1006:1144]
static int init_fex_cuda(VmafFeatureExtractor *fex, enum VmafPixelFormat pix_fmt,
unsigned bpc, unsigned w, unsigned h)
{
AdmStateCuda *s = fex->priv;
(void) pix_fmt;
(void) bpc;
int ret = 0;
CHECK_CUDA(cuCtxPushCurrent(fex->cu_state->ctx));
CHECK_CUDA(cuStreamCreateWithPriority(&s->str, CU_STREAM_NON_BLOCKING, 0));
CHECK_CUDA(cuStreamCreateWithPriority(&s->host_stream, CU_STREAM_NON_BLOCKING, 0));
CHECK_CUDA(cuEventCreate(&s->finished, CU_EVENT_DEFAULT));
CHECK_CUDA(cuEventCreate(&s->ref_event, CU_EVENT_DEFAULT));
CHECK_CUDA(cuEventCreate(&s->dis_event, CU_EVENT_DEFAULT));
CUmodule adm_cm_module, adm_csf_den_module, adm_csf_module, adm_decouple_module, adm_dwt_module;
CHECK_CUDA(cuModuleLoadData(&adm_dwt_module, adm_dwt2_ptx));
CHECK_CUDA(cuModuleLoadData(&adm_csf_module, adm_csf_ptx));
CHECK_CUDA(cuModuleLoadData(&adm_decouple_module, adm_decouple_ptx));
CHECK_CUDA(cuModuleLoadData(&adm_csf_den_module, adm_csf_den_ptx));
CHECK_CUDA(cuModuleLoadData(&adm_cm_module, adm_cm_ptx));
// Get DWT kernel function pointers check adm_dwt2.cu for __global__ templated kernels
CHECK_CUDA(cuModuleGetFunction(&s->func_dwt_s123_combined_vert_kernel_0_0_int32_t, adm_dwt_module, "dwt_s123_combined_vert_kernel_0_0_int32_t"));
CHECK_CUDA(cuModuleGetFunction(&s->func_dwt_s123_combined_vert_kernel_32768_16_int32_t, adm_dwt_module, "dwt_s123_combined_vert_kernel_32768_16_int32_t"));
CHECK_CUDA(cuModuleGetFunction(&s->func_dwt_s123_combined_hori_kernel_16384_15, adm_dwt_module, "dwt_s123_combined_hori_kernel_16384_15"));
CHECK_CUDA(cuModuleGetFunction(&s->func_dwt_s123_combined_hori_kernel_32768_16, adm_dwt_module, "dwt_s123_combined_hori_kernel_32768_16"));
CHECK_CUDA(cuModuleGetFunction(&s->func_adm_dwt2_8_vert_hori_kernel_4_16_32768_128_8_uint8_t, adm_dwt_module, "adm_dwt2_8_vert_hori_kernel_4_16_32768_128_8_uint8_t"));
CHECK_CUDA(cuModuleGetFunction(&s->func_adm_dwt2_8_vert_hori_kernel_4_16_32768_128_8_uint16_t, adm_dwt_module, "adm_dwt2_8_vert_hori_kernel_4_16_32768_128_8_uint16_t"));
// Get csf kernel function pointers check adm_csf.cu for __global__ templated kernels
CHECK_CUDA(cuModuleGetFunction(&s->func_adm_csf_kernel_1_4, adm_csf_module, "adm_csf_kernel_1_4"));
CHECK_CUDA(cuModuleGetFunction(&s->func_i4_adm_csf_kernel_1_4, adm_csf_module, "i4_adm_csf_kernel_1_4"));
CHECK_CUDA(cuModuleGetFunction(&s->func_adm_decouple_kernel, adm_decouple_module, "adm_decouple_kernel"));
CHECK_CUDA(cuModuleGetFunction(&s->func_adm_decouple_s123_kernel, adm_decouple_module, "adm_decouple_s123_kernel"));
CHECK_CUDA(cuModuleGetFunction(&s->func_adm_csf_den_scale_line_kernel, adm_csf_den_module, "adm_csf_den_scale_line_kernel_8_128"));
CHECK_CUDA(cuModuleGetFunction(&s->func_adm_csf_den_s123_line_kernel, adm_csf_den_module, "adm_csf_den_s123_line_kernel_8_128"));
CHECK_CUDA(cuModuleGetFunction(&s->func_adm_cm_reduce_line_kernel_4, adm_cm_module, "adm_cm_reduce_line_kernel_4"));
CHECK_CUDA(cuModuleGetFunction(&s->func_adm_cm_line_kernel_8, adm_cm_module, "adm_cm_line_kernel_8"));
CHECK_CUDA(cuModuleGetFunction(&s->func_i4_adm_cm_line_kernel, adm_cm_module, "i4_adm_cm_line_kernel"));
CHECK_CUDA(cuCtxPopCurrent(NULL));
// s->dwt2_8 = dwt2_8_device;
s->integer_stride = ALIGN_CEIL(w * sizeof(int32_t));
s->buf.ind_size_x = ALIGN_CEIL(((w + 1) / 2) * sizeof(int32_t));
s->buf.ind_size_y = ALIGN_CEIL(((h + 1) / 2) * sizeof(int32_t));
size_t buf_sz_one = s->buf.ind_size_x * ((h + 1) / 2);
ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.data_buf, buf_sz_one * NUM_BUFS_ADM);
if (ret) goto free_ref;
ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_ref, (s->integer_stride * 4 * ((h + 1) / 2)));
if (ret) goto free_ref;
ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_dis, (s->integer_stride * 4 * ((h + 1) / 2)));
if (ret) goto free_ref;
ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_accum, sizeof(uint64_t) * 3 * w * h);
if (ret) goto free_ref;
ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_accum_h, sizeof(uint64_t) * 3 * h);
if (ret) goto free_ref;
ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_res, sizeof(uint64_t) * RES_BUFFER_SIZE);
if (ret) goto free_ref;
ret = vmaf_cuda_buffer_host_alloc(fex->cu_state, &s->buf.results_host, sizeof(uint64_t) * RES_BUFFER_SIZE);
if (ret) goto free_ref;
CUdeviceptr cu_res_top;
ret = vmaf_cuda_buffer_get_dptr(s->buf.tmp_res, &cu_res_top);
if (ret) goto free_ref;
cu_res_top = init_res_cm_cuda(fex->cu_state, s->buf.adm_cm, cu_res_top);
cu_res_top = init_res_csf_cuda(fex->cu_state, s->buf.adm_csf_den, cu_res_top);
CUdeviceptr cu_data_top;
vmaf_cuda_buffer_get_dptr(s->buf.data_buf, &cu_data_top);
cu_data_top = init_dwt_band_cuda(fex->cu_state, &s->buf.ref_dwt2, cu_data_top, buf_sz_one / 2);
cu_data_top = init_dwt_band_cuda(fex->cu_state, &s->buf.dis_dwt2, cu_data_top, buf_sz_one / 2);
cu_data_top = init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.decouple_r, cu_data_top, buf_sz_one / 2);
cu_data_top = init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.decouple_a, cu_data_top, buf_sz_one / 2);
cu_data_top = init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.csf_a, cu_data_top, buf_sz_one / 2);
cu_data_top = init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.csf_f, cu_data_top, buf_sz_one / 2);
cu_data_top = i4_init_dwt_band_cuda(fex->cu_state, &s->buf.i4_ref_dwt2, cu_data_top, buf_sz_one);
cu_data_top = i4_init_dwt_band_cuda(fex->cu_state, &s->buf.i4_dis_dwt2, cu_data_top, buf_sz_one);
cu_data_top = i4_init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.i4_decouple_r, cu_data_top, buf_sz_one);
cu_data_top = i4_init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.i4_decouple_a, cu_data_top, buf_sz_one);
cu_data_top = i4_init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.i4_csf_a, cu_data_top, buf_sz_one);
cu_data_top = i4_init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.i4_csf_f, cu_data_top, buf_sz_one);
s->write_score_parameters = malloc(sizeof(write_score_parameters_adm));
((write_score_parameters_adm*)s->write_score_parameters)->s = s;
s->feature_name_dict =
vmaf_feature_name_dict_from_provided_features(fex->provided_features,
fex->options, s);
if (!s->feature_name_dict) goto free_ref;
return 0;
free_ref:
if (s->buf.data_buf) {
ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.data_buf);
free(s->buf.data_buf);
}
if (s->buf.tmp_ref) {
ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.tmp_ref);
free(s->buf.tmp_ref);
}
if (s->buf.tmp_accum) {
ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.tmp_accum);
free(s->buf.tmp_accum);
}
if (s->buf.tmp_accum_h) {
ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.tmp_accum_h);
free(s->buf.tmp_accum_h);
}
if (s->buf.tmp_res) {
ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.tmp_res);
free(s->buf.tmp_res);
}
if (s->buf.results_host) {
ret |= vmaf_cuda_buffer_host_free(fex->cu_state, s->buf.results_host);
}
vmaf_dictionary_free(&s->feature_name_dict);
return -ENOMEM;
}