in pyhanabi/utils.py [0:0]
def load_agent(weight_file, overwrite):
"""
overwrite has to contain "device"
"""
cfg = get_train_config(weight_file)
assert cfg is not None
if "core" in cfg:
new_cfg = {}
flatten_dict(cfg, new_cfg)
cfg = new_cfg
game = create_envs(
1,
1,
cfg["num_player"],
cfg["train_bomb"],
[0], # explore_eps,
[100], # boltzmann_t,
cfg["max_len"],
cfg["sad"] if "sad" in cfg else cfg["greedy_extra"],
cfg["shuffle_obs"],
cfg["shuffle_color"],
cfg["hide_action"],
True,
)[0]
config = {
"vdn": overwrite["vdn"] if "vdn" in overwrite else cfg["method"] == "vdn",
"multi_step": overwrite.get("multi_step", cfg["multi_step"]),
"gamma": overwrite.get("gamma", cfg["gamma"]),
"eta": 0.9,
"device": overwrite["device"],
"in_dim": game.feature_size(),
"hid_dim": cfg["hid_dim"] if "hid_dim" in cfg else cfg["rnn_hid_dim"],
"out_dim": game.num_action(),
"num_lstm_layer": cfg.get("num_lstm_layer", overwrite.get("num_lstm_layer", 2)),
"boltzmann_act": overwrite.get("boltzmann_act", cfg["boltzmann_act"]),
"uniform_priority": overwrite.get("uniform_priority", False),
}
agent = r2d2.R2D2Agent(**config).to(config["device"])
load_weight(agent.online_net, weight_file, config["device"])
agent.sync_target_with_online()
return agent, cfg