in src/model/operation/rnn.cc [174:241]
vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx,
const Tensor &cx, const Tensor &W,
CudnnRNNHandle &h) {
CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
// in
// x in shape {seq, bs, ..}
// out
// y in shape {seq, bs, ..}
h.batch_size = x.shape(1); // update batch size to accomodate bs change
h.seq_length = x.shape(0);
Tensor y(Shape{h.seq_length, h.batch_size,
h.hidden_size * (h.bidirectional ? 2 : 1)},
x.device());
Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
h.hidden_size},
x.device());
Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
h.hidden_size},
x.device());
y.SetValue(0.0f);
hy.SetValue(0.0f);
cy.SetValue(0.0f);
h.workspace.SetValue(0.0f);
y.device()->Exec(
[y, hy, cy, x, hx, cx, &W, &h](Context *ctx) {
// require desc, [x], hx, cx, w, y, hy, cy
cudnnTensorDescriptor_t *xDesc =
new cudnnTensorDescriptor_t[h.seq_length];
cudnnTensorDescriptor_t *yDesc =
new cudnnTensorDescriptor_t[h.seq_length];
init_xDesc(xDesc, h);
init_yDesc(yDesc, h);
cudnnTensorDescriptor_t hxDesc;
cudnnTensorDescriptor_t cxDesc;
cudnnTensorDescriptor_t hyDesc;
cudnnTensorDescriptor_t cyDesc;
init_hc_Desc(hxDesc, h);
init_hc_Desc(cxDesc, h);
init_hc_Desc(hyDesc, h);
init_hc_Desc(cyDesc, h);
auto x_con = Contiguous(x);
auto xptr = x_con.block()->data();
auto hxptr = hx.block()->data();
auto cxptr = cx.block()->data();
auto Wptr = W.block()->data();
auto yptr = y.block()->mutable_data();
auto hyptr = hy.block()->mutable_data();
auto cyptr = cy.block()->mutable_data();
auto wsptr = h.workspace.block()->mutable_data();
CUDNN_CHECK(cudnnRNNForwardInference(
ctx->cudnn_handle, h.rnnDesc, h.seq_length, xDesc, xptr, hxDesc,
hxptr, cxDesc, cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr,
cyDesc, cyptr, wsptr, h.workspace_size_bytes));
delete[] xDesc;
delete[] yDesc;
},
{x.block(), hx.block(), cx.block(), W.block()},
{y.block(), hy.block(), cy.block(), h.workspace.block()},
"cudnnRNNForwardInterface");
return {y, hy, cy};
}