virtual void postStep()

in rela/r2d2_actor.h [272:302]


  virtual void postStep() override {
    assert(replayBuffer_ != nullptr);
    assert(multiStepBuffer_.size() == historyHidden_.size());

    if (!multiStepBuffer_.canPop()) {
      assert(!r2d2Buffer_.canPop());
      return;
    }

    {
      FFTransition transition = multiStepBuffer_.popTransition();
      TensorDict hid = historyHidden_.front();
      TensorDict nextHid = historyHidden_.back();
      historyHidden_.pop_front();

      torch::Tensor priority = computePriority(transition, hid, nextHid);
      r2d2Buffer_.push(transition, priority, hid);
    }

    if (!r2d2Buffer_.canPop()) {
      return;
    }

    std::vector<RNNTransition> batch;
    torch::Tensor batchSeqPriority;
    torch::Tensor batchLen;

    std::tie(batch, batchSeqPriority, batchLen) = r2d2Buffer_.popTransition();
    auto priority = aggregatePriority(batchSeqPriority, batchLen);
    replayBuffer_->add(batch, priority);
  }