in src/model/operation/rnn.cc [699:771]
vector<Tensor> GpuRNNBackwardxEx(const Tensor &y, const Tensor &dy,
const Tensor &dhy, const Tensor &dcy,
const Tensor &W, const Tensor &hx,
const Tensor &cx, const Tensor &seq_lengths,
CudnnRNNHandle &h) {
// y shape: {bs, seq}
// dy shape: {bs, seq}
// dx shape: {bs, seq}
Shape xshape, states_shape;
if (h.batch_first) {
LOG(FATAL) << "batch_first not implemented for GpuRNNBackwardxEx";
} else {
xshape = Shape{h.batch_size, h.seq_length, h.feature_size};
states_shape = Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
h.hidden_size};
}
Tensor dx(xshape, y.device());
Tensor dhx(states_shape, y.device());
Tensor dcx(states_shape, y.device());
dx.SetValue(0.0f);
dhx.SetValue(0.0f);
dcx.SetValue(0.0f);
h.workspace.SetValue(0.0f);
dx.device()->Exec(
[dx, dhx, dcx, y, dy, dhy, dcy, &W, hx, cx, seq_lengths,
&h](Context *ctx) {
cudnnRNNDataDescriptor_t yDesc, dyDesc, dxDesc;
init_data_desc(yDesc,
h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
seq_lengths, h);
init_data_desc(dyDesc,
h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
seq_lengths, h);
init_data_desc(dxDesc, h.feature_size, seq_lengths, h);
/* other tensors desc*/
cudnnTensorDescriptor_t hxDesc, cxDesc, dhxDesc, dcxDesc, dhyDesc,
dcyDesc;
init_hc_Desc(hxDesc, h);
init_hc_Desc(cxDesc, h);
init_hc_Desc(dhxDesc, h);
init_hc_Desc(dcxDesc, h);
init_hc_Desc(dhyDesc, h);
init_hc_Desc(dcyDesc, h);
auto dxptr = dx.block()->mutable_data();
auto hxptr = hx.block()->data();
auto dhxptr = dhx.block()->mutable_data();
auto cxptr = cx.block()->data();
auto dcxptr = dcx.block()->mutable_data();
auto Wptr = W.block()->data();
auto yptr = y.block()->data();
auto dyptr = dy.block()->data();
auto dhyptr = dhy.block()->data();
auto dcyptr = dcy.block()->data();
auto wsptr = h.workspace.block()->mutable_data();
auto rsptr = h.reserve_space.block()->mutable_data();
CUDNN_CHECK(cudnnRNNBackwardDataEx(
ctx->cudnn_handle, h.rnnDesc, yDesc, yptr, dyDesc, dyptr, NULL,
NULL, dhyDesc, dhyptr, dcyDesc, dcyptr, h.wDesc, Wptr, hxDesc,
hxptr, cxDesc, cxptr, dxDesc, dxptr, dhxDesc, dhxptr, dcxDesc,
dcxptr, NULL, NULL, wsptr, h.workspace_size_bytes, rsptr,
h.reserve_size_bytes));
},
{y.block(), dy.block(), dhy.block(), dcy.block(), hx.block(), cx.block(),
W.block()},
{dx.block(), dhx.block(), dcx.block(), h.workspace.block(),
h.reserve_space.block()});
return {dx, dhx, dcx};
}