in rlcc/r2d2_actor.cc [264:368]
void R2D2Actor::act(HanabiEnv& env, const int curPlayer) {
torch::NoGradGuard ng;
auto& state = env.getHleState();
auto reply = futReply_.get();
moveHid(reply, hidden_);
if (replayBuffer_ != nullptr) {
r2d2Buffer_->pushAction(reply);
}
rela::TensorDict beliefReply;
if (offBelief_ && beliefRunner_ != nullptr) {
beliefReply = futBelief_.get();
moveHid(beliefReply, beliefHidden_);
// if it is not our turn, then this is all we need for belief
}
int action;
const std::vector<int>* invColorPermute;
if (vdn_) {
action = reply.at("a")[curPlayer].item<int64_t>();
invColorPermute = &(invColorPermutes_[curPlayer]);
} else {
action = reply.at("a").item<int64_t>();
invColorPermute = &(invColorPermutes_[0]);
}
if (offBelief_) {
const auto& hand = fictState_->Hands()[playerIdx_];
bool success = true;
if (beliefRunner_ != nullptr) {
auto sample = beliefReply.at("sample");
std::tie(sampledCards_, success) = filterSample(
sample,
privCardCount_,
*invColorPermute,
env.getHleGame(), // *fictGame_,
hand);
}
if (success) {
auto& deck = fictState_->Deck();
deck.PutCardsBack(hand.Cards());
deck.DealCards(sampledCards_);
fictState_->Hands()[playerIdx_].SetCards(sampledCards_);
++successFict_;
}
validFict_ = success;
++totalFict_;
}
if (!vdn_ && curPlayer != playerIdx_) {
if (offBelief_) {
auto partner = partners_[curPlayer];
assert(partner != nullptr);
// it is not my turn, I need to re-evaluate my partner on
// the fictitious transition
auto partnerInput = observe(
*fictState_,
partner->playerIdx_,
partner->shuffleColor_,
partner->colorPermutes_[0],
partner->invColorPermutes_[0],
partner->hideAction_,
partner->trinary_,
partner->sad_);
// add features such as eps and temperature
partnerInput["eps"] = torch::tensor(partner->playerEps_);
if (partner->playerTemp_.size() > 0) {
partnerInput["temperature"] = torch::tensor(partner->playerTemp_);
}
addHid(partnerInput, partner->prevHidden_);
assert(fictReply_.isNull());
fictReply_ = partner->runner_->call("act", partnerInput);
}
assert(action == env.noOpUid());
return;
}
auto move = state.ParentGame()->GetMove(action);
if (shuffleColor_ && move.MoveType() == hle::HanabiMove::Type::kRevealColor) {
int realColor = (*invColorPermute)[move.Color()];
move.SetColor(realColor);
}
if (replayBuffer_ == nullptr) {
if (move.MoveType() == hle::HanabiMove::kPlay) {
auto cardBelief = perCardPrivV0_[move.CardIndex()];
auto [colorKnown, rankKnown] = analyzeCardBelief(cardBelief);
if (colorKnown && rankKnown) {
++bothKnown_;
} else if (colorKnown) {
++colorKnown_;
} else if (rankKnown) {
++rankKnown_;
} else {
++noneKnown_;
}
}
}
env.step(move);
}