in fairdiplomacy/agents/searchbot_agent.py [0:0]
def compute_nash_conv(self, cfr_data, label, game, strat_f):
"""For each power, compute EV of each action assuming opponent ave policies"""
# get policy probs for all powers
power_action_ps: Dict[Power, List[float]] = {
pwr: strat_f(pwr) for (pwr, actions) in cfr_data.power_plausible_orders.items()
}
logging.info("Policies: {}".format(power_action_ps))
total_action_utilities: Dict[Tuple[Power, Action], float] = defaultdict(float)
temp_action_utilities: Dict[Tuple[Power, Action], float] = defaultdict(float)
total_state_utility: Dict[Power, float] = defaultdict(float)
max_state_utility: Dict[Power, float] = defaultdict(float)
for pwr, actions in cfr_data.power_plausible_orders.items():
total_state_utility[pwr] = 0
max_state_utility[pwr] = 0
# total_state_utility = [0 for u in idxs]
nash_conv = 0
br_iters = 100
for _ in range(br_iters):
# sample policy for all powers
idxs, power_sampled_orders = sample_orders_from_policy(
cfr_data.power_plausible_orders, power_action_ps
)
# for each power: compare all actions against sampled opponent action
set_orders_dicts = make_set_orders_dicts(
cfr_data.power_plausible_orders, power_sampled_orders
)
all_rollout_results = self.model_rollouts.do_rollouts(game, set_orders_dicts)
for pwr, actions in cfr_data.power_plausible_orders.items():
if len(actions) == 0:
continue
# pop this power's results
results, all_rollout_results = (
all_rollout_results[: len(actions)],
all_rollout_results[len(actions) :],
)
for r in results:
action = r[0][pwr]
val = r[1][pwr]
temp_action_utilities[(pwr, action)] = val
total_action_utilities[(pwr, action)] += val
# logging.info("results for power={}".format(pwr))
# for i in range(len(cfr_data.power_plausible_orders[pwr])):
# action = cfr_data.power_plausible_orders[pwr][i]
# util = action_utilities[i]
# logging.info("{} {} = {}".format(pwr,action,util))
# for action in cfr_data.power_plausible_orders[pwr]:
# logging.info("{} {} = {}".format(pwr,action,action_utilities))
# logging.info("action utilities={}".format(action_utilities))
# logging.info("Results={}".format(results))
# state_utility = np.dot(power_action_ps[pwr], action_utilities)
# action_regrets = [(u - state_utility) for u in action_utilities]
# logging.info("Action utilities={}".format(temp_action_utilities))
# for action in actions:
# total_action_utilities[(pwr,action)] += temp_action_utilities[(pwr,action)]
# logging.info("Total action utilities={}".format(total_action_utilities))
# total_state_utility[pwr] += state_utility
# total_state_utility[:] = [x / 100 for x in total_state_utility]
for pwr, actions in cfr_data.power_plausible_orders.items():
# ps = self.avg_strategy(pwr, cfr_data.power_plausible_orders[pwr])
for i in range(len(actions)):
action = actions[i]
total_action_utilities[(pwr, action)] /= br_iters
if total_action_utilities[(pwr, action)] > max_state_utility[pwr]:
max_state_utility[pwr] = total_action_utilities[(pwr, action)]
total_state_utility[pwr] += (
total_action_utilities[(pwr, action)] * power_action_ps[pwr][i]
)
for pwr, actions in cfr_data.power_plausible_orders.items():
logging.info(
"results for power={} value={} diff={}".format(
pwr,
total_state_utility[pwr],
(max_state_utility[pwr] - total_state_utility[pwr]),
)
)
nash_conv += max_state_utility[pwr] - total_state_utility[pwr]
for i in range(len(actions)):
action = actions[i]
logging.info(
"{} {} = {} (prob {})".format(
pwr, action, total_action_utilities[(pwr, action)], power_action_ps[pwr][i]
)
)
logging.info(f"Nash conv for {label} = {nash_conv}")