in src/operator/cudnn_rnn-inl.h [172:282]
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
size_t in_expected = param_.lstm_q_ ? 4 : 3;
size_t out_expected = param_.lstm_q_ ? 3 : 2;
if (!param_.state_outputs)
out_expected = 1;
CHECK_EQ(in_data.size(), in_expected);
CHECK_EQ(out_data.size(), out_expected);
CHECK_EQ(in_grad.size(), in_expected);
CHECK_EQ(out_grad.size(), out_expected);
CHECK_EQ(req.size(), in_expected);
CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data";
CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state";
Stream<gpu> *s = ctx.get_stream<gpu>();
// get input + output tensors
Tensor<gpu, 3, DType> x = in_data[rnn_enum::kData].get<gpu, 3, DType>(s);
Tensor<gpu, 3, DType> dx = in_grad[rnn_enum::kData].get<gpu, 3, DType>(s);
Tensor<gpu, 1, DType> w = in_data[rnn_enum::kParams].get<gpu, 1, DType>(s);
Tensor<gpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<gpu, 1, DType>(s);
Tensor<gpu, 3, DType> hx = in_data[rnn_enum::kState].get<gpu, 3, DType>(s);
Tensor<gpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<gpu, 3, DType>(s);
Tensor<gpu, 3, DType> y = out_data[rnn_enum::kOut].get<gpu, 3, DType>(s);
Tensor<gpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<gpu, 3, DType>(s);
if (req[rnn_enum::kParams] != kAddTo) {
dw = mshadow::expr::ScalarExp<DType>(0.0f);
}
// only need kStateOut grad output_states is true
void * dhy_ptr = NULL;
if (param_.state_outputs)
dhy_ptr = out_grad[rnn_enum::kStateOut].get<gpu, 3, DType>(s).dptr_;
// Deal with lstm
void * dcx_ptr = NULL;
void * dcy_ptr = NULL;
void * cx_ptr = NULL;
if (param_.mode == rnn_enum::kLstm) {
CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell";
cx_ptr = (in_data[rnn_enum::kStateCell].get<gpu, 3, DType>(s)).dptr_;
dcx_ptr = (in_grad[rnn_enum::kStateCell].get<gpu, 3, DType>(s)).dptr_;
}
if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs)
dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<gpu, 3, DType>(s)).dptr_;
CHECK_EQ(x.CheckContiguous(), true);
CHECK_EQ(w.CheckContiguous(), true);
CHECK_EQ(dw.CheckContiguous(), true);
CHECK_EQ(hx.CheckContiguous(), true);
CHECK_EQ(dhx.CheckContiguous(), true);
CHECK_EQ(y.CheckContiguous(), true);
CHECK_EQ(dy.CheckContiguous(), true);
if (!init_cudnn_) {
Init(s, in_data, out_data);
}
// Get temp space
int temp_size = workspace_size_;
Tensor<gpu, 1, DType> temp_space =
ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(temp_size), s);
CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
y_desc_vec_.data(),
y.dptr_,
dy_desc_vec_.data(),
dy.dptr_,
dhy_desc_,
dhy_ptr,
dcy_desc_,
dcy_ptr,
w_desc_,
w.dptr_,
hx_desc_,
hx.dptr_,
cx_desc_,
cx_ptr,
dx_desc_vec_.data(),
dx.dptr_,
dhx_desc_,
dhx.dptr_,
dcx_desc_,
dcx_ptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
x_desc_vec_.data(),
x.dptr_,
hx_desc_,
hx.dptr_,
y_desc_vec_.data(),
y.dptr_,
temp_space.dptr_,
workspace_byte_,
dw_desc_,
dw.dptr_,
reserve_space_.dptr,
reserve_space_byte_));
}