in src/model/operation/rnn.cc [568:631]
vector<Tensor> GpuRNNForwardInferenceEx(const Tensor &x, const Tensor &hx,
const Tensor &cx, const Tensor &W,
const Tensor &seq_lengths,
CudnnRNNHandle &h) {
CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
Tensor y, hy, cy;
Shape yshape, states_shape;
if (h.batch_first) {
LOG(FATAL) << "batch_first not implemented for GpuRNNForwardTrainingEx";
} else {
h.seq_length = x.shape(0);
h.batch_size = x.shape(1);
yshape = Shape{h.seq_length, h.batch_size,
h.hidden_size * (h.bidirectional ? 2 : 1)};
states_shape = Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
h.hidden_size};
}
y = Tensor(yshape, x.device());
hy = Tensor(states_shape, x.device());
cy = Tensor(states_shape, x.device());
y.device()->Exec(
[y, hy, cy, x, seq_lengths, hx, cx, &W, &h](Context *ctx) {
// data descriptor
cudnnRNNDataDescriptor_t xDesc, yDesc;
init_data_desc(xDesc, h.feature_size, seq_lengths, h);
init_data_desc(yDesc,
h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
seq_lengths, h);
// hidden cell states descriptor
cudnnTensorDescriptor_t hxDesc, cxDesc, hyDesc, cyDesc;
init_hc_Desc(hxDesc, h);
init_hc_Desc(cxDesc, h);
init_hc_Desc(hyDesc, h);
init_hc_Desc(cyDesc, h);
auto xptr = x.block()->data();
auto hxptr = hx.block()->data();
auto cxptr = cx.block()->data();
auto Wptr = W.block()->data();
auto yptr = y.block()->mutable_data();
auto hyptr = hy.block()->mutable_data();
auto cyptr = cy.block()->mutable_data();
auto wsptr = h.workspace.block()->mutable_data();
/* This routine is the extended version of the cudnnRNNForwardTraining()
function. The cudnnRNNForwardTrainingEx() allows the user to use
unpacked (padded) layout for input x and output y.
*/
CUDNN_CHECK(cudnnRNNForwardInferenceEx(
ctx->cudnn_handle, h.rnnDesc, xDesc, xptr, hxDesc, hxptr, cxDesc,
cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr, cyDesc, cyptr,
NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, wsptr,
h.workspace_size_bytes));
},
{x.block(), hx.block(), cx.block(), W.block()},
{y.block(), hy.block(), cy.block(), h.workspace.block(),
h.reserve_space.block()});
return {y, hy, cy};
}