rela::TensorDict observe()

in rlcc/utils.cc [34:103]


rela::TensorDict observe(
    const hle::HanabiState& state,
    int playerIdx,
    bool shuffleColor,
    const std::vector<int>& colorPermute,
    const std::vector<int>& invColorPermute,
    bool hideAction,
    bool trinary,
    bool sad) {
  const auto& game = *(state.ParentGame());
  auto obs = hle::HanabiObservation(state, playerIdx, true);
  auto encoder = hle::CanonicalObservationEncoder(&game);

  std::vector<float> vS = encoder.Encode(
      obs,
      true,  // regardless of the flag, splitPrivatePulic/convertSad will mask out this
             // field
      std::vector<int>(),  // shuffle card
      shuffleColor,
      colorPermute,
      invColorPermute,
      hideAction);

  rela::TensorDict feat;
  if (!sad) {
    feat = splitPrivatePublic(vS, game);
  } else {
    // only for evaluation
    auto vA =
        encoder.EncodeLastAction(obs, std::vector<int>(), shuffleColor, colorPermute);
    feat = convertSad(vS, vA, game);
  }

  if (trinary) {
    auto vOwnHand = encoder.EncodeOwnHandTrinary(obs);
    feat["own_hand"] = torch::tensor(vOwnHand);
  } else {
    auto vOwnHand = encoder.EncodeOwnHand(obs, shuffleColor, colorPermute);
    std::vector<float> vOwnHandARIn(vOwnHand.size(), 0);
    int end = (game.HandSize() - 1) * game.NumColors() * game.NumRanks();
    std::copy(
        vOwnHand.begin(),
        vOwnHand.begin() + end,
        vOwnHandARIn.begin() + game.NumColors() * game.NumRanks());
    feat["own_hand"] = torch::tensor(vOwnHand);
    feat["own_hand_ar_in"] = torch::tensor(vOwnHandARIn);
    auto privARV0 =
        encoder.EncodeARV0Belief(obs, std::vector<int>(), shuffleColor, colorPermute);
    feat["priv_ar_v0"] = torch::tensor(privARV0);
  }

  // legal moves
  const auto& legalMove = state.LegalMoves(playerIdx);
  std::vector<float> vLegalMove(game.MaxMoves() + 1);
  for (auto move : legalMove) {
    if (shuffleColor && move.MoveType() == hle::HanabiMove::Type::kRevealColor) {
      int permColor = colorPermute[move.Color()];
      move.SetColor(permColor);
    }

    auto uid = game.GetMoveUid(move);
    vLegalMove[uid] = 1;
  }
  if (legalMove.size() == 0) {
    vLegalMove[game.MaxMoves()] = 1;
  }

  feat["legal_move"] = torch::tensor(vLegalMove);
  return feat;
}