in src/model/layer/cudnn_rnn.cc [258:365]
const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
DataType dtype = inputs.at(0).data_type();
auto dev = inputs.at(0).device();
// copy input data into a block of contiguous memory
// hx (and cx) is at the end of inputs
CHECK_GT(inputs.size(), 1u + has_cell_);
size_t num_x = inputs.size() - has_cell_ - 1;
Tensor input = MergeInputs(num_x, inputs);
// LOG(INFO) << "input size " << input.Size() << " value " << input.L1();
if (rnn_desc_ != nullptr)
CHECK_EQ(dtype_, GetCudnnDataType(dtype))
<< "Cannot change cudnn data type during training from " << dtype_
<< " to " << GetCudnnDataType(dtype);
else
dtype_ = GetCudnnDataType(dtype);
UpdateStates(num_x, inputs);
// CheckFowardShapes();
Shape outshape{input.Size() * hidden_size_ / input_size_ * num_directions_};
Tensor output(outshape, dev, dtype);
// LOG(INFO) << "output size " << output.Size();
Tensor hx = inputs.at(num_x);
Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_size_};
Tensor hy(state_shape, dev, dtype);
Tensor cy, cx;
if (has_cell_) {
cx = inputs.at(num_x + 1);
cy.ResetLike(hy);
}
int did = input.device()->id();
CHECK_EQ(did, output.device()->id());
if (hx.Size()) {
CHECK_EQ(did, hx.device()->id());
CHECK_EQ(hx.device()->lang(), kCuda);
}
if (cx.Size()) {
CHECK_EQ(did, cx.device()->id());
CHECK_EQ(cx.device()->lang(), kCuda);
}
CHECK_EQ(did, weight_.device()->id());
CHECK_EQ(did, workspace_.device()->id());
CHECK_EQ(input.device()->lang(), kCuda);
CHECK_EQ(output.device()->lang(), kCuda);
CHECK_EQ(weight_.device()->lang(), kCuda);
CHECK_EQ(workspace_.device()->lang(), kCuda);
// LOG(INFO) << "hidden size " << hy.Size();
// LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1();
Block *inb = input.block(), *outb = output.block(),
*wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
*hyb = hy.block(), *cyb = cy.block(),
*wspace = this->workspace_.block(),
*rspace = this->reserve_space_.block();
if (flag & kTrain) {
CHECK_EQ(reserve_space_.device()->lang(), kCuda);
CHECK_EQ(did, reserve_space_.device()->id());
dev->Exec(
[inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context * ctx) {
// clang-format off
cudnnRNNForwardTraining(
ctx->cudnn_handle,
this->rnn_desc_,
this->seq_length_,
this->x_descs_, inb->data(),
this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
this->weight_desc_, wb->data(),
this->y_descs_, outb->mutable_data(),
this->hy_desc_, hyb->mutable_data(),
this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
wspace->mutable_data(),
this->workspace_.Size(), rspace->mutable_data(),
this->reserve_space_.Size());
// clang-format on
},
{inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
buf_.push(input);
buf_.push(output);
buf_.push(hx);
buf_.push(cx);
} else {
dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context * ctx) {
// clang-format off
cudnnRNNForwardInference(
ctx->cudnn_handle,
this->rnn_desc_,
this->seq_length_,
this->x_descs_, inb->data(),
this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
this->weight_desc_, wb->data(),
this->y_descs_, outb->mutable_data(),
this->hy_desc_, hyb->mutable_data(),
this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
wspace->mutable_data(), this->workspace_.Size());
// clang-format on
}, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});
}
auto outputs =
SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output);
outputs.push_back(hy);
if (has_cell_) outputs.push_back(cy);
return outputs;
}