in cpp/hanabi_env.cc [115:205]
rela::TensorDict HanabiEnv::computeFeatureAndLegalMove(
const std::unique_ptr<hle::HanabiState>& cloneState) {
std::vector<torch::Tensor> privS;
// std::vector<torch::Tensor> publS;
// std::vector<torch::Tensor> superS;
std::vector<torch::Tensor> legalMove;
std::vector<torch::Tensor> legalMatrix;
// auto epsAccessor = eps_.accessor<float, 1>();
// std::vector<float> eps;
std::vector<torch::Tensor> ownHand;
// std::vector<torch::Tensor> ownHandARIn;
// std::vector<torch::Tensor> allHand;
// std::vector<torch::Tensor> allHandARIn;
// std::vector<torch::Tensor> privCardCount;
for (int i = 0; i < game_.NumPlayers(); ++i) {
auto obs = hle::HanabiObservation(*state_, i, false);
std::vector<int> shuffleOrder;
if (shuffleObs_) {
// hacked for 2 players
assert(game_.NumPlayers() == 2);
// [1] for partner's hand
int partnerHandSize = obs.Hands()[1].Cards().size();
for (int i = 0; i < partnerHandSize; ++i) {
shuffleOrder.push_back(i);
}
std::shuffle(shuffleOrder.begin(), shuffleOrder.end(), *game_.rng());
}
std::vector<float> vS = obsEncoder_.Encode(
obs,
false,
shuffleOrder,
shuffleColor_,
colorPermutes_[i],
invColorPermutes_[i],
false);
if (sad_) {
assert(cloneState != nullptr);
auto extraObs = hle::HanabiObservation(*cloneState, i, false);
std::vector<float> vGreedyAction = obsEncoder_.EncodeLastAction(
extraObs, shuffleOrder, shuffleColor_, colorPermutes_[i]);
vS.insert(vS.end(), vGreedyAction.begin(), vGreedyAction.end());
}
privS.push_back(torch::tensor(vS));
{
auto cheatObs = hle::HanabiObservation(*state_, i, true);
auto vOwnHand = obsEncoder_.EncodeOwnHandTrinary(cheatObs);
ownHand.push_back(torch::tensor(vOwnHand));
}
// legal moves
auto legalMoves = state_->LegalMoves(i);
std::vector<float> moveUids(numAction(), 0);
// auto moveUids = torch::zeros({numAction()});
// auto moveAccessor = moveUids.accessor<float, 1>();
for (auto move : legalMoves) {
if (shuffleColor_ &&
// fixColorPlayer_ == i &&
move.MoveType() == hle::HanabiMove::Type::kRevealColor) {
int permColor = colorPermutes_[i][move.Color()];
move.SetColor(permColor);
}
auto uid = game_.GetMoveUid(move);
if (uid >= noOpUid()) {
std::cout << "Error: legal move id should be < " << numAction() - 1 << std::endl;
assert(false);
}
moveUids[uid] = 1;
}
if (legalMoves.size() == 0) {
moveUids[noOpUid()] = 1;
}
legalMove.push_back(torch::tensor(moveUids));
// epsAccessor[i] = playerEps_[i];
}
rela::TensorDict dict = {
{"priv_s", torch::stack(privS, 0)},
{"legal_move", torch::stack(legalMove, 0)},
{"eps", torch::tensor(playerEps_)},
{"own_hand", torch::stack(ownHand, 0)},
};
return dict;
}