in src/model/operation/rnn.cc [316:393]
vector<Tensor> GpuRNNBackwardx(const Tensor &y, const Tensor &dy,
const Tensor &dhy, const Tensor &dcy,
const Tensor &W, const Tensor &hx,
const Tensor &cx, CudnnRNNHandle &h) {
// in
// y shape {seq, bs}
// dy shape {seq, bs}
Tensor dx(Shape{h.seq_length, h.batch_size, h.feature_size}, y.device());
Tensor dhx(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
h.hidden_size},
y.device());
Tensor dcx(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
h.hidden_size},
y.device());
dx.SetValue(0.0f);
dhx.SetValue(0.0f);
dcx.SetValue(0.0f);
h.workspace.SetValue(0.0f);
dx.device()->Exec(
[dx, dhx, dcx, y, dy, dhy, dcy, &W, hx, cx, &h](Context *ctx) {
// require desc:
// [dx], hx, dhx, cx, dcx, w,
// [y], [dy], dhy, dcy
cudnnTensorDescriptor_t *dxDesc =
new cudnnTensorDescriptor_t[h.seq_length];
cudnnTensorDescriptor_t *yDesc =
new cudnnTensorDescriptor_t[h.seq_length];
cudnnTensorDescriptor_t *dyDesc =
new cudnnTensorDescriptor_t[h.seq_length];
init_yDesc(yDesc, h);
init_xDesc(dxDesc, h);
init_yDesc(dyDesc, h);
cudnnTensorDescriptor_t hxDesc;
cudnnTensorDescriptor_t cxDesc;
cudnnTensorDescriptor_t dhxDesc;
cudnnTensorDescriptor_t dcxDesc;
cudnnTensorDescriptor_t dhyDesc;
cudnnTensorDescriptor_t dcyDesc;
init_hc_Desc(hxDesc, h);
init_hc_Desc(cxDesc, h);
init_hc_Desc(dhxDesc, h);
init_hc_Desc(dcxDesc, h);
init_hc_Desc(dhyDesc, h);
init_hc_Desc(dcyDesc, h);
auto y_con = Contiguous(y);
auto dy_con = Contiguous(dy);
auto dxptr = dx.block()->mutable_data();
auto hxptr = hx.block()->data();
auto dhxptr = dhx.block()->mutable_data();
auto cxptr = cx.block()->data();
auto dcxptr = dcx.block()->mutable_data();
auto Wptr = W.block()->data();
auto yptr = y_con.block()->data();
auto dyptr = dy_con.block()->data();
auto dhyptr = dhy.block()->data();
auto dcyptr = dcy.block()->data();
auto wsptr = h.workspace.block()->mutable_data();
auto rsptr = h.reserve_space.block()->mutable_data();
CUDNN_CHECK(cudnnRNNBackwardData(
ctx->cudnn_handle, h.rnnDesc, h.seq_length, yDesc, yptr, dyDesc,
dyptr, dhyDesc, dhyptr, dcyDesc, dcyptr, h.wDesc, Wptr, hxDesc,
hxptr, cxDesc, cxptr, dxDesc, dxptr, dhxDesc, dhxptr, dcxDesc,
dcxptr, wsptr, h.workspace_size_bytes, rsptr,
h.reserve_size_bytes));
delete[] dxDesc;
delete[] yDesc;
delete[] dyDesc;
},
{y.block(), dy.block(), dhy.block(), dcy.block(), hx.block(), cx.block(),
W.block()},
{dx.block(), dhx.block(), dcx.block(), h.workspace.block(),
h.reserve_space.block()},
"cudnnRNNBackwardx");
return {dx, dhx, dcx};
}