vector GpuRNNForwardTraining()

in src/model/operation/rnn.cc [243:315]


vector<Tensor> GpuRNNForwardTraining(const Tensor &x, const Tensor &hx,
                                     const Tensor &cx, const Tensor &W,
                                     CudnnRNNHandle &h) {
  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";

  // in
  // x in shape {seq, bs, ..}
  // out
  // y in shape {seq, bs, ..}

  // update batch size to accomodate bs change
  h.batch_size = x.shape(1);
  h.seq_length = x.shape(0);

  Tensor y(Shape{h.seq_length, h.batch_size,
                 h.hidden_size * (h.bidirectional ? 2 : 1)},
           x.device());
  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
                  h.hidden_size},
            x.device());
  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
                  h.hidden_size},
            x.device());
  y.SetValue(0.0f);
  hy.SetValue(0.0f);
  cy.SetValue(0.0f);
  h.workspace.SetValue(0.0f);
  h.reserve_space.SetValue(0.0f);

  y.device()->Exec(
      [y, hy, cy, x, hx, cx, &W, &h](Context *ctx) {
        // require desc, [x], hx, cx, w, y, hy, cy
        cudnnTensorDescriptor_t *xDesc =
            new cudnnTensorDescriptor_t[h.seq_length];
        cudnnTensorDescriptor_t *yDesc =
            new cudnnTensorDescriptor_t[h.seq_length];
        init_xDesc(xDesc, h);
        init_yDesc(yDesc, h);
        cudnnTensorDescriptor_t hxDesc;
        cudnnTensorDescriptor_t cxDesc;
        cudnnTensorDescriptor_t hyDesc;
        cudnnTensorDescriptor_t cyDesc;
        init_hc_Desc(hxDesc, h);
        init_hc_Desc(cxDesc, h);
        init_hc_Desc(hyDesc, h);
        init_hc_Desc(cyDesc, h);

        auto x_con = Contiguous(x);

        auto xptr = x_con.block()->data();
        auto hxptr = hx.block()->data();
        auto cxptr = cx.block()->data();
        auto Wptr = W.block()->data();
        auto yptr = y.block()->mutable_data();
        auto hyptr = hy.block()->mutable_data();
        auto cyptr = cy.block()->mutable_data();
        auto wsptr = h.workspace.block()->mutable_data();
        auto rsptr = h.reserve_space.block()->mutable_data();
        CUDNN_CHECK(cudnnRNNForwardTraining(
            ctx->cudnn_handle, h.rnnDesc, h.seq_length, xDesc, xptr, hxDesc,
            hxptr, cxDesc, cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr,
            cyDesc, cyptr, wsptr, h.workspace_size_bytes, rsptr,
            h.reserve_size_bytes));
        delete[] xDesc;
        delete[] yDesc;
      },
      {x.block(), hx.block(), cx.block(), W.block()},
      {y.block(), hy.block(), cy.block(), h.workspace.block(),
       h.reserve_space.block()},
      "cudnnRNNForwardTraining");

  return {y, hy, cy};
}