uint64_t construct()

in src/nv-wavenet/wavenet_infer_wrapper.cpp [31:106]


uint64_t construct(int sample_count,
                   int batch_size,
                   at::Tensor embed_prev_tensor,
                   at::Tensor embed_curr_tensor,
                   at::Tensor conv_init_tensor,
                   at::Tensor conv_init_bias_tensor,
                   at::Tensor conv_out_tensor,
                   at::Tensor conv_out_bias_tensor,
                   at::Tensor conv_end_tensor,
                   at::Tensor conv_end_bias_tensor,
                   std::vector<at::Tensor>& Wprev, std::vector<at::Tensor>& Wcur,std::vector<at::Tensor>& Bh,
                   std::vector<at::Tensor>& Wres, std::vector<at::Tensor>& Bres,
                   std::vector<at::Tensor>& Wskip, std::vector<at::Tensor>& Bskip,
                   int num_layers,
                   int use_embed_tanh,
                   int max_dilation,
                   int implementation) {
    float* embedding_prev = embed_prev_tensor.data<float>();
    float* embedding_curr = embed_curr_tensor.data<float>();
    float* conv_init = conv_init_tensor.data<float>();
    float* conv_init_bias = conv_init_bias_tensor.data<float>();
    float* conv_out = conv_out_tensor.data<float>();
    float* conv_out_bias = conv_out_bias_tensor.data<float>();
    float* conv_end = conv_end_tensor.data<float>();
    float* conv_end_bias = conv_end_bias_tensor.data<float>();

    float** in_layer_weights_prev = (float**) malloc(num_layers*sizeof(float*));
    float** in_layer_weights_curr = (float**) malloc(num_layers*sizeof(float*));
    float** in_layer_biases = (float**) malloc(num_layers*sizeof(float*));
    float** res_layer_weights = (float**) malloc(num_layers*sizeof(float*));
    float** res_layer_biases = (float**) malloc(num_layers*sizeof(float*));
    float** skip_layer_weights = (float**) malloc(num_layers*sizeof(float*));
    float** skip_layer_biases = (float**) malloc(num_layers*sizeof(float*));
    for (int i=0; i < num_layers; i++) {
        in_layer_weights_prev[i] = Wprev[i].data<float>();
        in_layer_weights_curr[i] = Wcur[i].data<float>();
        in_layer_biases[i] = Bh[i].data<float>();
        res_layer_weights[i] = Wres[i].data<float>();
        res_layer_biases[i] = Bres[i].data<float>();
        skip_layer_weights[i] = Wskip[i].data<float>();
        skip_layer_biases[i] = Bskip[i].data<float>();
    }

    void* wavenet = wavenet_construct(sample_count,
                                      batch_size,
                                      embedding_prev,
                                      embedding_curr,
                                      num_layers,
                                      max_dilation,
                                      in_layer_weights_prev,
                                      in_layer_weights_curr,
                                      in_layer_biases,
                                      res_layer_weights,
                                      res_layer_biases,
                                      skip_layer_weights,
                                      skip_layer_biases,
                                      conv_init,
                                      conv_init_bias,
                                      conv_out,
                                      conv_out_bias,
                                      conv_end,
                                      conv_end_bias,
                                      use_embed_tanh,
                                      implementation
                                     );
    
    free(in_layer_weights_prev);
    free(in_layer_weights_curr);
    free(in_layer_biases);
    free(res_layer_weights);
    free(res_layer_biases);
    free(skip_layer_weights);
    free(skip_layer_biases);
    
    return reinterpret_cast<uint64_t>(wavenet);
}