virtual void postAct()

in rela/r2d2_actor.h [103:172]


  virtual void postAct(const torch::Tensor& r, const torch::Tensor& t) {
    if (replayBuffer_ == nullptr) {
      return;
    }

    // assert(replayBuffer_ != nullptr);
    multiStepBuffer_->pushRewardAndTerminal(r, t);

    // if ith state is terminal, reset hidden states
    // h0: [num_layers * num_directions, batch, hidden_size]
    TensorDict h0 = getH0(1, numPlayer_);
    auto terminal = t.accessor<bool, 1>();
    // std::cout << "terminal size: " << t.sizes() << std::endl;
    // std::cout << "hid size: " << hidden_["h0"].sizes() << std::endl;
    for (int i = 0; i < terminal.size(0); i++) {
      if (!terminal[i]) {
        continue;
      }
      for (auto& kv : hidden_) {
        // [numLayer, numEnvs, hidDim]
        // [numLayer, numEnvs, numPlayer (>1), hidDim]
        kv.second.narrow(1, i * numPlayer_, numPlayer_) = h0.at(kv.first);
      }
    }

    if (replayBuffer_ == nullptr) {
      return;
    }
    assert(multiStepBuffer_->size() == historyHidden_.size());

    if (!multiStepBuffer_->canPop()) {
      assert(!r2d2Buffer_->canPop());
      return;
    }

    {
      FFTransition transition = multiStepBuffer_->popTransition();
      TensorDict hid = historyHidden_.front();
      TensorDict nextHid = historyHidden_.back();
      historyHidden_.pop_front();

      auto input = transition.toDict();
      for (auto& kv : hid) {
        auto ret = input.emplace(kv.first, kv.second.transpose(0, 1));
        assert(ret.second);
      }
      for (auto& kv : nextHid) {
        auto ret = input.emplace("next_" + kv.first, kv.second.transpose(0, 1));
        assert(ret.second);
      }

      int slot = -1;
      auto futureReply = runner_->call("compute_priority", input, &slot);
      auto priority = futureReply->get(slot)["priority"];

      r2d2Buffer_->push(transition, priority, hid);
    }

    if (!r2d2Buffer_->canPop()) {
      return;
    }

    std::vector<RNNTransition> batch;
    torch::Tensor seqBatchPriority;
    torch::Tensor batchLen;

    std::tie(batch, seqBatchPriority, batchLen) = r2d2Buffer_->popTransition();
    auto priority = aggregatePriority(seqBatchPriority, batchLen, eta_);
    replayBuffer_->add(batch, priority);
  }