in minihack/agent/rllib/train.py [0:0]
def train(cfg: DictConfig) -> None:
ray.init(num_gpus=cfg.num_gpus, num_cpus=cfg.num_cpus + 1)
cfg = get_full_config(cfg)
register_env("RLlibNLE-v0", RLLibNLEEnv)
try:
config, trainer = NAME_TO_TRAINER[cfg.algo]
except KeyError as error:
raise ValueError(
"The algorithm you specified isn't currently supported: %s",
cfg.algo,
) from error
args_config = OmegaConf.to_container(cfg)
# Algo-specific config. Requires hydra config keys to match rllib exactly
algo_config = args_config.pop(cfg.algo)
# Remove unnecessary config keys
for algo in NAME_TO_TRAINER.keys():
if algo != cfg.algo:
args_config.pop(algo, None)
# Merge config from hydra (will have some rogue keys but that's ok)
config = merge_dicts(config, args_config)
# check the name of the environment
if cfg.env not in tasks.ENVS:
if is_env_registered(cfg.env):
cfg.env = get_env_shortcut(cfg.env)
else:
raise KeyError(
f"Could not find an environement with a name: {cfg.env}."
)
# Update configuration with parsed arguments in specific ways
config = merge_dicts(
config,
{
"framework": "torch",
"num_gpus": cfg.num_gpus,
"seed": cfg.seed,
"env": "RLlibNLE-v0",
"env_config": {
"flags": cfg,
"observation_keys": cfg.obs_keys.split(","),
"name": cfg.env,
},
"train_batch_size": cfg.train_batch_size,
"model": merge_dicts(
MODEL_DEFAULTS,
{
"custom_model": "rllib_nle_model",
"custom_model_config": {"flags": cfg, "algo": cfg.algo},
"use_lstm": cfg.use_lstm,
"lstm_use_prev_reward": True,
"lstm_use_prev_action": True,
"lstm_cell_size": cfg.hidden_dim,
},
),
"num_workers": cfg.num_cpus,
"num_envs_per_worker": int(cfg.num_actors / cfg.num_cpus),
"evaluation_interval": 100,
"evaluation_num_episodes": 50,
"evaluation_config": {"explore": False},
"rollout_fragment_length": cfg.unroll_length,
},
)
# Merge algo-specific config at top level
config = merge_dicts(config, algo_config)
# Ensure we can use the config we've specified above
trainer_class = trainer.with_updates(default_config=config)
callbacks = []
if cfg.wandb:
callbacks.append(
WandbLoggerCallback(
project=cfg.project,
api_key_file="~/.wandb_api_key",
entity=cfg.entity,
group=cfg.group,
tags=cfg.tags.split(","),
)
)
os.environ[
"TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"
] = "1" # Only log to wandb
# Hacky monkey-patching to allow for OmegaConf config
def _is_allowed_type(obj):
"""Return True if type is allowed for logging to wandb"""
if isinstance(obj, DictConfig):
return True
if isinstance(obj, np.ndarray) and obj.size == 1:
return isinstance(obj.item(), Number)
if isinstance(obj, Iterable) and len(obj) > 0:
return isinstance(obj[0], _VALID_ITERABLE_TYPES)
return isinstance(obj, _VALID_TYPES)
ray.tune.integration.wandb._is_allowed_type = _is_allowed_type
tune.run(
trainer_class,
stop={"timesteps_total": cfg.total_steps},
config=config,
name=cfg.name,
callbacks=callbacks,
)