def load_agent()

in pyhanabi/utils.py [0:0]


def load_agent(weight_file, overwrite):
    """
    overwrite has to contain "device"
    """
    cfg = get_train_config(weight_file)
    assert cfg is not None

    if "core" in cfg:
        new_cfg = {}
        flatten_dict(cfg, new_cfg)
        cfg = new_cfg

    game = create_envs(
        1,
        1,
        cfg["num_player"],
        cfg["train_bomb"],
        [0], # explore_eps,
        [100], # boltzmann_t,
        cfg["max_len"],
        cfg["sad"] if "sad" in cfg else cfg["greedy_extra"],
        cfg["shuffle_obs"],
        cfg["shuffle_color"],
        cfg["hide_action"],
        True,
    )[0]

    config = {
        "vdn": overwrite["vdn"] if "vdn" in overwrite else cfg["method"] == "vdn",
        "multi_step": overwrite.get("multi_step", cfg["multi_step"]),
        "gamma": overwrite.get("gamma", cfg["gamma"]),
        "eta": 0.9,
        "device": overwrite["device"],
        "in_dim": game.feature_size(),
        "hid_dim": cfg["hid_dim"] if "hid_dim" in cfg else cfg["rnn_hid_dim"],
        "out_dim": game.num_action(),
        "num_lstm_layer": cfg.get("num_lstm_layer", overwrite.get("num_lstm_layer", 2)),
        "boltzmann_act": overwrite.get("boltzmann_act", cfg["boltzmann_act"]),
        "uniform_priority": overwrite.get("uniform_priority", False),
    }

    agent = r2d2.R2D2Agent(**config).to(config["device"])
    load_weight(agent.online_net, weight_file, config["device"])
    agent.sync_target_with_online()
    return agent, cfg