in caffe2/operators/conv_op_cudnn.cc [898:1470]
bool CudnnConvGradientOp::DoRunWithType() {
auto& X = Input(INPUT);
auto& filter = Input(FILTER);
auto& dY = Input(OUTPUT_GRAD);
CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5);
CAFFE_ENFORCE(filter.dim() >= 3 && filter.dim() <= 5);
const int M = filter.dim32(0);
int N = 0, C = 0, H = 0, W = 0, D = 0, H_out = 0, W_out = 0, D_out = 0;
int group_offset_X = 0, group_offset_Y = 0;
switch (order_) {
case StorageOrder::NHWC:
N = X.dim32(0);
H = X.dim32(1);
W = X.dim() > 3 ? X.dim32(2) : 1;
D = X.dim() > 4 ? X.dim32(3) : 1;
C = X.dim32(X.dim() - 1);
H_out = dY.dim32(1);
W_out = dY.dim() > 3 ? dY.dim32(2) : 1;
D_out = dY.dim() > 4 ? dY.dim32(3) : 1;
for (int i = 0; i < kernel_.size(); ++i) {
CAFFE_ENFORCE_EQ(filter.dim32(i + 1), kernel_[i]);
}
CAFFE_ENFORCE_EQ(filter.dim32(filter.dim() - 1), C / group_);
group_offset_X = C / group_;
group_offset_Y = M / group_;
break;
case StorageOrder::NCHW:
N = X.dim32(0);
C = X.dim32(1);
H = X.dim32(2);
W = X.dim() > 3 ? X.dim32(3) : 1;
D = X.dim() > 4 ? X.dim32(4) : 1;
H_out = dY.dim32(2);
W_out = dY.dim() > 3 ? dY.dim32(3) : 1;
D_out = dY.dim() > 4 ? dY.dim32(4) : 1;
CAFFE_ENFORCE_EQ(filter.dim32(1), C / group_);
for (int i = 0; i < kernel_.size(); ++i) {
CAFFE_ENFORCE_EQ(filter.dim32(i + 2), kernel_[i]);
}
group_offset_X = C / group_ * H * W * D;
group_offset_Y = M / group_ * H_out * W_out * D_out;
break;
default:
LOG(FATAL) << "Unknown storage order: " << order_;
}
CAFFE_ENFORCE(
C % group_ == 0,
"If you set group, the number of input channels should be divisible "
"by group.");
CAFFE_ENFORCE(
M % group_ == 0,
"If you set group, the number of output channels should be divisible "
"by group.");
#if !CUDNN_VERSION_MIN(7, 0, 0)
int group_offset_filter = filter.numel() / group_;
#endif
if (kernel_.size() == 1) {
ConvPoolOpBase<CUDAContext>::ComputePads({H});
} else if (kernel_.size() == 2) {
ConvPoolOpBase<CUDAContext>::ComputePads({H, W});
} else if (kernel_.size() == 3) {
ConvPoolOpBase<CUDAContext>::ComputePads({H, W, D});
} else {
CAFFE_THROW("Unsupported kernel size:", kernel_.size());
}
auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T_DW>());
if (N == 0) {
math::Set<T_DW, CUDAContext>(
dfilter->numel(),
T_DW(0),
dfilter->template mutable_data<T_DW>(),
&context_);
if (!no_bias_) {
auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<T_DB>());
math::Set<T_DB, CUDAContext>(
dbias->numel(),
T_DB(0),
dbias->template mutable_data<T_DB>(),
&context_);
}
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
auto* dX = Output(
no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
X.sizes(),
at::dtype<T_DX>());
dX->template mutable_data<T_DX>();
}
return true;
}
// Set up the cudnn algorithms & workspace if necessary
bool input_changed = (X.sizes() != cudnn_input_dims_);
bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
if (input_changed || filter_changed) {
VLOG(1) << "Changing the cudnn descriptor configurations.";
if (input_changed) {
cudnn_input_dims_ = X.sizes().vec();
SetTensorNdDescriptorWithGroup<T_X>(X.dim(), bottom_desc_, N, C, H, W, D);
}
if (filter_changed) {
cudnn_filter_dims_ = filter.sizes().vec();
if (kernel_.size() == 1 || kernel_.size() == 2) {
#if CUDNN_VERSION_MIN(7, 0, 0)
const int MM = M;
#else
const int MM = M / group_;
#endif
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
filter_desc_,
cudnnTypeWrapper<T_W>::type,
GetCudnnTensorFormat(order_),
MM,
C / group_,
kernel_h(),
kernel_.size() == 1 ? 1 : kernel_w()));
} else {
vector<int> dims(filter.sizes().begin(), filter.sizes().end());
#if !CUDNN_VERSION_MIN(7, 0, 0)
// We only need to divide dims by group_ when CUDNN version < 7.0
// see CUDA group convolution doc: https://fburl.com/dgj6dvpd
order_ == StorageOrder::NCHW ? dims[1] /= group_
: dims[filter.ndim() - 1] /= group_;
#endif
CUDNN_ENFORCE(cudnnSetFilterNdDescriptor(
filter_desc_,
cudnnTypeWrapper<T_W>::type,
GetCudnnTensorFormat(order_),
dims.size(),
dims.data()));
}
if (!no_bias_) {
if (kernel_.size() == 1 || kernel_.size() == 2) {
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
bias_desc_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T_B>::type,
1,
M,
1,
1));
} else {
std::vector<int> bias_dims(X.dim(), 1);
bias_dims[1] = M;
std::vector<int> strides = {M, 1, 1, 1, 1, 1};
CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
bias_desc_,
cudnnTypeWrapper<T_B>::type,
X.dim() > 3 ? X.dim() : 4,
bias_dims.data(),
strides.data()));
}
}
}
// Set the output
SetTensorNdDescriptorWithGroup<T_DX>(
X.dim(), top_desc_, N, M, H_out, W_out, D_out);
// Set the output with descriptor useful for bias addition in one run.
if (kernel_.size() == 1 || kernel_.size() == 2) {
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
top_desc_for_bias_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T_B>::type,
N,
M,
H_out,
W_out));
} else {
vector<int> dims = {N, M, H_out, W_out, D_out};
vector<int> strides = {M * H_out * W_out * D_out,
H_out * W_out * D_out,
W_out * D_out,
D_out,
1};
CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
top_desc_for_bias_,
cudnnTypeWrapper<T_B>::type,
X.dim() > 3 ? X.dim() : 4,
dims.data(),
strides.data()));
}
compute_type_ = DetermineComputeTypeFromInput(X);
SetConvDescFromArguments();
DuplicateConvDesc(
conv_desc_, kernel_.size(), dilation_.size(), bwd_filter_conv_desc_);
DuplicateConvDesc(
conv_desc_, kernel_.size(), dilation_.size(), bwd_data_conv_desc_);
#if CUDNN_VERSION_MIN(7, 0, 0)
if (enable_tensor_core_) {
CUDNN_ENFORCE(cudnnSetConvolutionMathType(
bwd_filter_conv_desc_, CUDNN_TENSOR_OP_MATH));
CUDNN_ENFORCE(cudnnSetConvolutionMathType(
bwd_data_conv_desc_, CUDNN_TENSOR_OP_MATH));
}
// set cuDNN groups if appropriate
CUDNN_CHECK(cudnnSetConvolutionGroupCount(bwd_filter_conv_desc_, group_));
CUDNN_CHECK(cudnnSetConvolutionGroupCount(bwd_data_conv_desc_, group_));
#endif
// Choose dW algorithm
if (force_algo_[ALGO_WGRAD] >= 0) {
bwd_filter_algo_ =
(cudnnConvolutionBwdFilterAlgo_t)force_algo_[ALGO_WGRAD];
} else if (deterministic_) {
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else if (exhaustive_search_) {
// Even when FP16 compute is supported and requested, try FP32
// because it may be faster. However, if FP32 compute is specified,
// FP16 is not a suitable alternative - early out from the loop.
std::array<ConvBwdFilterAlgorithmWithCost, 2> algosToCompare;
for (int i = 0; i < 2; i++) {
SetConvDescComputeType(bwd_filter_conv_desc_, kComputeTypesToTry[i]);
algosToCompare[i] = filter_algo_cache_.getAlgorithm(
X.sizes(), filter.sizes(), kComputeTypesToTry[i], [&]() {
VLOG(1) << "CUDNN Convolution bwd: doing filter exhaustive"
<< "search for " << kComputePassNames[i];
// When we do an exhaustive search, we will ignore the workspace
// size limit and simply go for the fastest algorithm. If you
// happen to run out of memory later, you will be on your own...
int returned_algo_count;
// We clean up the current workspace memory so that the forward
// algorithm is free to allocate memory.
// Actually run the search.
std::array<
cudnnConvolutionBwdFilterAlgoPerf_t,
kNUM_CUDNN_BWD_FILTER_ALGS>
filter_perf_stat;
cudnn_wrapper_.with_cudnn_state(
cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(cudnnFindConvolutionBackwardFilterAlgorithmEx(
state->cudnn_handle(),
bottom_desc_,
X.template data<T_X>(),
top_desc_,
dY.template data<T_DY>(),
bwd_filter_conv_desc_,
filter_desc_,
dfilter->template mutable_data<T_DW>(),
kNUM_CUDNN_BWD_FILTER_ALGS,
&returned_algo_count,
filter_perf_stat.data(),
state->workspace().get(cudnn_ws_nbytes_limit_),
cudnn_ws_nbytes_limit_));
});
LogCuDNNPerfStats(filter_perf_stat, returned_algo_count);
float algo_time =
filter_perf_stat[0].status == CUDNN_STATUS_SUCCESS
? filter_perf_stat[0].time
: 1e10;
return ConvBwdFilterAlgorithmWithCost(
filter_perf_stat[0].algo, algo_time);
});
// When set to fp32 compute, don't try fp16
if (compute_type_ == CUDNN_DATA_FLOAT) {
break;
}
}
if (compute_type_ == CUDNN_DATA_FLOAT) {
// For FP32 compute, just use the best FP32 algorithm
bwd_filter_algo_ = std::get<0>(algosToCompare[0]);
} else {
// For FP16 compute, choose algo with fastest execution
int bestAlgoIndex =
(std::get<1>(algosToCompare[0]) < std::get<1>(algosToCompare[1]))
? 0
: 1;
bwd_filter_algo_ = std::get<0>(algosToCompare[bestAlgoIndex]);
SetConvDescComputeType(
bwd_filter_conv_desc_, kComputeTypesToTry[bestAlgoIndex]);
}
} else {
// choose backward algorithm for filter
constexpr int nalgo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
int valid_algos;
cudnnConvolutionBwdFilterAlgoPerf_t algos[nalgo];
CUDNN_ENFORCE(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
cudnn_wrapper_.inline_cudnn_handle(),
bottom_desc_,
top_desc_,
bwd_filter_conv_desc_,
filter_desc_,
nalgo,
&valid_algos,
algos));
bool found = false;
for (int i = 0; i < valid_algos; i++) {
auto a = algos[i];
if (a.memory <= cudnn_ws_nbytes_limit_) {
bwd_filter_algo_ = a.algo;
found = true;
break;
}
}
CAFFE_ENFORCE(found, "Unable to find algorithms for cuDNN backward filter");
}
// Pick dX algo if needed
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
if (force_algo_[ALGO_DGRAD] >= 0) {
bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD];
} else if (deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else if (exhaustive_search_) {
// Even when FP16 compute is supported and requested, try FP32
// because it may be faster. However, if FP32 compute is specified,
// FP16 is not a suitable alternative - early out from the loop.
std::array<ConvBwdDataAlgorithmWithCost, 2> algosToCompare;
for (int i = 0; i < 2; i++) {
SetConvDescComputeType(bwd_data_conv_desc_, kComputeTypesToTry[i]);
algosToCompare[i] = data_algo_cache_.getAlgorithm(
X.sizes(), filter.sizes(), kComputeTypesToTry[i], [&]() {
VLOG(1) << "CUDNN Convolution bwd: doing data exhaustive"
<< "search for " << kComputePassNames[i];
int returned_algo_count;
std::array<
cudnnConvolutionBwdDataAlgoPerf_t,
kNUM_CUDNN_BWD_DATA_ALGS>
data_perf_stat;
cudnn_wrapper_.with_cudnn_state(
cudnn_state_, [&](CuDNNState* state) {
auto* dX = Output(
no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
X.sizes(),
at::dtype<T_DX>());
const T_W* filter_data = filter.template data<T_W>();
const T_DY* dYdata = dY.template data<T_DY>();
T_DX* dXdata = dX->template mutable_data<T_DX>();
CUDNN_ENFORCE(cudnnFindConvolutionBackwardDataAlgorithmEx(
state->cudnn_handle(),
filter_desc_,
filter_data,
top_desc_,
dYdata,
bwd_data_conv_desc_,
bottom_desc_,
dXdata,
kNUM_CUDNN_BWD_DATA_ALGS,
&returned_algo_count,
data_perf_stat.data(),
state->workspace().get(cudnn_ws_nbytes_limit_),
cudnn_ws_nbytes_limit_));
});
LogCuDNNPerfStats(data_perf_stat, returned_algo_count);
float algo_time =
data_perf_stat[0].status == CUDNN_STATUS_SUCCESS
? data_perf_stat[0].time
: 1e10;
return ConvBwdDataAlgorithmWithCost(
data_perf_stat[0].algo, algo_time);
});
// When set to fp32 compute, don't try fp16
if (compute_type_ == CUDNN_DATA_FLOAT) {
break;
}
}
if (compute_type_ == CUDNN_DATA_FLOAT) {
// For FP32 compute, just use the best FP32 algorithm
bwd_data_algo_ = std::get<0>(algosToCompare[0]);
} else {
// For FP16 compute, choose algo with fastest execution
int bestAlgoIndex =
(std::get<1>(algosToCompare[0]) < std::get<1>(algosToCompare[1]))
? 0
: 1;
bwd_data_algo_ = std::get<0>(algosToCompare[bestAlgoIndex]);
SetConvDescComputeType(
bwd_data_conv_desc_, kComputeTypesToTry[bestAlgoIndex]);
}
} else {
constexpr int nalgo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
int valid_algos;
cudnnConvolutionBwdDataAlgoPerf_t algos[nalgo];
CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataAlgorithm_v7(
cudnn_wrapper_.inline_cudnn_handle(),
filter_desc_,
top_desc_,
bwd_data_conv_desc_,
bottom_desc_,
nalgo,
&valid_algos,
algos));
bool found = false;
for (int i = 0; i < valid_algos; i++) {
auto a = algos[i];
if (a.memory <= cudnn_ws_nbytes_limit_) {
bwd_data_algo_ = a.algo;
found = true;
break;
}
}
CAFFE_ENFORCE(found, "Unable to find algorithms for cuDNN backward data");
}
}
// get workspace size for backwards filter algorithm
size_t bwd_filter_ws_size, bwd_data_ws_size;
for (int step = 0; step < 2; ++step) {
cudnnStatus_t _status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
cudnn_wrapper_.inline_cudnn_handle(),
bottom_desc_,
top_desc_,
bwd_filter_conv_desc_,
filter_desc_,
bwd_filter_algo_,
&bwd_filter_ws_size);
if (step == 0) {
if (_status == CUDNN_STATUS_SUCCESS) {
break;
}
if (_status == CUDNN_STATUS_NOT_SUPPORTED) {
cudnnConvolutionBwdFilterAlgo_t new_algo = deterministic_
? CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
: CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
VLOG(1) << "Backward Filter algorithm " << (int)bwd_filter_algo_
<< " is not currently supported for given parameters."
<< " Trying the default algorithm " << (int)new_algo;
bwd_filter_algo_ = new_algo;
continue;
}
}
CUDNN_ENFORCE(_status);
}
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
// get workspace size for backwards data algorithm
for (int step = 0; step < 2; ++step) {
cudnnStatus_t _status = cudnnGetConvolutionBackwardDataWorkspaceSize(
cudnn_wrapper_.inline_cudnn_handle(),
filter_desc_,
top_desc_,
bwd_data_conv_desc_,
bottom_desc_,
bwd_data_algo_,
&bwd_data_ws_size);
if (step == 0) {
if (_status == CUDNN_STATUS_SUCCESS) {
break;
}
if (_status == CUDNN_STATUS_NOT_SUPPORTED) {
cudnnConvolutionBwdDataAlgo_t new_algo = deterministic_
? CUDNN_CONVOLUTION_BWD_DATA_ALGO_1
: CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
VLOG(1) << "Backward Data algorithm " << (int)bwd_data_algo_
<< " is not currently supported for given parameters."
<< " Trying the default algorithm " << (int)new_algo;
bwd_data_algo_ = new_algo;
continue;
}
}
CUDNN_ENFORCE(_status);
}
} else {
bwd_data_ws_size = 0;
}
cudnn_ws_nbytes_ = std::max(bwd_filter_ws_size, bwd_data_ws_size);
VLOG(1) << "CuDNN bwd data & filter algorithm: " << bwd_data_algo_ << ", "
<< bwd_filter_algo_;
VLOG(1) << "CuDNN workspace size: " << cudnn_ws_nbytes_;
}
// Now, actually run the computation.
if (!no_bias_) {
auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<T_DB>());
CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
cudnn_wrapper_.inline_cudnn_handle(),
cudnnTypeWrapper<T_DY>::kOne(),
top_desc_for_bias_,
dY.template data<T_DY>(),
cudnnTypeWrapper<T_DB>::kZero(),
bias_desc_,
dbias->template mutable_data<T_DB>()));
}
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
state->cudnn_handle(),
cudnnTypeWrapper<T_X>::kOne(),
bottom_desc_,
X.template data<T_X>(),
top_desc_,
dY.template data<T_DY>(),
bwd_filter_conv_desc_,
bwd_filter_algo_,
state->workspace().get(cudnn_ws_nbytes_),
cudnn_ws_nbytes_,
cudnnTypeWrapper<T_DW>::kZero(),
filter_desc_,
dfilter->template mutable_data<T_DW>()));
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
// Compute the gradient w.r.t. the input.
auto* dX = Output(
no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
X.sizes(),
at::dtype<T_DX>());
CUDNN_ENFORCE(cudnnConvolutionBackwardData(
state->cudnn_handle(),
cudnnTypeWrapper<T_W>::kOne(),
filter_desc_,
filter.template data<T_W>(),
top_desc_,
dY.template data<T_DY>(),
bwd_data_conv_desc_,
bwd_data_algo_,
state->workspace().get(cudnn_ws_nbytes_),
cudnn_ws_nbytes_,
cudnnTypeWrapper<T_DX>::kZero(),
bottom_desc_,
dX->template mutable_data<T_DX>()));
}
});
#else
for (int i = 0; i < group_; ++i) {
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
state->cudnn_handle(),
cudnnTypeWrapper<T_X>::kOne(),
bottom_desc_,
X.template data<T_X>() + i * group_offset_X,
top_desc_,
dY.template data<T_DY>() + i * group_offset_Y,
bwd_filter_conv_desc_,
bwd_filter_algo_,
state->workspace().get(cudnn_ws_nbytes_),
cudnn_ws_nbytes_,
cudnnTypeWrapper<T_DW>::kZero(),
filter_desc_,
dfilter->template mutable_data<T_DW>() + i * group_offset_filter));
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
// Compute the gradient w.r.t. the input.
auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
dX->ResizeLike(X);
CUDNN_ENFORCE(cudnnConvolutionBackwardData(
state->cudnn_handle(),
cudnnTypeWrapper<T_W>::kOne(),
filter_desc_,
filter.template data<T_W>() + i * group_offset_filter,
top_desc_,
dY.template data<T_DY>() + i * group_offset_Y,
bwd_data_conv_desc_,
bwd_data_algo_,
state->workspace().get(cudnn_ws_nbytes_),
cudnn_ws_nbytes_,
cudnnTypeWrapper<T_DX>::kZero(),
bottom_desc_,
dX->template mutable_data<T_DX>() + i * group_offset_X));
}
});
}
#endif
return true;
}