std::vector ComputeCardCount()

in hanabi-learning-environment/hanabi_lib/canonical_encoders.cc [806:873]


std::vector<int> ComputeCardCount(
    const HanabiGame& game,
    const HanabiObservation& obs,
    bool shuffle_color,
    const std::vector<int>& color_permute,
    bool publ) {
  int num_colors = game.NumColors();
  int num_ranks = game.NumRanks();

  std::vector<int> card_count(num_colors * num_ranks, 0);
  int total_count = 0;
  // full deck card count
  for (int color = 0; color < game.NumColors(); ++color) {
    for (int rank = 0; rank < game.NumRanks(); ++rank) {
      auto count = game.NumberCardInstances(color, rank);
      card_count[CardIndex(color, rank, num_ranks, shuffle_color, color_permute)] = count;
      total_count += count;
    }
  }
  // remove discard
  for (const HanabiCard& card : obs.DiscardPile()) {
    --card_count[CardIndex(card.Color(), card.Rank(), num_ranks, shuffle_color, color_permute)];
    --total_count;
  }

  // remove firework
  const std::vector<int>& fireworks = obs.Fireworks();
  for (int c = 0; c < num_colors; ++c) {
    // fireworks[color] is the number of successfully played <color> cards.
    // If some were played, one-hot encode the highest (0-indexed) rank played
    if (fireworks[c] > 0) {
      for (int rank = 0; rank < fireworks[c]; ++rank) {
        --card_count[CardIndex(c, rank, num_ranks, shuffle_color, color_permute)];
        --total_count;
      }
    }
  }

  if (publ) {
    return card_count;
  }

  // {
  //   // sanity check
  //   const std::vector<HanabiHand>& hands = obs.Hands();
  //   int total_hand_size = 0;
  //   for (const auto& hand : hands) {
  //     total_hand_size += hand.Cards().size();
  //   }
  //   if(total_count != obs.DeckSize() + total_hand_size) {
  //     std::cout << "size mismatch: " << total_count
  //               << " vs " << obs.DeckSize() + total_hand_size << std::endl;
  //     assert(false);
  //   }
  // }

  // convert to private count
  for (int i = 1; i < obs.Hands().size(); ++i) {
    const auto& hand = obs.Hands()[i];
    for (auto card : hand.Cards()) {
      int index = CardIndex(card.Color(), card.Rank(), num_ranks, shuffle_color, color_permute);
      --card_count[index];
      assert(card_count[index] >= 0);
    }
  }

  return card_count;
}