void push()

in rela/transition_buffer.h [134:176]


  void push(
      const FFTransition& transition,
      const torch::Tensor& priority,
      const TensorDict& /*hid*/) {
    assert(priority.size(0) == batchsize);

    auto priorityAccessor = priority.accessor<float, 1>();
    for (int i = 0; i < batchsize; ++i) {
      int nextIdx = batchNextIdx_[i];
      assert(nextIdx < seqLen && nextIdx >= 0);
      if (nextIdx == 0) {
        // TODO: !!! simplification for unconditional h0
        // batchH0_[i] =
        //     utils::tensorDictNarrow(hid, 1, i * numPlayer, numPlayer, false);
      }

      auto t = transition.index(i);
      // some sanity check for termination
      if (nextIdx != 0) {
        // should not append after terminal
        // terminal should be processed when it is pushed
        assert(!batchSeqTransition_[i][nextIdx - 1].terminal.item<bool>());
        assert(batchLen_[i] == 0);
      }

      batchSeqTransition_[i][nextIdx] = t;
      batchSeqPriority_[i][nextIdx] = priorityAccessor[i];

      ++batchNextIdx_[i];
      if (!t.terminal.item<bool>()) {
        continue;
      }

      // pad the rest of the seq in case of terminal
      batchLen_[i] = batchNextIdx_[i];
      while (batchNextIdx_[i] < seqLen) {
        batchSeqTransition_[i][batchNextIdx_[i]] = t.padLike();
        batchSeqPriority_[i][batchNextIdx_[i]] = 0;
        ++batchNextIdx_[i];
      }
      canPop_ = true;
    }
  }