vector GpuRNNBackwardxEx()

in src/model/operation/rnn.cc [699:771]


vector<Tensor> GpuRNNBackwardxEx(const Tensor &y, const Tensor &dy,
                                 const Tensor &dhy, const Tensor &dcy,
                                 const Tensor &W, const Tensor &hx,
                                 const Tensor &cx, const Tensor &seq_lengths,
                                 CudnnRNNHandle &h) {
  // y shape: {bs, seq}
  // dy shape: {bs, seq}
  // dx shape: {bs, seq}
  Shape xshape, states_shape;
  if (h.batch_first) {
    LOG(FATAL) << "batch_first not implemented for GpuRNNBackwardxEx";
  } else {
    xshape = Shape{h.batch_size, h.seq_length, h.feature_size};
    states_shape = Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
                         h.hidden_size};
  }
  Tensor dx(xshape, y.device());
  Tensor dhx(states_shape, y.device());
  Tensor dcx(states_shape, y.device());

  dx.SetValue(0.0f);
  dhx.SetValue(0.0f);
  dcx.SetValue(0.0f);
  h.workspace.SetValue(0.0f);

  dx.device()->Exec(
      [dx, dhx, dcx, y, dy, dhy, dcy, &W, hx, cx, seq_lengths,
       &h](Context *ctx) {
        cudnnRNNDataDescriptor_t yDesc, dyDesc, dxDesc;
        init_data_desc(yDesc,
                       h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
                       seq_lengths, h);
        init_data_desc(dyDesc,
                       h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
                       seq_lengths, h);
        init_data_desc(dxDesc, h.feature_size, seq_lengths, h);

        /* other tensors desc*/
        cudnnTensorDescriptor_t hxDesc, cxDesc, dhxDesc, dcxDesc, dhyDesc,
            dcyDesc;
        init_hc_Desc(hxDesc, h);
        init_hc_Desc(cxDesc, h);
        init_hc_Desc(dhxDesc, h);
        init_hc_Desc(dcxDesc, h);
        init_hc_Desc(dhyDesc, h);
        init_hc_Desc(dcyDesc, h);

        auto dxptr = dx.block()->mutable_data();
        auto hxptr = hx.block()->data();
        auto dhxptr = dhx.block()->mutable_data();
        auto cxptr = cx.block()->data();
        auto dcxptr = dcx.block()->mutable_data();
        auto Wptr = W.block()->data();
        auto yptr = y.block()->data();
        auto dyptr = dy.block()->data();
        auto dhyptr = dhy.block()->data();
        auto dcyptr = dcy.block()->data();
        auto wsptr = h.workspace.block()->mutable_data();
        auto rsptr = h.reserve_space.block()->mutable_data();

        CUDNN_CHECK(cudnnRNNBackwardDataEx(
            ctx->cudnn_handle, h.rnnDesc, yDesc, yptr, dyDesc, dyptr, NULL,
            NULL, dhyDesc, dhyptr, dcyDesc, dcyptr, h.wDesc, Wptr, hxDesc,
            hxptr, cxDesc, cxptr, dxDesc, dxptr, dhxDesc, dhxptr, dcxDesc,
            dcxptr, NULL, NULL, wsptr, h.workspace_size_bytes, rsptr,
            h.reserve_size_bytes));
      },
      {y.block(), dy.block(), dhy.block(), dcy.block(), hx.block(), cx.block(),
       W.block()},
      {dx.block(), dhx.block(), dcx.block(), h.workspace.block(),
       h.reserve_space.block()});
  return {dx, dhx, dcx};
}