void R2D2Actor::act()

in rlcc/r2d2_actor.cc [264:368]


void R2D2Actor::act(HanabiEnv& env, const int curPlayer) {
  torch::NoGradGuard ng;

  auto& state = env.getHleState();
  auto reply = futReply_.get();
  moveHid(reply, hidden_);

  if (replayBuffer_ != nullptr) {
    r2d2Buffer_->pushAction(reply);
  }

  rela::TensorDict beliefReply;
  if (offBelief_ && beliefRunner_ != nullptr) {
    beliefReply = futBelief_.get();
    moveHid(beliefReply, beliefHidden_);
    // if it is not our turn, then this is all we need for belief
  }

  int action;
  const std::vector<int>* invColorPermute;
  if (vdn_) {
    action = reply.at("a")[curPlayer].item<int64_t>();
    invColorPermute = &(invColorPermutes_[curPlayer]);
  } else {
    action = reply.at("a").item<int64_t>();
    invColorPermute = &(invColorPermutes_[0]);
  }

  if (offBelief_) {
    const auto& hand = fictState_->Hands()[playerIdx_];
    bool success = true;
    if (beliefRunner_ != nullptr) {
      auto sample = beliefReply.at("sample");
      std::tie(sampledCards_, success) = filterSample(
          sample,
          privCardCount_,
          *invColorPermute,
          env.getHleGame(),  // *fictGame_,
          hand);
    }
    if (success) {
      auto& deck = fictState_->Deck();
      deck.PutCardsBack(hand.Cards());
      deck.DealCards(sampledCards_);
      fictState_->Hands()[playerIdx_].SetCards(sampledCards_);
      ++successFict_;
    }
    validFict_ = success;
    ++totalFict_;
  }

  if (!vdn_ && curPlayer != playerIdx_) {
    if (offBelief_) {
      auto partner = partners_[curPlayer];
      assert(partner != nullptr);
      // it is not my turn, I need to re-evaluate my partner on
      // the fictitious transition
      auto partnerInput = observe(
          *fictState_,
          partner->playerIdx_,
          partner->shuffleColor_,
          partner->colorPermutes_[0],
          partner->invColorPermutes_[0],
          partner->hideAction_,
          partner->trinary_,
          partner->sad_);
      // add features such as eps and temperature
      partnerInput["eps"] = torch::tensor(partner->playerEps_);
      if (partner->playerTemp_.size() > 0) {
        partnerInput["temperature"] = torch::tensor(partner->playerTemp_);
      }
      addHid(partnerInput, partner->prevHidden_);
      assert(fictReply_.isNull());
      fictReply_ = partner->runner_->call("act", partnerInput);
    }

    assert(action == env.noOpUid());
    return;
  }

  auto move = state.ParentGame()->GetMove(action);
  if (shuffleColor_ && move.MoveType() == hle::HanabiMove::Type::kRevealColor) {
    int realColor = (*invColorPermute)[move.Color()];
    move.SetColor(realColor);
  }

  if (replayBuffer_ == nullptr) {
    if (move.MoveType() == hle::HanabiMove::kPlay) {
      auto cardBelief = perCardPrivV0_[move.CardIndex()];
      auto [colorKnown, rankKnown] = analyzeCardBelief(cardBelief);

      if (colorKnown && rankKnown) {
        ++bothKnown_;
      } else if (colorKnown) {
        ++colorKnown_;
      } else if (rankKnown) {
        ++rankKnown_;
      } else {
        ++noneKnown_;
      }
    }
  }

  env.step(move);
}