const vector CudnnRNN::Forward()

in src/model/layer/cudnn_rnn.cc [258:365]


const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
  DataType dtype = inputs.at(0).data_type();
  auto dev = inputs.at(0).device();

  // copy input data into a block of contiguous memory
  // hx (and cx) is at the end of inputs
  CHECK_GT(inputs.size(), 1u + has_cell_);
  size_t num_x = inputs.size() - has_cell_ - 1;
  Tensor input = MergeInputs(num_x, inputs);
  // LOG(INFO) << "input size " << input.Size() << " value " << input.L1();

  if (rnn_desc_ != nullptr)
    CHECK_EQ(dtype_, GetCudnnDataType(dtype))
        << "Cannot change cudnn data type during training from " << dtype_
        << " to " << GetCudnnDataType(dtype);
  else
    dtype_ = GetCudnnDataType(dtype);

  UpdateStates(num_x, inputs);
  // CheckFowardShapes();

  Shape outshape{input.Size() * hidden_size_ / input_size_ * num_directions_};
  Tensor output(outshape, dev, dtype);
  // LOG(INFO) << "output size " << output.Size();
  Tensor hx = inputs.at(num_x);
  Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_size_};
  Tensor hy(state_shape, dev, dtype);
  Tensor cy, cx;
  if (has_cell_) {
    cx = inputs.at(num_x + 1);
    cy.ResetLike(hy);
  }

  int did = input.device()->id();
  CHECK_EQ(did, output.device()->id());
  if (hx.Size()) {
    CHECK_EQ(did, hx.device()->id());
    CHECK_EQ(hx.device()->lang(), kCuda);
  }
  if (cx.Size()) {
    CHECK_EQ(did, cx.device()->id());
    CHECK_EQ(cx.device()->lang(), kCuda);
  }
  CHECK_EQ(did, weight_.device()->id());
  CHECK_EQ(did, workspace_.device()->id());
  CHECK_EQ(input.device()->lang(), kCuda);
  CHECK_EQ(output.device()->lang(), kCuda);
  CHECK_EQ(weight_.device()->lang(), kCuda);
  CHECK_EQ(workspace_.device()->lang(), kCuda);

  // LOG(INFO) << "hidden size " << hy.Size();
  // LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1();
  Block *inb = input.block(), *outb = output.block(),
         *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
          *hyb = hy.block(), *cyb = cy.block(),
           *wspace = this->workspace_.block(),
            *rspace = this->reserve_space_.block();
  if (flag & kTrain) {
    CHECK_EQ(reserve_space_.device()->lang(), kCuda);
    CHECK_EQ(did, reserve_space_.device()->id());
    dev->Exec(
    [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context * ctx) {
      // clang-format off
      cudnnRNNForwardTraining(
        ctx->cudnn_handle,
        this->rnn_desc_,
        this->seq_length_,
        this->x_descs_, inb->data(),
        this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
        this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
        this->weight_desc_, wb->data(),
        this->y_descs_, outb->mutable_data(),
        this->hy_desc_, hyb->mutable_data(),
        this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
        wspace->mutable_data(),
        this->workspace_.Size(), rspace->mutable_data(),
        this->reserve_space_.Size());
      // clang-format on
    },
    {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
    buf_.push(input);
    buf_.push(output);
    buf_.push(hx);
    buf_.push(cx);
  } else {
    dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context * ctx) {
      // clang-format off
      cudnnRNNForwardInference(
        ctx->cudnn_handle,
        this->rnn_desc_,
        this->seq_length_,
        this->x_descs_, inb->data(),
        this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
        this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
        this->weight_desc_, wb->data(),
        this->y_descs_, outb->mutable_data(),
        this->hy_desc_, hyb->mutable_data(),
        this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
        wspace->mutable_data(), this->workspace_.Size());
      // clang-format on
    }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});
  }
  auto outputs =
    SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output);
  outputs.push_back(hy);
  if (has_cell_) outputs.push_back(cy);
  return outputs;
}