in rela/r2d2_actor.h [29:87]
void push(const FFTransition& transition,
const torch::Tensor& priority,
const TensorDict& hid) {
assert(priority.size(0) == batchsize_);
auto priorityAccessor = priority.accessor<float, 1>();
for (int i = 0; i < batchsize_; ++i) {
auto t = transition.index(i);
if (batchNextIdx_[i] == 0) {
// it does not matter here, should be reset after burnin
batchH0_[i] = utils::tensorDictNarrow(hid, 1, i, 1, true);
for (auto& kv : batchH0_[i]) {
assert(kv.second.sum().item<float>() == 0);
}
while (batchNextIdx_[i] < burnin_) {
batchSeqTransition_[i][batchNextIdx_[i]] = t.padLike();
++batchNextIdx_[i];
}
} else {
// should not append after terminal
// terminal should be processed when it is pushed
int nextIdx = batchNextIdx_[i];
if (batchSeqTransition_[i][nextIdx - 1].terminal.item<bool>()) {
std::cout << nextIdx << std::endl;
assert(false);
}
assert(batchLen_[i] == 0);
}
int nextIdx = batchNextIdx_[i];
// std::cout << "next idx: " << nextIdx << std::endl;
assert(nextIdx < burnin_ + seqLen_ + multiStep_ && nextIdx >= burnin_);
// burnin_ + seqLen_ - burnin_ = seqLen_
if (nextIdx == seqLen_) {
// will become stored hidden for next trajectory
batchNextH0_[i] = utils::tensorDictNarrow(hid, 1, i, 1, true);
}
batchSeqTransition_[i][nextIdx] = t;
batchSeqPriority_[i][nextIdx - burnin_] = priorityAccessor[i];
++batchNextIdx_[i];
if (!t.terminal.item<bool>() && batchNextIdx_[i] < burnin_ + seqLen_ + multiStep_) {
continue;
}
// pad the rest of the seq in case of terminal
batchLen_[i] = batchNextIdx_[i];
while (batchNextIdx_[i] < burnin_ + seqLen_ + multiStep_) {
batchSeqTransition_[i][batchNextIdx_[i]] = t.padLike();
batchSeqPriority_[i][batchNextIdx_[i] - burnin_] = 0;
++batchNextIdx_[i];
}
canPop_ = true;
}
}