in src/model/layer/cudnn_convolution.cc [55:159]
void CudnnConvolution::InitCudnn(const Tensor &input) {
DataType dtype = input.data_type();
auto dev = input.device();
Context *ctx = dev->context(0);
size_t batchsize = input.shape(0);
if (!has_init_cudnn_) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
if (bias_term_)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
}
CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype), batchsize,
channels_, height_, width_));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
num_filters_, conv_height_, conv_width_));
if (bias_term_)
CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype), 1,
num_filters_, 1, 1));
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_,
stride_h_, stride_w_, 1, 1, // dilation x and y
CUDNN_CROSS_CORRELATION
#if CUDNN_MAJOR >= 7
, GetCudnnDataType(dtype)
#endif // CUDNN_MAJOR
));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype),
CUDNN_TENSOR_NCHW, num_filters_,
channels_, kernel_h_, kernel_w_));
if (prefer_ == "fastest" || prefer_ == "limited_workspace" ||
prefer_ == "no_workspace") {
cudnnConvolutionFwdPreference_t fwd_pref;
cudnnConvolutionBwdFilterPreference_t bwd_filt_pref;
cudnnConvolutionBwdDataPreference_t bwd_data_pref;
if (prefer_ == "fastest") {
fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
} else if (prefer_ == "limited_workspace") {
fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;
bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;
bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
} else {
fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
}
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
workspace_byte_limit_, &fp_alg_));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
// deprecated in cudnn v7
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
} else if (prefer_ == "autotune") {
const int topk = 1;
int num_fp_alg, num_bp_filt_alg, num_bp_data_alg;
cudnnConvolutionFwdAlgoPerf_t fp_alg_perf[topk];
cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk];
cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk];
CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
&num_fp_alg, fp_alg_perf));
fp_alg_ = fp_alg_perf[0].algo;
CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
&num_bp_filt_alg, bp_filt_perf));
bp_filter_alg_ = bp_filt_perf[0].algo;
CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm(
ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
&num_bp_data_alg, bp_data_perf));
bp_data_alg_ = bp_data_perf[0].algo;
} else {
LOG(FATAL) << "Preferred algorithm is not available!";
}
size_t fp_byte, bp_data_byte, bp_filter_byte;
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
&fp_byte));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
bp_data_alg_, &bp_data_byte));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
bp_filter_alg_, &bp_filter_byte));
workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
sizeof(float) +
1;
if (workspace_count_ * sizeof(float) > workspace_byte_limit_)
LOG(WARNING) << "The required memory for workspace ("
<< workspace_count_ * sizeof(float)
<< ") is larger than the expected Bytes ("
<< workspace_byte_limit_ << ")";
workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
has_init_cudnn_ = true;
}