in hanabi-learning-environment/hanabi_lib/canonical_encoders.cc [806:873]
std::vector<int> ComputeCardCount(
const HanabiGame& game,
const HanabiObservation& obs,
bool shuffle_color,
const std::vector<int>& color_permute,
bool publ) {
int num_colors = game.NumColors();
int num_ranks = game.NumRanks();
std::vector<int> card_count(num_colors * num_ranks, 0);
int total_count = 0;
// full deck card count
for (int color = 0; color < game.NumColors(); ++color) {
for (int rank = 0; rank < game.NumRanks(); ++rank) {
auto count = game.NumberCardInstances(color, rank);
card_count[CardIndex(color, rank, num_ranks, shuffle_color, color_permute)] = count;
total_count += count;
}
}
// remove discard
for (const HanabiCard& card : obs.DiscardPile()) {
--card_count[CardIndex(card.Color(), card.Rank(), num_ranks, shuffle_color, color_permute)];
--total_count;
}
// remove firework
const std::vector<int>& fireworks = obs.Fireworks();
for (int c = 0; c < num_colors; ++c) {
// fireworks[color] is the number of successfully played <color> cards.
// If some were played, one-hot encode the highest (0-indexed) rank played
if (fireworks[c] > 0) {
for (int rank = 0; rank < fireworks[c]; ++rank) {
--card_count[CardIndex(c, rank, num_ranks, shuffle_color, color_permute)];
--total_count;
}
}
}
if (publ) {
return card_count;
}
// {
// // sanity check
// const std::vector<HanabiHand>& hands = obs.Hands();
// int total_hand_size = 0;
// for (const auto& hand : hands) {
// total_hand_size += hand.Cards().size();
// }
// if(total_count != obs.DeckSize() + total_hand_size) {
// std::cout << "size mismatch: " << total_count
// << " vs " << obs.DeckSize() + total_hand_size << std::endl;
// assert(false);
// }
// }
// convert to private count
for (int i = 1; i < obs.Hands().size(); ++i) {
const auto& hand = obs.Hands()[i];
for (auto card : hand.Cards()) {
int index = CardIndex(card.Color(), card.Rank(), num_ranks, shuffle_color, color_permute);
--card_count[index];
assert(card_count[index] >= 0);
}
}
return card_count;
}