vector GpuRNNForwardInferenceEx()

in src/model/operation/rnn.cc [568:631]


vector<Tensor> GpuRNNForwardInferenceEx(const Tensor &x, const Tensor &hx,
                                        const Tensor &cx, const Tensor &W,
                                        const Tensor &seq_lengths,
                                        CudnnRNNHandle &h) {
  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";

  Tensor y, hy, cy;
  Shape yshape, states_shape;

  if (h.batch_first) {
    LOG(FATAL) << "batch_first not implemented for GpuRNNForwardTrainingEx";
  } else {
    h.seq_length = x.shape(0);
    h.batch_size = x.shape(1);
    yshape = Shape{h.seq_length, h.batch_size,
                   h.hidden_size * (h.bidirectional ? 2 : 1)};
    states_shape = Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
                         h.hidden_size};
  }

  y = Tensor(yshape, x.device());
  hy = Tensor(states_shape, x.device());
  cy = Tensor(states_shape, x.device());

  y.device()->Exec(
      [y, hy, cy, x, seq_lengths, hx, cx, &W, &h](Context *ctx) {
        // data descriptor
        cudnnRNNDataDescriptor_t xDesc, yDesc;
        init_data_desc(xDesc, h.feature_size, seq_lengths, h);
        init_data_desc(yDesc,
                       h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
                       seq_lengths, h);

        // hidden cell states descriptor
        cudnnTensorDescriptor_t hxDesc, cxDesc, hyDesc, cyDesc;
        init_hc_Desc(hxDesc, h);
        init_hc_Desc(cxDesc, h);
        init_hc_Desc(hyDesc, h);
        init_hc_Desc(cyDesc, h);

        auto xptr = x.block()->data();
        auto hxptr = hx.block()->data();
        auto cxptr = cx.block()->data();
        auto Wptr = W.block()->data();
        auto yptr = y.block()->mutable_data();
        auto hyptr = hy.block()->mutable_data();
        auto cyptr = cy.block()->mutable_data();
        auto wsptr = h.workspace.block()->mutable_data();

        /* This routine is the extended version of the cudnnRNNForwardTraining()
        function. The cudnnRNNForwardTrainingEx() allows the user to use
        unpacked (padded) layout for input x and output y.
        */
        CUDNN_CHECK(cudnnRNNForwardInferenceEx(
            ctx->cudnn_handle, h.rnnDesc, xDesc, xptr, hxDesc, hxptr, cxDesc,
            cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr, cyDesc, cyptr,
            NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, wsptr,
            h.workspace_size_bytes));
      },
      {x.block(), hx.block(), cx.block(), W.block()},
      {y.block(), hy.block(), cy.block(), h.workspace.block(),
       h.reserve_space.block()});
  return {y, hy, cy};
}