in src/operator/rnn-inl.h [1117:1363]
void Backward(const OpContext& ctx,
const std::vector<TBlob>& out_grad,
const std::vector<TBlob>& in_data,
const std::vector<TBlob>& out_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";
size_t num_inputs = GetRnnNumInputs(param_);
// kOut
size_t num_outputs = 1;
if (param_.state_outputs) {
// kOut, kStateOut, kStateCellOut
num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
}
CHECK_EQ(in_data.size(), num_inputs);
CHECK_EQ(out_data.size(), num_outputs);
CHECK_EQ(in_grad.size(), num_inputs);
CHECK_EQ(out_grad.size(), num_outputs);
CHECK_EQ(req.size(), num_inputs);
CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data";
CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state";
Stream<xpu>* s = ctx.get_stream<xpu>();
// get input + output tensors
Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s);
Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1, DType>(s);
Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s);
CHECK_EQ(x.CheckContiguous(), true);
CHECK_EQ(w.CheckContiguous(), true);
CHECK_EQ(dw.CheckContiguous(), true);
CHECK_EQ(hx.CheckContiguous(), true);
CHECK_EQ(dhx.CheckContiguous(), true);
CHECK_EQ(y.CheckContiguous(), true);
CHECK_EQ(dy.CheckContiguous(), true);
CHECK_EQ(dx.CheckContiguous(), true);
if (req[rnn_enum::kParams] != kAddTo) {
dw = mshadow::expr::ScalarExp<DType>(0.0f);
}
param_.seq_length_ = x.shape_[0];
param_.batch_size_ = x.shape_[1];
param_.input_size_ = x.shape_[2];
const int direction = param_.bidirectional ? 2 : 1;
const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode);
DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;
DType* dhy_ptr = nullptr;
if (param_.state_outputs) {
dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
}
DType* dcx_ptr = nullptr;
DType* dcy_ptr = nullptr;
DType* cx_ptr = nullptr;
if (param_.mode == rnn_enum::kLstm) {
CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell";
cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
}
if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) {
dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
}
#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
if (!init_cudnn_) {
Init(ctx, s, in_data, out_data);
}
// Get temp space
int temp_size = workspace_size_;
Tensor<gpu, 1, DType> temp_space =
ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(temp_size), s);
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
rnn_desc_,
y_data_desc_,
y.dptr_,
dy_data_desc_,
dy.dptr_,
nullptr,
nullptr,
dhy_desc_,
dhy_ptr,
dcy_desc_,
dcy_ptr,
w_desc_,
w.dptr_,
hx_desc_,
hx.dptr_,
cx_desc_,
cx_ptr,
dx_data_desc_,
dx.dptr_,
dhx_desc_,
dhx.dptr_,
dcx_desc_,
dcx_ptr,
nullptr,
nullptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
SyncDgrad();
if (req[rnn_enum::kParams] != kNullOp) {
CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
rnn_desc_,
x_data_desc_,
x.dptr_,
hx_desc_,
hx.dptr_,
y_data_desc_,
y.dptr_,
temp_space.dptr_,
workspace_byte_,
dw_desc_,
dw.dptr_,
reserve_space_.dptr,
reserve_space_byte_));
}
#else
CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
y_desc_vec_.data(),
y.dptr_,
dy_desc_vec_.data(),
dy.dptr_,
dhy_desc_,
dhy_ptr,
dcy_desc_,
dcy_ptr,
w_desc_,
w.dptr_,
hx_desc_,
hx.dptr_,
cx_desc_,
cx_ptr,
dx_desc_vec_.data(),
dx.dptr_,
dhx_desc_,
dhx.dptr_,
dcx_desc_,
dcx_ptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
SyncDgrad();
if (req[rnn_enum::kParams] != kNullOp) {
CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
x_desc_vec_.data(),
x.dptr_,
hx_desc_,
hx.dptr_,
y_desc_vec_.data(),
y.dptr_,
temp_space.dptr_,
workspace_byte_,
dw_desc_,
dw.dptr_,
reserve_space_.dptr,
reserve_space_byte_));
}
#endif // MXNET_USE_CUDNN_GE_7200
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
// TODO(zixuanweeei): Add training support for LSTM with projection on CPU.
// projection_size = param_.projection_size.value();
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}
// allocate temp space
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
param_.state_size,
projection_size,
direction,
param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) {
LOG(FATAL) << "Check temp init error";
}
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);
size_t r_size = GetRNNReserveSpaceSize(param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.state_size,
param_.mode);
if (!init_space_ || reserve_cpu_space_size_ != r_size) {
LOG(FATAL) << "Check forward init error";
}
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);
RNNBackward<DType>(work_cpu_space,
reserve_space_ptr,
param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
param_.state_size,
x.dptr_,
hx.dptr_,
cx_ptr,
w.dptr_,
y.dptr_,
dy.dptr_,
dhy_ptr,
dcy_ptr,
dx.dptr_,
dhx.dptr_,
dcx_ptr,
dw.dptr_,
db_ptr,
req[rnn_enum::kData],
req[rnn_enum::kParams],
req[rnn_enum::kState],
// State cell should be present for LSTMs, but is absent for other RNNs.
param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp,
param_.p,
param_.mode);
}
}