in rlcc/utils.cc [34:103]
rela::TensorDict observe(
const hle::HanabiState& state,
int playerIdx,
bool shuffleColor,
const std::vector<int>& colorPermute,
const std::vector<int>& invColorPermute,
bool hideAction,
bool trinary,
bool sad) {
const auto& game = *(state.ParentGame());
auto obs = hle::HanabiObservation(state, playerIdx, true);
auto encoder = hle::CanonicalObservationEncoder(&game);
std::vector<float> vS = encoder.Encode(
obs,
true, // regardless of the flag, splitPrivatePulic/convertSad will mask out this
// field
std::vector<int>(), // shuffle card
shuffleColor,
colorPermute,
invColorPermute,
hideAction);
rela::TensorDict feat;
if (!sad) {
feat = splitPrivatePublic(vS, game);
} else {
// only for evaluation
auto vA =
encoder.EncodeLastAction(obs, std::vector<int>(), shuffleColor, colorPermute);
feat = convertSad(vS, vA, game);
}
if (trinary) {
auto vOwnHand = encoder.EncodeOwnHandTrinary(obs);
feat["own_hand"] = torch::tensor(vOwnHand);
} else {
auto vOwnHand = encoder.EncodeOwnHand(obs, shuffleColor, colorPermute);
std::vector<float> vOwnHandARIn(vOwnHand.size(), 0);
int end = (game.HandSize() - 1) * game.NumColors() * game.NumRanks();
std::copy(
vOwnHand.begin(),
vOwnHand.begin() + end,
vOwnHandARIn.begin() + game.NumColors() * game.NumRanks());
feat["own_hand"] = torch::tensor(vOwnHand);
feat["own_hand_ar_in"] = torch::tensor(vOwnHandARIn);
auto privARV0 =
encoder.EncodeARV0Belief(obs, std::vector<int>(), shuffleColor, colorPermute);
feat["priv_ar_v0"] = torch::tensor(privARV0);
}
// legal moves
const auto& legalMove = state.LegalMoves(playerIdx);
std::vector<float> vLegalMove(game.MaxMoves() + 1);
for (auto move : legalMove) {
if (shuffleColor && move.MoveType() == hle::HanabiMove::Type::kRevealColor) {
int permColor = colorPermute[move.Color()];
move.SetColor(permColor);
}
auto uid = game.GetMoveUid(move);
vLegalMove[uid] = 1;
}
if (legalMove.size() == 0) {
vLegalMove[game.MaxMoves()] = 1;
}
feat["legal_move"] = torch::tensor(vLegalMove);
return feat;
}