in src/operator/nn/cudnn/cudnn_convolution-inl.h [411:612]
void InitDescriptors(const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
cudnnDataType_t cudnn_forward_compute_type,
cudnnDataType_t cudnn_backward_compute_type) {
using namespace mshadow;
size_t expected = param_.no_bias ? 2 : 3;
CHECK_EQ(in_shape.size(), expected);
CHECK_EQ(out_shape.size(), 1U);
TShape dshape = in_shape[conv::kData];
TShape wshape = in_shape[conv::kWeight];
TShape oshape = out_shape[conv::kOut];
TShape dstride, ostride;
#if CUDNN_MAJOR <= 6
wshape[0] /= param_.num_group;
#endif
#if CUDNN_MAJOR <= 5
// As of cuDNN_v6, the unsuffixed version of cudnnSetConvolution2dDescriptor()
// takes an additional 'computeType' parameter to set the precision of the
// convolution calculation. Supply this method signature for cuDNN versions < 6.
#define cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m, ct) \
cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m)
#endif
if (param_.kernel.ndim() == 1 || param_.kernel.ndim() == 2) {
// 1d or 2d conv
auto pad = param_.kernel.ndim() == 2 ? param_.pad : TShape({0, param_.pad[0]});
auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]});
auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]});
CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_,
pad[0],
pad[1],
stride[0],
stride[1],
dilate[0],
dilate[1],
CUDNN_CROSS_CORRELATION,
cudnn_forward_compute_type));
CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_,
pad[0],
pad[1],
stride[0],
stride[1],
dilate[0],
dilate[1],
CUDNN_CROSS_CORRELATION,
cudnn_backward_compute_type));
CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_,
pad[0],
pad[1],
stride[0],
stride[1],
dilate[0],
dilate[1],
CUDNN_CROSS_CORRELATION,
cudnn_backward_compute_type));
#if CUDNN_MAJOR < 5
// As of cuDNN_v5, cudnnSetFilter4dDescriptor() takes a format parameter.
// Supply this method signature for cuDNN versions < 5.
#define cudnnSetFilter4dDescriptor(fdesc, dt, f, w0, w1, w2, w3) \
cudnnSetFilter4dDescriptor(fdesc, dt, w0, w1, w2, w3)
CHECK_EQ(format_, CUDNN_TENSOR_NCHW) << "CuDNN V4 and earlier only supports NCHW layout";
#endif
if (param_.kernel.ndim() == 2) {
wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW);
dstride = ConvertLayout(Strides<4>(dshape), param_.layout.value(), kNCHW);
dshape = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW);
ostride = ConvertLayout(Strides<4>(oshape), param_.layout.value(), kNCHW);
oshape = ConvertLayout(oshape.get<4>(), param_.layout.value(), kNCHW);
} else {
wshape = ConvertLayout(wshape.get<3>(), param_.layout.value(), kNCW);
wshape = TShape({wshape[0], wshape[1], 1, wshape[2]});
dstride = ConvertLayout(Strides<3>(dshape), param_.layout.value(), kNCW);
dstride = TShape({dstride[0], dstride[1], dstride[1], dstride[2]});
dshape = ConvertLayout(dshape.get<3>(), param_.layout.value(), kNCW);
dshape = TShape({dshape[0], dshape[1], 1, dshape[2]});
ostride = ConvertLayout(Strides<3>(oshape), param_.layout.value(), kNCW);
ostride = TShape({ostride[0], ostride[1], ostride[1], ostride[2]});
oshape = ConvertLayout(oshape.get<3>(), param_.layout.value(), kNCW);
oshape = TShape({oshape[0], oshape[1], 1, oshape[2]});
}
CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_,
dtype_,
format_,
wshape[0],
wshape[1],
wshape[2],
wshape[3]));
} else if (param_.kernel.ndim() == 3) {
// 3d conv
#if CUDNN_MAJOR >= 5
CHECK_EQ(param_.layout.value(), kNCDHW) << "CuDNN only support 3D conv with NCDHW layout";
std::vector<int> wshape_buffer(wshape.ndim());
CUDNN_CALL(cudnnSetFilterNdDescriptor(filter_desc_,
dtype_,
CUDNN_TENSOR_NCHW,
static_cast<int>(wshape.ndim()),
CastTShapeToIntPtr(wshape, &wshape_buffer)));
#else
LOG(FATAL) << "Only support CUDNN V5 for 3D convolution";
#endif
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(forward_conv_desc_,
3,
param_pad_.data(),
param_stride_.data(),
param_dilate_.data(),
CUDNN_CROSS_CORRELATION,
cudnn_forward_compute_type));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(back_conv_desc_,
3,
param_pad_.data(),
param_stride_.data(),
param_dilate_.data(),
CUDNN_CROSS_CORRELATION,
cudnn_backward_compute_type));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(back_conv_desc_w_,
3,
param_pad_.data(),
param_stride_.data(),
param_dilate_.data(),
CUDNN_CROSS_CORRELATION,
cudnn_backward_compute_type));
dstride = ConvertLayout(Strides<5>(dshape), param_.layout.value(), kNCDHW);
dshape = ConvertLayout(dshape.get<5>(), param_.layout.value(), kNCDHW);
ostride = ConvertLayout(Strides<5>(oshape), param_.layout.value(), kNCDHW);
oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW);
}
// Set "allow tensor core" flag in convolution descriptors, if available.
#if CUDNN_MAJOR >= 7
cudnnMathType_t math_type = cudnn_tensor_core_ ? CUDNN_TENSOR_OP_MATH
: CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION >= 7200
if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() &&
(DataType<DType>::kFlag != kFloat16))
math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
#endif
CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, math_type));
CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, math_type));
CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, math_type));
CUDNN_CALL(cudnnSetConvolutionGroupCount(forward_conv_desc_, param_.num_group));
CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_, param_.num_group));
CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_w_, param_.num_group));
#endif
#if CUDNN_MAJOR <= 6
dshape[1] /= param_.num_group;
oshape[1] /= param_.num_group;
#endif
weight_offset_ = wshape.Size();
data_offset_ = dstride[1] * dshape[1];
out_offset_ = ostride[1] * oshape[1];
std::vector<int> dshape_buffer(dshape.ndim());
nnvm::ShapeTypeCast(dshape.begin(), dshape.end(), dshape_buffer.data());
std::vector<int> dstride_buffer(dstride.ndim());
nnvm::ShapeTypeCast(dstride.begin(), dstride.end(), dstride_buffer.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_,
dtype_,
static_cast<int>(dshape.ndim()),
dshape_buffer.data(),
dstride_buffer.data()));
std::vector<int> oshape_buffer(oshape.ndim());
nnvm::ShapeTypeCast(oshape.begin(), oshape.end(), oshape_buffer.data());
std::vector<int> ostride_buffer(ostride.ndim());
nnvm::ShapeTypeCast(ostride.begin(), ostride.end(), ostride_buffer.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(out_desc_,
dtype_,
static_cast<int>(oshape.ndim()),
oshape_buffer.data(),
ostride_buffer.data()));
if (!param_.no_bias) {
TShape bias = in_shape[conv::kBias];
#if CUDNN_MAJOR >= 7
bias_offset_ = bias[0];
std::vector<int> bias_shape = {1,
static_cast<int>(bias[0]),
1, 1};
#else
bias_offset_ = bias[0] / param_.num_group;
std::vector<int> bias_shape = {1,
static_cast<int>(bias[0] / param_.num_group),
1, 1};
#endif
std::vector<int> bias_stride = {static_cast<int>(bias_offset_), 1, 1, 1};
if (param_.kernel.ndim() == 3) {
bias_shape.push_back(1);
bias_stride.push_back(1);
}
CUDNN_CALL(cudnnSetTensorNdDescriptor(bias_desc_,
dtype_,
static_cast<int>(bias_shape.size()),
&bias_shape[0],
&bias_stride[0]));
}
}