std::tuple HanabiEnv::step()

in cpp/hanabi_env.cc [49:113]


std::tuple<rela::TensorDict, float, bool> HanabiEnv::step(
    const rela::TensorDict& action) {
  assert(!terminated());

  numStep_ += 1;

  float prevScore = state_->Score();

  // perform action for only current player
  int curPlayer = state_->CurPlayer();
  int actionUid = action.at("a")[curPlayer].item<int>();
  hle::HanabiMove move = game_.GetMove(actionUid);
  maybeInversePermuteColor_(move, curPlayer);

  if (!state_->MoveIsLegal(move)) {
    std::cout << "Error: move is not legal" << std::endl;
    std::cout << "UID: " << actionUid << std::endl;
    std::cout << "legal move:" << std::endl;
    std::cout << "numStep: " << numStep_ - 1 << std::endl;

    auto legalMoves = state_->LegalMoves(curPlayer);
    for (auto move : legalMoves) {
      if (shuffleColor_ &&
          move.MoveType() == hle::HanabiMove::Type::kRevealColor) {
        int permColor = colorPermutes_[curPlayer][move.Color()];
        move.SetColor(permColor);
      }
      auto uid = game_.GetMoveUid(move);
      std::cout << "legal_move: " << uid << std::endl;
    }
    assert(false);
  }

  std::unique_ptr<hle::HanabiState> cloneState = nullptr;
  if (sad_) {
    cloneState = std::make_unique<hle::HanabiState>(*state_);
    int greedyActionUid = action.at("greedy_a")[curPlayer].item<int>();
    hle::HanabiMove greedyMove = game_.GetMove(greedyActionUid);
    maybeInversePermuteColor_(greedyMove, curPlayer);

    assert(state_->MoveIsLegal(greedyMove));
    cloneState->ApplyMove(greedyMove);
  }
  state_->ApplyMove(move);

  bool terminal = state_->IsTerminal();
  float reward = state_->Score() - prevScore;

  // forced termination, lose all points
  if (maxLen_ > 0 && numStep_ == maxLen_) {
    terminal = true;
    reward = 0 - prevScore;
  }

  if (!terminal) {
    // chance player
    while (state_->CurPlayer() == hle::kChancePlayerId) {
      state_->ApplyRandomChance();
    }
  }

  // std::cout << "score: " << state_->Score() << std::endl;
  auto obs = computeFeatureAndLegalMove(cloneState);
  return std::make_tuple(obs, reward, terminal);
}