rela::TensorDict HanabiEnv::computeFeatureAndLegalMove()

in cpp/hanabi_env.cc [115:205]


rela::TensorDict HanabiEnv::computeFeatureAndLegalMove(
    const std::unique_ptr<hle::HanabiState>& cloneState) {
  std::vector<torch::Tensor> privS;
  // std::vector<torch::Tensor> publS;
  // std::vector<torch::Tensor> superS;
  std::vector<torch::Tensor> legalMove;
  std::vector<torch::Tensor> legalMatrix;
  // auto epsAccessor = eps_.accessor<float, 1>();
  // std::vector<float> eps;
  std::vector<torch::Tensor> ownHand;
  // std::vector<torch::Tensor> ownHandARIn;
  // std::vector<torch::Tensor> allHand;
  // std::vector<torch::Tensor> allHandARIn;

  // std::vector<torch::Tensor> privCardCount;

  for (int i = 0; i < game_.NumPlayers(); ++i) {
    auto obs = hle::HanabiObservation(*state_, i, false);
    std::vector<int> shuffleOrder;
    if (shuffleObs_) {
      // hacked for 2 players
      assert(game_.NumPlayers() == 2);
      // [1] for partner's hand
      int partnerHandSize = obs.Hands()[1].Cards().size();
      for (int i = 0; i < partnerHandSize; ++i) {
        shuffleOrder.push_back(i);
      }
      std::shuffle(shuffleOrder.begin(), shuffleOrder.end(), *game_.rng());
    }

    std::vector<float> vS = obsEncoder_.Encode(
        obs,
        false,
        shuffleOrder,
        shuffleColor_,
        colorPermutes_[i],
        invColorPermutes_[i],
        false);

    if (sad_) {
      assert(cloneState != nullptr);
      auto extraObs = hle::HanabiObservation(*cloneState, i, false);
      std::vector<float> vGreedyAction = obsEncoder_.EncodeLastAction(
          extraObs, shuffleOrder, shuffleColor_, colorPermutes_[i]);
      vS.insert(vS.end(), vGreedyAction.begin(), vGreedyAction.end());
    }

    privS.push_back(torch::tensor(vS));

    {
      auto cheatObs = hle::HanabiObservation(*state_, i, true);
      auto vOwnHand = obsEncoder_.EncodeOwnHandTrinary(cheatObs);
      ownHand.push_back(torch::tensor(vOwnHand));
    }

    // legal moves
    auto legalMoves = state_->LegalMoves(i);
    std::vector<float> moveUids(numAction(), 0);
    // auto moveUids = torch::zeros({numAction()});
    // auto moveAccessor = moveUids.accessor<float, 1>();
    for (auto move : legalMoves) {
      if (shuffleColor_ &&
          // fixColorPlayer_ == i &&
          move.MoveType() == hle::HanabiMove::Type::kRevealColor) {
        int permColor = colorPermutes_[i][move.Color()];
        move.SetColor(permColor);
      }
      auto uid = game_.GetMoveUid(move);
      if (uid >= noOpUid()) {
        std::cout << "Error: legal move id should be < " << numAction() - 1 << std::endl;
        assert(false);
      }
      moveUids[uid] = 1;
    }
    if (legalMoves.size() == 0) {
      moveUids[noOpUid()] = 1;
    }

    legalMove.push_back(torch::tensor(moveUids));
    // epsAccessor[i] = playerEps_[i];
  }

  rela::TensorDict dict = {
      {"priv_s", torch::stack(privS, 0)},
      {"legal_move", torch::stack(legalMove, 0)},
      {"eps", torch::tensor(playerEps_)},
      {"own_hand", torch::stack(ownHand, 0)},
  };

  return dict;
}