in rela/transition_buffer.h [134:176]
void push(
const FFTransition& transition,
const torch::Tensor& priority,
const TensorDict& /*hid*/) {
assert(priority.size(0) == batchsize);
auto priorityAccessor = priority.accessor<float, 1>();
for (int i = 0; i < batchsize; ++i) {
int nextIdx = batchNextIdx_[i];
assert(nextIdx < seqLen && nextIdx >= 0);
if (nextIdx == 0) {
// TODO: !!! simplification for unconditional h0
// batchH0_[i] =
// utils::tensorDictNarrow(hid, 1, i * numPlayer, numPlayer, false);
}
auto t = transition.index(i);
// some sanity check for termination
if (nextIdx != 0) {
// should not append after terminal
// terminal should be processed when it is pushed
assert(!batchSeqTransition_[i][nextIdx - 1].terminal.item<bool>());
assert(batchLen_[i] == 0);
}
batchSeqTransition_[i][nextIdx] = t;
batchSeqPriority_[i][nextIdx] = priorityAccessor[i];
++batchNextIdx_[i];
if (!t.terminal.item<bool>()) {
continue;
}
// pad the rest of the seq in case of terminal
batchLen_[i] = batchNextIdx_[i];
while (batchNextIdx_[i] < seqLen) {
batchSeqTransition_[i][batchNextIdx_[i]] = t.padLike();
batchSeqPriority_[i][batchNextIdx_[i]] = 0;
++batchNextIdx_[i];
}
canPop_ = true;
}
}