vector GpuRNNBackwardx()

in src/model/operation/rnn.cc [316:393]


vector<Tensor> GpuRNNBackwardx(const Tensor &y, const Tensor &dy,
                               const Tensor &dhy, const Tensor &dcy,
                               const Tensor &W, const Tensor &hx,
                               const Tensor &cx, CudnnRNNHandle &h) {
  // in
  // y shape {seq, bs}
  // dy shape {seq, bs}
  Tensor dx(Shape{h.seq_length, h.batch_size, h.feature_size}, y.device());
  Tensor dhx(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
                   h.hidden_size},
             y.device());
  Tensor dcx(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
                   h.hidden_size},
             y.device());
  dx.SetValue(0.0f);
  dhx.SetValue(0.0f);
  dcx.SetValue(0.0f);
  h.workspace.SetValue(0.0f);
  dx.device()->Exec(
      [dx, dhx, dcx, y, dy, dhy, dcy, &W, hx, cx, &h](Context *ctx) {
        // require desc:
        //      [dx], hx, dhx, cx, dcx, w,
        // [y], [dy],     dhy,     dcy
        cudnnTensorDescriptor_t *dxDesc =
            new cudnnTensorDescriptor_t[h.seq_length];
        cudnnTensorDescriptor_t *yDesc =
            new cudnnTensorDescriptor_t[h.seq_length];
        cudnnTensorDescriptor_t *dyDesc =
            new cudnnTensorDescriptor_t[h.seq_length];
        init_yDesc(yDesc, h);
        init_xDesc(dxDesc, h);
        init_yDesc(dyDesc, h);
        cudnnTensorDescriptor_t hxDesc;
        cudnnTensorDescriptor_t cxDesc;
        cudnnTensorDescriptor_t dhxDesc;
        cudnnTensorDescriptor_t dcxDesc;
        cudnnTensorDescriptor_t dhyDesc;
        cudnnTensorDescriptor_t dcyDesc;
        init_hc_Desc(hxDesc, h);
        init_hc_Desc(cxDesc, h);
        init_hc_Desc(dhxDesc, h);
        init_hc_Desc(dcxDesc, h);
        init_hc_Desc(dhyDesc, h);
        init_hc_Desc(dcyDesc, h);

        auto y_con = Contiguous(y);
        auto dy_con = Contiguous(dy);

        auto dxptr = dx.block()->mutable_data();
        auto hxptr = hx.block()->data();
        auto dhxptr = dhx.block()->mutable_data();
        auto cxptr = cx.block()->data();
        auto dcxptr = dcx.block()->mutable_data();
        auto Wptr = W.block()->data();
        auto yptr = y_con.block()->data();
        auto dyptr = dy_con.block()->data();
        auto dhyptr = dhy.block()->data();
        auto dcyptr = dcy.block()->data();
        auto wsptr = h.workspace.block()->mutable_data();
        auto rsptr = h.reserve_space.block()->mutable_data();

        CUDNN_CHECK(cudnnRNNBackwardData(
            ctx->cudnn_handle, h.rnnDesc, h.seq_length, yDesc, yptr, dyDesc,
            dyptr, dhyDesc, dhyptr, dcyDesc, dcyptr, h.wDesc, Wptr, hxDesc,
            hxptr, cxDesc, cxptr, dxDesc, dxptr, dhxDesc, dhxptr, dcxDesc,
            dcxptr, wsptr, h.workspace_size_bytes, rsptr,
            h.reserve_size_bytes));
        delete[] dxDesc;
        delete[] yDesc;
        delete[] dyDesc;
      },
      {y.block(), dy.block(), dhy.block(), dcy.block(), hx.block(), cx.block(),
       W.block()},
      {dx.block(), dhx.block(), dcx.block(), h.workspace.block(),
       h.reserve_space.block()},
      "cudnnRNNBackwardx");
  return {dx, dhx, dcx};
}