in rela/r2d2_actor.h [272:302]
virtual void postStep() override {
assert(replayBuffer_ != nullptr);
assert(multiStepBuffer_.size() == historyHidden_.size());
if (!multiStepBuffer_.canPop()) {
assert(!r2d2Buffer_.canPop());
return;
}
{
FFTransition transition = multiStepBuffer_.popTransition();
TensorDict hid = historyHidden_.front();
TensorDict nextHid = historyHidden_.back();
historyHidden_.pop_front();
torch::Tensor priority = computePriority(transition, hid, nextHid);
r2d2Buffer_.push(transition, priority, hid);
}
if (!r2d2Buffer_.canPop()) {
return;
}
std::vector<RNNTransition> batch;
torch::Tensor batchSeqPriority;
torch::Tensor batchLen;
std::tie(batch, batchSeqPriority, batchLen) = r2d2Buffer_.popTransition();
auto priority = aggregatePriority(batchSeqPriority, batchLen);
replayBuffer_->add(batch, priority);
}