in fairdiplomacy/situation_check.py [0:0]
def run_situation_check(meta, agent, extra_plausible_orders_str: str = ""):
extra_plausible_orders: Optional[Dict[str, List[Tuple[str, ...]]]]
if extra_plausible_orders_str:
assert isinstance(agent, SearchBotAgent)
extra_plausible_orders = _parse_extra_plausible_orders(extra_plausible_orders_str)
else:
extra_plausible_orders = None
results = {}
for name, config in meta.items():
logging.info("=" * 80)
comment = config.get("comment", "")
logging.info(f"{name}: {comment} (phase={config.get('phase')})")
# If path is not absolute, treat as relative to code root.
game_path = heyhi.PROJ_ROOT / config["game_path"]
logging.info(f"path: {game_path}")
with open(game_path) as f:
game = Game.from_json(f.read())
if "phase" in config:
game.rolled_back_to_phase_start(config["phase"])
if hasattr(agent, "get_all_power_prob_distributions"):
if isinstance(agent, SearchBotAgent):
prob_distributions = agent.get_all_power_prob_distributions(
game, extra_plausible_orders=extra_plausible_orders
)
else:
prob_distributions = agent.get_all_power_prob_distributions(
game
) # FIXME: early exit
logging.info("CFR strategy:")
else:
# this is a supervised agent, sample N times to get a distribution
NUM_ROLLOUTS = 100
prob_distributions = {p: defaultdict(float) for p in POWERS}
for power in POWERS:
for N in range(NUM_ROLLOUTS):
orders = agent.get_orders(game, power)
prob_distributions[power][tuple(orders)] += 1 / NUM_ROLLOUTS
if hasattr(agent, "get_values"):
logging.info(
"Values: %s",
" ".join(f"{p}={v:.3f}" for p, v in zip(POWERS, agent.get_values(game))),
)
for power in POWERS:
pd = prob_distributions[power]
pdl = sorted(list(pd.items()), key=lambda x: -x[1])
logging.info(f" {power}")
for order, prob in pdl:
if prob < 0.02:
break
logging.info(f" {prob:5.2f} {order}")
for i, (test_desc, test_func_str) in enumerate(config.get("tests", {}).items()):
test_func = eval(test_func_str)
passed = test_func(prob_distributions)
results[f"{name}.{i}"] = int(passed)
res_string = "PASSED" if passed else "FAILED"
logging.info(f"Result: {res_string:8s} {name:20s} {test_desc}")
logging.info(f" {test_func_str}")
logging.info("Passed: %d/%d", sum(results.values()), len(results))
logging.info("JSON: %s", results)
return results