void R2D2Actor::observeBeforeAct()

in rlcc/r2d2_actor.cc [175:262]


void R2D2Actor::observeBeforeAct(const HanabiEnv& env) {
  torch::NoGradGuard ng;
  prevHidden_ = hidden_;

  rela::TensorDict input;
  const auto& state = env.getHleState();

  if (vdn_) {
    std::vector<rela::TensorDict> vObs;
    for (int i = 0; i < numPlayer_; ++i) {
      vObs.push_back(observe(
          state,
          i,
          shuffleColor_,
          colorPermutes_[i],
          invColorPermutes_[i],
          hideAction_,
          trinary_,
          sad_));
    }
    input = rela::tensor_dict::stack(vObs, 0);
  } else {
    input = observe(
        state,
        playerIdx_,
        shuffleColor_,
        colorPermutes_[0],
        invColorPermutes_[0],
        hideAction_,
        trinary_,
        sad_);
  }

  // add features such as eps and temperature
  input["eps"] = torch::tensor(playerEps_);
  if (playerTemp_.size() > 0) {
    input["temperature"] = torch::tensor(playerTemp_);
  }

  // push before we add hidden
  if (replayBuffer_ != nullptr) {
    r2d2Buffer_->pushObs(input);
  } else {
    // eval mode, collect some stats
    const auto& game = env.getHleGame();
    auto obs = hle::HanabiObservation(state, state.CurPlayer(), true);
    auto encoder = hle::CanonicalObservationEncoder(&game);
    auto [privV0, cardCount] =
        encoder.EncodePrivateV0Belief(obs, std::vector<int>(), false, std::vector<int>());
    perCardPrivV0_ =
        extractPerCardBelief(privV0, env.getHleGame(), obs.Hands()[0].Cards().size());
  }

  addHid(input, hidden_);

  // no-blocking async call to neural network
  futReply_ = runner_->call("act", input);

  if (!offBelief_) {
    return;
  }

  // forward belief model
  assert(!vdn_);
  auto [beliefInput, privCardCount, v0] = beliefModelObserve(
      state,
      playerIdx_,
      shuffleColor_,
      colorPermutes_[0],
      invColorPermutes_[0],
      hideAction_);
  privCardCount_ = privCardCount;

  if (beliefRunner_ == nullptr) {
    sampledCards_ = sampleCards(
        v0,
        privCardCount_,
        invColorPermutes_[0],
        env.getHleGame(),
        state.Hands()[playerIdx_],
        rng_);
  } else {
    addHid(beliefInput, beliefHidden_);
    futBelief_ = beliefRunner_->call("sample", beliefInput);
  }

  fictState_ = std::make_unique<hle::HanabiState>(state);
}