void push()

in rela/r2d2_actor.h [29:87]


  void push(const FFTransition& transition,
            const torch::Tensor& priority,
            const TensorDict& hid) {
    assert(priority.size(0) == batchsize_);

    auto priorityAccessor = priority.accessor<float, 1>();
    for (int i = 0; i < batchsize_; ++i) {
      auto t = transition.index(i);

      if (batchNextIdx_[i] == 0) {
        // it does not matter here, should be reset after burnin
        batchH0_[i] = utils::tensorDictNarrow(hid, 1, i, 1, true);
        for (auto& kv : batchH0_[i]) {
          assert(kv.second.sum().item<float>() == 0);
        }

        while (batchNextIdx_[i] < burnin_) {
          batchSeqTransition_[i][batchNextIdx_[i]] = t.padLike();
          ++batchNextIdx_[i];
        }
      } else {
        // should not append after terminal
        // terminal should be processed when it is pushed
        int nextIdx = batchNextIdx_[i];
        if (batchSeqTransition_[i][nextIdx - 1].terminal.item<bool>()) {
          std::cout << nextIdx << std::endl;
          assert(false);
        }
        assert(batchLen_[i] == 0);
      }

      int nextIdx = batchNextIdx_[i];
      // std::cout << "next idx: " << nextIdx << std::endl;
      assert(nextIdx < burnin_ + seqLen_ + multiStep_ && nextIdx >= burnin_);

      // burnin_ + seqLen_ - burnin_ = seqLen_
      if (nextIdx == seqLen_) {
        // will become stored hidden for next trajectory
        batchNextH0_[i] = utils::tensorDictNarrow(hid, 1, i, 1, true);
      }

      batchSeqTransition_[i][nextIdx] = t;
      batchSeqPriority_[i][nextIdx - burnin_] = priorityAccessor[i];

      ++batchNextIdx_[i];
      if (!t.terminal.item<bool>() && batchNextIdx_[i] < burnin_ + seqLen_ + multiStep_) {
        continue;
      }

      // pad the rest of the seq in case of terminal
      batchLen_[i] = batchNextIdx_[i];
      while (batchNextIdx_[i] < burnin_ + seqLen_ + multiStep_) {
        batchSeqTransition_[i][batchNextIdx_[i]] = t.padLike();
        batchSeqPriority_[i][batchNextIdx_[i] - burnin_] = 0;
        ++batchNextIdx_[i];
      }
      canPop_ = true;
    }
  }