in src/nv-wavenet/wavenet_infer_wrapper.cpp [31:106]
uint64_t construct(int sample_count,
int batch_size,
at::Tensor embed_prev_tensor,
at::Tensor embed_curr_tensor,
at::Tensor conv_init_tensor,
at::Tensor conv_init_bias_tensor,
at::Tensor conv_out_tensor,
at::Tensor conv_out_bias_tensor,
at::Tensor conv_end_tensor,
at::Tensor conv_end_bias_tensor,
std::vector<at::Tensor>& Wprev, std::vector<at::Tensor>& Wcur,std::vector<at::Tensor>& Bh,
std::vector<at::Tensor>& Wres, std::vector<at::Tensor>& Bres,
std::vector<at::Tensor>& Wskip, std::vector<at::Tensor>& Bskip,
int num_layers,
int use_embed_tanh,
int max_dilation,
int implementation) {
float* embedding_prev = embed_prev_tensor.data<float>();
float* embedding_curr = embed_curr_tensor.data<float>();
float* conv_init = conv_init_tensor.data<float>();
float* conv_init_bias = conv_init_bias_tensor.data<float>();
float* conv_out = conv_out_tensor.data<float>();
float* conv_out_bias = conv_out_bias_tensor.data<float>();
float* conv_end = conv_end_tensor.data<float>();
float* conv_end_bias = conv_end_bias_tensor.data<float>();
float** in_layer_weights_prev = (float**) malloc(num_layers*sizeof(float*));
float** in_layer_weights_curr = (float**) malloc(num_layers*sizeof(float*));
float** in_layer_biases = (float**) malloc(num_layers*sizeof(float*));
float** res_layer_weights = (float**) malloc(num_layers*sizeof(float*));
float** res_layer_biases = (float**) malloc(num_layers*sizeof(float*));
float** skip_layer_weights = (float**) malloc(num_layers*sizeof(float*));
float** skip_layer_biases = (float**) malloc(num_layers*sizeof(float*));
for (int i=0; i < num_layers; i++) {
in_layer_weights_prev[i] = Wprev[i].data<float>();
in_layer_weights_curr[i] = Wcur[i].data<float>();
in_layer_biases[i] = Bh[i].data<float>();
res_layer_weights[i] = Wres[i].data<float>();
res_layer_biases[i] = Bres[i].data<float>();
skip_layer_weights[i] = Wskip[i].data<float>();
skip_layer_biases[i] = Bskip[i].data<float>();
}
void* wavenet = wavenet_construct(sample_count,
batch_size,
embedding_prev,
embedding_curr,
num_layers,
max_dilation,
in_layer_weights_prev,
in_layer_weights_curr,
in_layer_biases,
res_layer_weights,
res_layer_biases,
skip_layer_weights,
skip_layer_biases,
conv_init,
conv_init_bias,
conv_out,
conv_out_bias,
conv_end,
conv_end_bias,
use_embed_tanh,
implementation
);
free(in_layer_weights_prev);
free(in_layer_weights_curr);
free(in_layer_biases);
free(res_layer_weights);
free(res_layer_biases);
free(skip_layer_weights);
free(skip_layer_biases);
return reinterpret_cast<uint64_t>(wavenet);
}