void Backward()

in src/operator/rnn-inl.h [1117:1363]


  void Backward(const OpContext& ctx,
                const std::vector<TBlob>& out_grad,
                const std::vector<TBlob>& in_data,
                const std::vector<TBlob>& out_data,
                const std::vector<OpReqType>& req,
                const std::vector<TBlob>& in_grad) {
    using namespace mshadow;
    using namespace mshadow::expr;
    CHECK(param_.p >= 0.0f && param_.p < 1.0f)
        << "unsupported dropout value, should be 0 <= dropout < 1";

    size_t num_inputs = GetRnnNumInputs(param_);

    //  kOut
    size_t num_outputs = 1;
    if (param_.state_outputs) {
      // kOut, kStateOut, kStateCellOut
      num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
    }

    CHECK_EQ(in_data.size(), num_inputs);
    CHECK_EQ(out_data.size(), num_outputs);
    CHECK_EQ(in_grad.size(), num_inputs);
    CHECK_EQ(out_grad.size(), num_outputs);
    CHECK_EQ(req.size(), num_inputs);
    CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data";
    CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state";
    Stream<xpu>* s = ctx.get_stream<xpu>();
    // get input + output tensors
    Tensor<xpu, 3, DType> x   = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
    Tensor<xpu, 3, DType> dx  = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s);
    Tensor<xpu, 1, DType> w   = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
    Tensor<xpu, 1, DType> dw  = in_grad[rnn_enum::kParams].get<xpu, 1, DType>(s);
    Tensor<xpu, 3, DType> hx  = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
    Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3, DType>(s);
    Tensor<xpu, 3, DType> y   = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
    Tensor<xpu, 3, DType> dy  = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s);

    CHECK_EQ(x.CheckContiguous(), true);
    CHECK_EQ(w.CheckContiguous(), true);
    CHECK_EQ(dw.CheckContiguous(), true);
    CHECK_EQ(hx.CheckContiguous(), true);
    CHECK_EQ(dhx.CheckContiguous(), true);
    CHECK_EQ(y.CheckContiguous(), true);
    CHECK_EQ(dy.CheckContiguous(), true);
    CHECK_EQ(dx.CheckContiguous(), true);

    if (req[rnn_enum::kParams] != kAddTo) {
      dw = mshadow::expr::ScalarExp<DType>(0.0f);
    }

    param_.seq_length_ = x.shape_[0];
    param_.batch_size_ = x.shape_[1];
    param_.input_size_ = x.shape_[2];

    const int direction = param_.bidirectional ? 2 : 1;
    const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode);

    DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;

    DType* dhy_ptr = nullptr;
    if (param_.state_outputs) {
      dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
    }

    DType* dcx_ptr = nullptr;
    DType* dcy_ptr = nullptr;
    DType* cx_ptr  = nullptr;

    if (param_.mode == rnn_enum::kLstm) {
      CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell";
      cx_ptr  = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
      dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
    }
    if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) {
      dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
    }

#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
    if (!init_cudnn_) {
      Init(ctx, s, in_data, out_data);
    }

    // Get temp space
    int temp_size = workspace_size_;
    Tensor<gpu, 1, DType> temp_space =
        ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
            mshadow::Shape1(temp_size), s);

