void CudnnConvolution::InitCudnn()

in src/model/layer/cudnn_convolution.cc [55:159]


void CudnnConvolution::InitCudnn(const Tensor &input) {
  DataType dtype = input.data_type();
  auto dev = input.device();
  Context *ctx = dev->context(0);
  size_t batchsize = input.shape(0);
  if (!has_init_cudnn_) {
    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
    CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
    if (bias_term_)
      CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
    CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
    CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
  }

  CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
                                         GetCudnnDataType(dtype), batchsize,
                                         channels_, height_, width_));
  CUDNN_CHECK(cudnnSetTensor4dDescriptor(
                y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
                num_filters_, conv_height_, conv_width_));
  if (bias_term_)
    CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
                                           GetCudnnDataType(dtype), 1,
                                           num_filters_, 1, 1));
  CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_,
              stride_h_, stride_w_, 1, 1,  // dilation x and y
              CUDNN_CROSS_CORRELATION
#if CUDNN_MAJOR >= 7
              , GetCudnnDataType(dtype)
#endif  // CUDNN_MAJOR
                                             ));
  CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype),
                                         CUDNN_TENSOR_NCHW, num_filters_,
                                         channels_, kernel_h_, kernel_w_));
  if (prefer_ == "fastest" || prefer_ == "limited_workspace" ||
      prefer_ == "no_workspace") {
    cudnnConvolutionFwdPreference_t fwd_pref;
    cudnnConvolutionBwdFilterPreference_t bwd_filt_pref;
    cudnnConvolutionBwdDataPreference_t bwd_data_pref;
    if (prefer_ == "fastest") {
      fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
      bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
      bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
    } else if (prefer_ == "limited_workspace") {
      fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;
      bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;
      bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
    } else {
      fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
      bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
      bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
    }
    CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
                  ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
                  workspace_byte_limit_, &fp_alg_));
    CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
                  ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
                  bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
    // deprecated in cudnn v7
    CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
                  ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
                  bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
  } else if (prefer_ == "autotune") {
    const int topk = 1;
    int num_fp_alg, num_bp_filt_alg, num_bp_data_alg;
    cudnnConvolutionFwdAlgoPerf_t fp_alg_perf[topk];
    cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk];
    cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk];
    CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
                  ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
                  &num_fp_alg, fp_alg_perf));
    fp_alg_ = fp_alg_perf[0].algo;
    CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
                  ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
                  &num_bp_filt_alg, bp_filt_perf));
    bp_filter_alg_ = bp_filt_perf[0].algo;
    CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm(
                  ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
                  &num_bp_data_alg, bp_data_perf));
    bp_data_alg_ = bp_data_perf[0].algo;
  } else {
    LOG(FATAL) << "Preferred algorithm is not available!";
  }

  size_t fp_byte, bp_data_byte, bp_filter_byte;
  CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
                ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
                &fp_byte));
  CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
                ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
                bp_data_alg_, &bp_data_byte));
  CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
                ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
                bp_filter_alg_, &bp_filter_byte));
  workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
                     sizeof(float) +
                     1;
  if (workspace_count_ * sizeof(float) > workspace_byte_limit_)
    LOG(WARNING) << "The required memory for workspace ("
                 << workspace_count_ * sizeof(float)
                 << ") is larger than the expected Bytes ("
                 << workspace_byte_limit_ << ")";
  workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
  has_init_cudnn_ = true;
}