def create_rl_context()

in pyhanabi/train_belief.py [0:0]


def create_rl_context(args):
    agent_overwrite = {
        "vdn": False,
        "device": args.train_device,  # batch runner will create copy on act device
        "uniform_priority": True,
    }

    if args.clone_bot:
        agent = utils.load_supervised_agent(args.policy, args.train_device)
        cfgs = {
            "act_base_eps": 0.1,
            "act_eps_alpha": 7,
            "num_game_per_thread": 80,
            "num_player": 2,
            "train_bomb": 0,
            "max_len": 80,
            "sad": 0,
            "shuffle_color": 0,
            "hide_action": 0,
            "multi_step": 1,
            "gamma": 0.999,
        }
    else:
        agent, cfgs = utils.load_agent(args.policy, agent_overwrite)

    assert cfgs["shuffle_color"] == False
    assert args.explore

    replay_buffer = rela.RNNPrioritizedReplay(
        args.replay_buffer_size,
        args.seed,
        1.0,  # priority exponent
        0.0,  # priority weight
        args.prefetch,
    )

    if args.rand:
        explore_eps = [1]
    elif args.explore:
        # use the same exploration config as policy learning
        explore_eps = utils.generate_explore_eps(
            cfgs["act_base_eps"], cfgs["act_eps_alpha"], cfgs["num_game_per_thread"]
        )
    else:
        explore_eps = [0]

    expected_eps = np.mean(explore_eps)
    print("explore eps:", explore_eps)
    print("avg explore eps:", np.mean(explore_eps))
    if args.clone_bot or not agent.boltzmann:
        print("no boltzmann act")
        boltzmann_t = []
    else:
        boltzmann_beta = utils.generate_log_uniform(
            1 / cfgs["max_t"], 1 / cfgs["min_t"], cfgs["num_t"]
        )
        boltzmann_t = [1 / b for b in boltzmann_beta]
        print("boltzmann beta:", ", ".join(["%.2f" % b for b in boltzmann_beta]))
        print("avg boltzmann beta:", np.mean(boltzmann_beta))

    games = create_envs(
        args.num_thread * args.num_game_per_thread,
        args.seed,
        cfgs["num_player"],
        cfgs["train_bomb"],
        cfgs["max_len"],
    )

    act_group = ActGroup(
        args.act_device,
        agent,
        args.seed,
        args.num_thread,
        args.num_game_per_thread,
        cfgs["num_player"],
        explore_eps,
        boltzmann_t,
        "iql",
        cfgs["sad"],
        cfgs["shuffle_color"] if not args.rand else False,
        cfgs["hide_action"],
        False,  # not trinary, need full hand for prediction
        replay_buffer,
        cfgs["multi_step"],  # not used
        cfgs["max_len"],
        cfgs["gamma"],  # not used
        False,  # turn off off-belief rewardless of how it is trained
        None,  # belief_model
    )

    context, threads = create_threads(
        args.num_thread,
        args.num_game_per_thread,
        act_group.actors,
        games,
    )
    return agent, cfgs, replay_buffer, games, act_group, context, threads