rela::TensorDict HanabiEnv::reset()

in cpp/hanabi_env.cc [9:47]


rela::TensorDict HanabiEnv::reset() {
  assert(terminated());
  state_ = std::make_unique<hle::HanabiState>(&game_);
  // chance player
  while (state_->CurPlayer() == hle::kChancePlayerId) {
    state_->ApplyRandomChance();
  }
  numStep_ = 0;

  for (int pid = 0; pid < game_.NumPlayers(); ++pid) {
    playerEps_[pid] = epsList_[game_.rng()->operator()() % epsList_.size()];
  }

  if (shuffleColor_) {
    // assert(game_.NumPlayers() == 2);
    int fixColorPlayer = game_.rng()->operator()() % game_.NumPlayers();
    for (int pid = 0; pid < game_.NumPlayers(); ++pid) {
      auto& colorPermute = colorPermutes_[pid];
      auto& invColorPermute = invColorPermutes_[pid];
      colorPermute.clear();
      invColorPermute.clear();
      for (int i = 0; i < game_.NumColors(); ++i) {
        colorPermute.push_back(i);
        invColorPermute.push_back(i);
      }
      if (pid != fixColorPlayer) {
        std::shuffle(colorPermute.begin(), colorPermute.end(), *game_.rng());
        std::sort(invColorPermute.begin(), invColorPermute.end(), [&](int i, int j) {
          return colorPermute[i] < colorPermute[j];
        });
      }
      for (int i = 0; i < (int)colorPermute.size(); ++i) {
        assert(invColorPermute[colorPermute[i]] == i);
      }
    }
  }

  return computeFeatureAndLegalMove(state_);
}