in src/operator/cudnn_deconvolution-inl.h [536:662]
void SelectAlgo(const Context& ctx,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
cudnnDataType_t cudnn_forward_compute_type,
cudnnDataType_t cudnn_backward_compute_type) {
std::string key = CuDNNAlgoReg::Get()->GetKey(param_, in_shape, out_shape, dtype_,
cudnn_forward_compute_type,
cudnn_backward_compute_type);
if (CuDNNAlgoReg::Get()->Find(key, &algo_, &back_algo_, &back_algo_w_))
return;
Engine::VarHandle var = Engine::Get()->NewVariable();
Engine::Get()->PushSync([=](RunContext rctx) {
mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
if (!param_.cudnn_tune.value()) {
// In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
// supported. Hard-coded this since the algo find() or get() throws an FPE.
if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else {
CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
out_desc_,
filter_desc_,
backward_conv_desc_, // forward algorithm used to backprop-to-data
in_desc_,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_byte,
&(this->algo_)));
}
CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
out_desc_,
in_desc_,
backward_conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_byte,
&(this->back_algo_w_)));
CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
filter_desc_,
in_desc_,
forward_conv_desc_, // this backward algorithm used for inference
out_desc_,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_byte,
&(this->back_algo_)));
} else {
const int kMaxAlgos = 10;
int nalgo = kMaxAlgos;
int i;
// In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
// supported. Hard-coded this since the algo find() or get() throws an FPE.
if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else {
cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
out_desc_,
filter_desc_,
backward_conv_desc_, // forward algorithm used to backprop-to-data
in_desc_,
kMaxAlgos,
&nalgo,
fwd_algo));
i = 0;
while (i < nalgo
&& (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
|| (param_.cudnn_tune.value() == deconv::kLimited
&& fwd_algo[i].memory > workspace_byte))) ++i;
if (i == nalgo) {
LOG(FATAL) << "Failed to find a 'forward' convolution algorithm " <<
"(for use in deconvolution operator backprop-to-data).";
} else {
this->algo_ = fwd_algo[i].algo;
}
}
cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
out_desc_,
in_desc_,
backward_conv_desc_,
filter_desc_,
kMaxAlgos,
&nalgo,
bwd_filter_algo));
i = 0;
while (i < nalgo
&& (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
|| (param_.cudnn_tune.value() == deconv::kLimited
&& bwd_filter_algo[i].memory > workspace_byte))) ++i;
if (i == nalgo) {
LOG(FATAL) << "Failed to find a backward filter convolution algorithm " <<
"(for use in deconvolution operator backprop-to-filter).";
} else {
this->back_algo_w_ = bwd_filter_algo[i].algo;
}
cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
filter_desc_,
in_desc_,
forward_conv_desc_, // this backward algorithm used for inference
out_desc_,
kMaxAlgos,
&nalgo,
bwd_data_algo));
i = 0;
while (i < nalgo
&& (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
|| (param_.cudnn_tune.value() == deconv::kLimited
&& bwd_data_algo[i].memory > workspace_byte))) ++i;
if (i == nalgo) {
LOG(FATAL) << "Failed to find a backward data convolution algorithm." <<
"(for use in deconvolution operator forward inference).";
} else {
this->back_algo_ = bwd_data_algo[i].algo;
}
CuDNNAlgoReg::Get()->Register(key, this->algo_, this->back_algo_,
this->back_algo_w_);
}
}, ctx, {}, {var});
Engine::Get()->WaitForVar(var);
Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var);
}