TreeStrategyStats compute_stategy_stats()

in csrc/liars_dice/subgame_solving.cc [823:899]


TreeStrategyStats compute_stategy_stats(const Game& game,
                                        const TreeStrategy& strategy) {
  const auto uniform_beliefs = get_initial_beliefs(game).at(0);
  const auto tree = unroll_tree(game);
  TreeStrategyStats stats;
  stats.tree = tree;

  auto& reach_probabilities = stats.reach_probabilities;
  init_nd(tree.size(), game.num_hands(), 0.0, &reach_probabilities[0]);
  init_nd(tree.size(), game.num_hands(), 0.0, &reach_probabilities[1]);
  auto& tree_values = stats.values;
  init_nd(tree.size(), game.num_hands(), 0.0, &tree_values[0]);
  init_nd(tree.size(), game.num_hands(), 0.0, &tree_values[1]);
  stats.node_reach.resize(tree.size());
  stats.node_values[0].resize(tree.size());
  stats.node_values[1].resize(tree.size());
  for (int player : {0, 1}) {
    compute_reach_probabilities(tree, strategy, uniform_beliefs, player,
                                &reach_probabilities[player]);
  }
  for (size_t node_id = tree.size(); node_id-- > 0;) {
    stats.node_reach[node_id] = vector_sum(reach_probabilities[0][node_id]) *
                                vector_sum(reach_probabilities[1][node_id]);
  }
  for (int player : {0, 1}) {
    for (size_t node_id = tree.size(); node_id-- > 0;) {
      const auto& node = tree[node_id];
      const auto& state = node.state;
      std::vector<double>& node_values = tree_values[player][node_id];
      const auto op_reach_probabilities =
          reach_probabilities[1 - player][node_id];
      std::vector<double> op_beliefs = normalize_probabilities_safe(
          op_reach_probabilities, kReachSmoothingEps);
      if (game.is_terminal(state)) {
        const auto last_bid = tree[node.parent].state.last_bid;
        node_values = compute_expected_terminal_values(
            game, last_bid, /*inverse=*/state.player_id != player, op_beliefs);
      } else {
        assert(node.num_children() > 0);
      }
      if (state.player_id == player) {
        for (int hand = 0; hand < game.num_hands(); ++hand) {
          for (auto [child_node_id, action] : ChildrenActionIt(node, game)) {
            tree_values[player][node_id][hand] +=
                strategy[node_id][hand][action] *
                tree_values[player][child_node_id][hand];
          }
        }
      } else {
        for (auto [child_node_id, action] : ChildrenActionIt(node, game)) {
          double action_prob = 0;
          // Iterating over op's hands.
          for (int hand = 0; hand < game.num_hands(); ++hand) {
            action_prob += strategy[node_id][hand][action] * op_beliefs[hand];
          }
          // Iterating over traverser's hands.
          for (int hand = 0; hand < game.num_hands(); ++hand) {
            tree_values[player][node_id][hand] +=
                action_prob * tree_values[player][child_node_id][hand];
          }
        }
      }
    }
  }
  for (int player : {0, 1}) {
    for (size_t node_id = tree.size(); node_id-- > 0;) {
      auto beliefs = normalize_probabilities_safe(
          reach_probabilities[player][node_id], 1e-6);
      for (int hand = 0; hand < game.num_hands(); ++hand) {
        stats.node_values[player][node_id] +=
            beliefs[hand] * tree_values[player][node_id][hand];
      }
    }
  }

  return stats;
}