in src/operator/cudnn_deconvolution-inl.h [344:534]
inline void InitDescriptors(const Context& ctx,
const std::vector<TShape> &in_shape,
const std::vector<TShape> &out_shape,
cudnnDataType_t cudnn_forward_compute_type,
cudnnDataType_t cudnn_backward_compute_type) {
using namespace mshadow;
#if CUDNN_MAJOR >= 5
format_ = CUDNN_TENSOR_NCHW;
#endif
size_t expected = param_.no_bias ? 2 : 3;
CHECK_EQ(in_shape.size(), expected);
CHECK_EQ(out_shape.size(), 1U);
CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc_));
CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc_));
CUDNN_CALL(cudnnCreateConvolutionDescriptor(&forward_conv_desc_));
CUDNN_CALL(cudnnCreateConvolutionDescriptor(&backward_conv_desc_));
TShape dshape = in_shape[deconv::kData];
TShape wshape = in_shape[deconv::kWeight];
TShape oshape = out_shape[deconv::kOut];
TShape dstride, ostride;
wshape[0] /= param_.num_group;
if (param_.kernel.ndim() == 2) {
// 2d conv
index_t o_pad[2];
index_t o_adj[2];
param_.InferPad(dshape, o_pad, o_adj);
#if CUDNN_MAJOR >= 6
CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_,
o_pad[0],
o_pad[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1],
CUDNN_CROSS_CORRELATION,
cudnn_forward_compute_type));
CUDNN_CALL(cudnnSetConvolution2dDescriptor(backward_conv_desc_,
o_pad[0],
o_pad[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1],
CUDNN_CROSS_CORRELATION,
cudnn_backward_compute_type));
#else
CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_,
o_pad[0],
o_pad[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1],
CUDNN_CROSS_CORRELATION));
CUDNN_CALL(cudnnSetConvolution2dDescriptor(backward_conv_desc_,
o_pad[0],
o_pad[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1],
CUDNN_CROSS_CORRELATION));
#endif
#if CUDNN_MAJOR >= 5
wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW);
CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_,
dtype_,
format_,
wshape[0],
wshape[1],
wshape[2],
wshape[3]));
#else
CHECK_EQ(param_.layout.value(), kNCHW) << "CuDNN V4 only support NCHW layout";
CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_,
dtype_,
wshape[0],
wshape[1],
wshape[2],
wshape[3]));
#endif
dstride = ConvertLayout(Shape4(dshape[1] * dshape[2] * dshape[3],
dshape[2] * dshape[3],
dshape[3],
1),
param_.layout.value(), kNCHW);
dshape = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW);
ostride = ConvertLayout(Shape4(oshape[1] * oshape[2] * oshape[3],
oshape[2] * oshape[3],
oshape[3],
1),
param_.layout.value(), kNCHW);
oshape = ConvertLayout(oshape.get<4>(), param_.layout.value(), kNCHW);
} else if (param_.kernel.ndim() == 3) {
// 3d conv
index_t o_pad[3];
index_t o_adj[3];
param_.InferPad(dshape, o_pad, o_adj);
#if CUDNN_MAJOR >= 5
CHECK_EQ(param_.layout.value(), kNCDHW) << "CuDNN only support 3D conv with NCDHW layout";
std::vector<int> wshape_buffer(wshape.ndim());
CUDNN_CALL(cudnnSetFilterNdDescriptor(filter_desc_,
dtype_,
CUDNN_TENSOR_NCHW,
static_cast<int>(wshape.ndim()),
CastTShapeToIntPtr(wshape, &wshape_buffer)));
#else
LOG(FATAL) << "Only support CUDNN V5 for 3D convolution";
#endif
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(forward_conv_desc_,
3,
reinterpret_cast<int*>(&o_pad[0]),
param_stride_.data(),
param_dilate_.data(),
CUDNN_CROSS_CORRELATION,
cudnn_forward_compute_type));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(backward_conv_desc_,
3,
reinterpret_cast<int*>(&o_pad[0]),
param_stride_.data(),
param_dilate_.data(),
CUDNN_CROSS_CORRELATION,
cudnn_backward_compute_type));
dstride = ConvertLayout(Shape5(dshape[1] * dshape[2] * dshape[3] * dshape[4],
dshape[2] * dshape[3] * dshape[4],
dshape[3] * dshape[4],
dshape[4],
1),
param_.layout.value(), kNCDHW);
dshape = ConvertLayout(dshape.get<5>(), param_.layout.value(), kNCDHW);
ostride = ConvertLayout(Shape5(oshape[1] * oshape[2] * oshape[3] * oshape[4],
oshape[2] * oshape[3] * oshape[4],
oshape[3] * oshape[4],
oshape[4],
1),
param_.layout.value(), kNCDHW);
oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW);
}
dshape[1] /= param_.num_group;
oshape[1] /= param_.num_group;
weight_offset_ = wshape.Size();
data_offset_ = dstride[1] * dshape[1];
out_offset_ = ostride[1] * oshape[1];
std::vector<int> dshape_buffer(dshape.ndim());
std::vector<int> dstride_buffer(dstride.ndim());
CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_,
dtype_,
static_cast<int>(dshape.ndim()),
CastTShapeToIntPtr(dshape, &dshape_buffer),
CastTShapeToIntPtr(dstride, &dstride_buffer)))
std::vector<int> oshape_buffer(oshape.ndim());
std::vector<int> ostride_buffer(ostride.ndim());
CUDNN_CALL(cudnnSetTensorNdDescriptor(out_desc_,
dtype_,
static_cast<int>(oshape.ndim()),
CastTShapeToIntPtr(oshape, &oshape_buffer),
CastTShapeToIntPtr(ostride, &ostride_buffer)));
if (!param_.no_bias) {
TShape bias = in_shape[deconv::kBias];
bias_offset_ = bias[0] / param_.num_group;
std::vector<int> bias_shape = {1,
static_cast<int>(bias[0] / param_.num_group),
1, 1};
std::vector<int> bias_stride = {static_cast<int>(bias_offset_), 1, 1, 1};
if (param_.kernel.ndim() == 3) {
bias_shape.push_back(1);
bias_stride.push_back(1);
}
CUDNN_CALL(cudnnSetTensorNdDescriptor(bias_desc_,
dtype_,
static_cast<int>(bias_shape.size()),
&bias_shape[0],
&bias_stride[0]));
}
init_cudnn_ = true;
}