in pyhanabi/train_belief.py [0:0]
def create_rl_context(args):
agent_overwrite = {
"vdn": False,
"device": args.train_device, # batch runner will create copy on act device
"uniform_priority": True,
}
if args.clone_bot:
agent = utils.load_supervised_agent(args.policy, args.train_device)
cfgs = {
"act_base_eps": 0.1,
"act_eps_alpha": 7,
"num_game_per_thread": 80,
"num_player": 2,
"train_bomb": 0,
"max_len": 80,
"sad": 0,
"shuffle_color": 0,
"hide_action": 0,
"multi_step": 1,
"gamma": 0.999,
}
else:
agent, cfgs = utils.load_agent(args.policy, agent_overwrite)
assert cfgs["shuffle_color"] == False
assert args.explore
replay_buffer = rela.RNNPrioritizedReplay(
args.replay_buffer_size,
args.seed,
1.0, # priority exponent
0.0, # priority weight
args.prefetch,
)
if args.rand:
explore_eps = [1]
elif args.explore:
# use the same exploration config as policy learning
explore_eps = utils.generate_explore_eps(
cfgs["act_base_eps"], cfgs["act_eps_alpha"], cfgs["num_game_per_thread"]
)
else:
explore_eps = [0]
expected_eps = np.mean(explore_eps)
print("explore eps:", explore_eps)
print("avg explore eps:", np.mean(explore_eps))
if args.clone_bot or not agent.boltzmann:
print("no boltzmann act")
boltzmann_t = []
else:
boltzmann_beta = utils.generate_log_uniform(
1 / cfgs["max_t"], 1 / cfgs["min_t"], cfgs["num_t"]
)
boltzmann_t = [1 / b for b in boltzmann_beta]
print("boltzmann beta:", ", ".join(["%.2f" % b for b in boltzmann_beta]))
print("avg boltzmann beta:", np.mean(boltzmann_beta))
games = create_envs(
args.num_thread * args.num_game_per_thread,
args.seed,
cfgs["num_player"],
cfgs["train_bomb"],
cfgs["max_len"],
)
act_group = ActGroup(
args.act_device,
agent,
args.seed,
args.num_thread,
args.num_game_per_thread,
cfgs["num_player"],
explore_eps,
boltzmann_t,
"iql",
cfgs["sad"],
cfgs["shuffle_color"] if not args.rand else False,
cfgs["hide_action"],
False, # not trinary, need full hand for prediction
replay_buffer,
cfgs["multi_step"], # not used
cfgs["max_len"],
cfgs["gamma"], # not used
False, # turn off off-belief rewardless of how it is trained
None, # belief_model
)
context, threads = create_threads(
args.num_thread,
args.num_game_per_thread,
act_group.actors,
games,
)
return agent, cfgs, replay_buffer, games, act_group, context, threads