#if MXNET_USE_CUDNN_GE_7200
    CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
                                      rnn_desc_,
                                      y_data_desc_,
                                      y.dptr_,
                                      dy_data_desc_,
                                      dy.dptr_,
                                      nullptr,
                                      nullptr,
                                      dhy_desc_,
                                      dhy_ptr,
                                      dcy_desc_,
                                      dcy_ptr,
                                      w_desc_,
                                      w.dptr_,
                                      hx_desc_,
                                      hx.dptr_,
                                      cx_desc_,
                                      cx_ptr,
                                      dx_data_desc_,
                                      dx.dptr_,
                                      dhx_desc_,
                                      dhx.dptr_,
                                      dcx_desc_,
                                      dcx_ptr,
                                      nullptr,
                                      nullptr,
                                      temp_space.dptr_,
                                      workspace_byte_,
                                      reserve_space_.dptr,
                                      reserve_space_byte_));
    SyncDgrad();
    if (req[rnn_enum::kParams] != kNullOp) {
      CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
                                           rnn_desc_,
                                           x_data_desc_,
                                           x.dptr_,
                                           hx_desc_,
                                           hx.dptr_,
                                           y_data_desc_,
                                           y.dptr_,
                                           temp_space.dptr_,
                                           workspace_byte_,
                                           dw_desc_,
                                           dw.dptr_,
                                           reserve_space_.dptr,
                                           reserve_space_byte_));
    }
#else
    CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
                                    rnn_desc_,
                                    param_.seq_length_,
                                    y_desc_vec_.data(),
                                    y.dptr_,
                                    dy_desc_vec_.data(),
                                    dy.dptr_,
                                    dhy_desc_,
                                    dhy_ptr,
                                    dcy_desc_,
                                    dcy_ptr,
                                    w_desc_,
                                    w.dptr_,
                                    hx_desc_,
                                    hx.dptr_,
                                    cx_desc_,
                                    cx_ptr,
                                    dx_desc_vec_.data(),
                                    dx.dptr_,
                                    dhx_desc_,
                                    dhx.dptr_,
                                    dcx_desc_,
                                    dcx_ptr,
                                    temp_space.dptr_,
                                    workspace_byte_,
                                    reserve_space_.dptr,
                                    reserve_space_byte_));
    SyncDgrad();
    if (req[rnn_enum::kParams] != kNullOp) {
      CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
                                         rnn_desc_,
                                         param_.seq_length_,
                                         x_desc_vec_.data(),
                                         x.dptr_,
                                         hx_desc_,
                                         hx.dptr_,
                                         y_desc_vec_.data(),
                                         y.dptr_,
                                         temp_space.dptr_,
                                         workspace_byte_,
                                         dw_desc_,
                                         dw.dptr_,
                                         reserve_space_.dptr,
                                         reserve_space_byte_));
    }
#endif  // MXNET_USE_CUDNN_GE_7200
#endif  // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

    if (ctx_.dev_type == kCPU) {
      int projection_size = 0;
      if (param_.projection_size.has_value()) {
        // TODO(zixuanweeei): Add training support for LSTM with projection on CPU.
        // projection_size = param_.projection_size.value();
        LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
      }

      // allocate temp space
      const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_,
                                                             param_.batch_size_,
                                                             param_.state_size,
                                                             projection_size,
                                                             direction,
                                                             param_.mode);
      if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) {
        LOG(FATAL) << "Check temp init error";
      }
      DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);
      size_t r_size         = GetRNNReserveSpaceSize(param_.num_layers,
                                             direction,
                                             param_.seq_length_,
                                             param_.batch_size_,
                                             param_.state_size,
                                             param_.mode);

      if (!init_space_ || reserve_cpu_space_size_ != r_size) {
        LOG(FATAL) << "Check forward init error";
      }

      DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);
      RNNBackward<DType>(work_cpu_space,
                         reserve_space_ptr,
                         param_.num_layers,
                         direction,
                         param_.seq_length_,
                         param_.batch_size_,
                         param_.input_size_,
                         param_.state_size,
                         x.dptr_,
                         hx.dptr_,
                         cx_ptr,
                         w.dptr_,
                         y.dptr_,
                         dy.dptr_,
                         dhy_ptr,
                         dcy_ptr,
                         dx.dptr_,
                         dhx.dptr_,
                         dcx_ptr,
                         dw.dptr_,
                         db_ptr,
                         req[rnn_enum::kData],
                         req[rnn_enum::kParams],
                         req[rnn_enum::kState],
                         // State cell should be present for LSTMs, but is absent for other RNNs.
                         param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp,
                         param_.p,
                         param_.mode);
    }
  }