static int init_fex_cuda()

in libvmaf/src/feature/cuda/integer_adm_cuda.c [1006:1144]


static int init_fex_cuda(VmafFeatureExtractor *fex, enum VmafPixelFormat pix_fmt,
        unsigned bpc, unsigned w, unsigned h)
{
    AdmStateCuda *s = fex->priv;

    (void) pix_fmt;
    (void) bpc;
    int ret = 0;

    CHECK_CUDA(cuCtxPushCurrent(fex->cu_state->ctx));
    CHECK_CUDA(cuStreamCreateWithPriority(&s->str, CU_STREAM_NON_BLOCKING, 0));
    CHECK_CUDA(cuStreamCreateWithPriority(&s->host_stream, CU_STREAM_NON_BLOCKING, 0));
    CHECK_CUDA(cuEventCreate(&s->finished, CU_EVENT_DEFAULT));
    CHECK_CUDA(cuEventCreate(&s->ref_event, CU_EVENT_DEFAULT));
    CHECK_CUDA(cuEventCreate(&s->dis_event, CU_EVENT_DEFAULT));


    CUmodule adm_cm_module, adm_csf_den_module, adm_csf_module, adm_decouple_module, adm_dwt_module;


    CHECK_CUDA(cuModuleLoadData(&adm_dwt_module, adm_dwt2_ptx));
    CHECK_CUDA(cuModuleLoadData(&adm_csf_module, adm_csf_ptx));
    CHECK_CUDA(cuModuleLoadData(&adm_decouple_module, adm_decouple_ptx));
    CHECK_CUDA(cuModuleLoadData(&adm_csf_den_module, adm_csf_den_ptx));
    CHECK_CUDA(cuModuleLoadData(&adm_cm_module, adm_cm_ptx));

    // Get DWT kernel function pointers check adm_dwt2.cu for __global__ templated kernels
    CHECK_CUDA(cuModuleGetFunction(&s->func_dwt_s123_combined_vert_kernel_0_0_int32_t,  adm_dwt_module, "dwt_s123_combined_vert_kernel_0_0_int32_t"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_dwt_s123_combined_vert_kernel_32768_16_int32_t, adm_dwt_module, "dwt_s123_combined_vert_kernel_32768_16_int32_t"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_dwt_s123_combined_hori_kernel_16384_15, adm_dwt_module, "dwt_s123_combined_hori_kernel_16384_15"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_dwt_s123_combined_hori_kernel_32768_16, adm_dwt_module, "dwt_s123_combined_hori_kernel_32768_16"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_adm_dwt2_8_vert_hori_kernel_4_16_32768_128_8_uint8_t, adm_dwt_module, "adm_dwt2_8_vert_hori_kernel_4_16_32768_128_8_uint8_t"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_adm_dwt2_8_vert_hori_kernel_4_16_32768_128_8_uint16_t, adm_dwt_module, "adm_dwt2_8_vert_hori_kernel_4_16_32768_128_8_uint16_t"));


    // Get csf kernel function pointers check adm_csf.cu for __global__ templated kernels
    CHECK_CUDA(cuModuleGetFunction(&s->func_adm_csf_kernel_1_4, adm_csf_module, "adm_csf_kernel_1_4"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_i4_adm_csf_kernel_1_4, adm_csf_module, "i4_adm_csf_kernel_1_4"));


    CHECK_CUDA(cuModuleGetFunction(&s->func_adm_decouple_kernel, adm_decouple_module, "adm_decouple_kernel"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_adm_decouple_s123_kernel, adm_decouple_module, "adm_decouple_s123_kernel"));


    CHECK_CUDA(cuModuleGetFunction(&s->func_adm_csf_den_scale_line_kernel, adm_csf_den_module, "adm_csf_den_scale_line_kernel_8_128"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_adm_csf_den_s123_line_kernel, adm_csf_den_module, "adm_csf_den_s123_line_kernel_8_128"));

    CHECK_CUDA(cuModuleGetFunction(&s->func_adm_cm_reduce_line_kernel_4, adm_cm_module, "adm_cm_reduce_line_kernel_4"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_adm_cm_line_kernel_8, adm_cm_module, "adm_cm_line_kernel_8"));
    CHECK_CUDA(cuModuleGetFunction(&s->func_i4_adm_cm_line_kernel, adm_cm_module, "i4_adm_cm_line_kernel"));


    CHECK_CUDA(cuCtxPopCurrent(NULL));

    // s->dwt2_8 = dwt2_8_device;

    s->integer_stride   = ALIGN_CEIL(w * sizeof(int32_t));
    s->buf.ind_size_x   = ALIGN_CEIL(((w + 1) / 2) * sizeof(int32_t));
    s->buf.ind_size_y   = ALIGN_CEIL(((h + 1) / 2) * sizeof(int32_t));
    size_t buf_sz_one   = s->buf.ind_size_x * ((h + 1) / 2);

    ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.data_buf, buf_sz_one * NUM_BUFS_ADM);
    if (ret) goto free_ref;
    ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_ref, (s->integer_stride * 4 * ((h + 1) / 2)));
    if (ret) goto free_ref;
    ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_dis, (s->integer_stride * 4 * ((h + 1) / 2)));
    if (ret) goto free_ref;
    ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_accum, sizeof(uint64_t) * 3 * w * h);
    if (ret) goto free_ref;
    ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_accum_h, sizeof(uint64_t) * 3 * h);
    if (ret) goto free_ref;
    ret = vmaf_cuda_buffer_alloc(fex->cu_state, &s->buf.tmp_res, sizeof(uint64_t) * RES_BUFFER_SIZE);
    if (ret) goto free_ref;
    ret = vmaf_cuda_buffer_host_alloc(fex->cu_state, &s->buf.results_host, sizeof(uint64_t) * RES_BUFFER_SIZE);
    if (ret) goto free_ref;

    CUdeviceptr cu_res_top;
    ret = vmaf_cuda_buffer_get_dptr(s->buf.tmp_res, &cu_res_top);
    if (ret) goto free_ref;

    cu_res_top = init_res_cm_cuda(fex->cu_state, s->buf.adm_cm, cu_res_top);
    cu_res_top = init_res_csf_cuda(fex->cu_state, s->buf.adm_csf_den, cu_res_top);

    CUdeviceptr cu_data_top;
    vmaf_cuda_buffer_get_dptr(s->buf.data_buf, &cu_data_top);

    cu_data_top = init_dwt_band_cuda(fex->cu_state, &s->buf.ref_dwt2, cu_data_top, buf_sz_one / 2);
    cu_data_top = init_dwt_band_cuda(fex->cu_state, &s->buf.dis_dwt2, cu_data_top, buf_sz_one / 2);
    cu_data_top = init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.decouple_r, cu_data_top, buf_sz_one / 2);
    cu_data_top = init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.decouple_a, cu_data_top, buf_sz_one / 2);
    cu_data_top = init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.csf_a, cu_data_top, buf_sz_one / 2);
    cu_data_top = init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.csf_f, cu_data_top, buf_sz_one / 2);

    cu_data_top = i4_init_dwt_band_cuda(fex->cu_state, &s->buf.i4_ref_dwt2, cu_data_top, buf_sz_one);
    cu_data_top = i4_init_dwt_band_cuda(fex->cu_state, &s->buf.i4_dis_dwt2, cu_data_top, buf_sz_one);
    cu_data_top = i4_init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.i4_decouple_r, cu_data_top, buf_sz_one);
    cu_data_top = i4_init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.i4_decouple_a, cu_data_top, buf_sz_one);
    cu_data_top = i4_init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.i4_csf_a, cu_data_top, buf_sz_one);
    cu_data_top = i4_init_dwt_band_hvd_cuda(fex->cu_state, &s->buf.i4_csf_f, cu_data_top, buf_sz_one);

    s->write_score_parameters = malloc(sizeof(write_score_parameters_adm));
    ((write_score_parameters_adm*)s->write_score_parameters)->s = s;


    s->feature_name_dict =
        vmaf_feature_name_dict_from_provided_features(fex->provided_features,
                fex->options, s);
    if (!s->feature_name_dict) goto free_ref;

    return 0;

free_ref:
    if (s->buf.data_buf) {
        ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.data_buf);
        free(s->buf.data_buf);
    }
    if (s->buf.tmp_ref) {
        ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.tmp_ref);
        free(s->buf.tmp_ref);
    }
    if (s->buf.tmp_accum) {
        ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.tmp_accum);
        free(s->buf.tmp_accum);
    }
    if (s->buf.tmp_accum_h) {
        ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.tmp_accum_h);
        free(s->buf.tmp_accum_h);
    }
    if (s->buf.tmp_res) {
        ret |= vmaf_cuda_buffer_free(fex->cu_state, s->buf.tmp_res);
        free(s->buf.tmp_res);
    }
    if (s->buf.results_host) {
        ret |= vmaf_cuda_buffer_host_free(fex->cu_state, s->buf.results_host);
    }
    vmaf_dictionary_free(&s->feature_name_dict);

    return -ENOMEM;
}