in train.py [0:0]
def setup_training(cfg: DictConfig) -> TrainingSetup:
if cfg.device == 'cuda' and not th.cuda.is_available():
log.warning('CUDA not available, falling back to CPU')
cfg.device = 'cpu'
# TODO doesn't work with submitit?
# if th.backends.cudnn.is_available():
# th.backends.cudnn.benchmark = True
th.manual_seed(cfg.seed)
viz = Visdom(
server=f'http://{cfg.visdom.host}',
port=cfg.visdom.port,
env=cfg.visdom.env,
offline=cfg.visdom.offline,
log_to_filename=cfg.visdom.logfile,
)
rq = hucc.RenderQueue(viz)
wrappers = hucc.make_wrappers(cfg.env)
envs = hucc.make_vec_envs(
cfg.env.name,
cfg.env.train_procs,
device=cfg.device,
seed=cfg.seed,
wrappers=wrappers,
**cfg.env.train_args,
)
eval_envs = hucc.make_vec_envs(
cfg.env.name,
cfg.env.eval_procs,
device=cfg.device,
seed=cfg.seed,
wrappers=wrappers,
**cfg.env.eval_args,
)
observation_space = hucc.effective_observation_space(cfg.agent, envs)
action_space = hucc.effective_action_space(cfg.agent, envs)
def make_model_rec(mcfg, obs_space, action_space) -> nn.Module:
if isinstance(obs_space, dict) and isinstance(action_space, dict):
assert set(obs_space.keys()) == set(action_space.keys())
models: Dict[str, nn.Module] = {}
for k in obs_space.keys():
models[k] = make_model_rec(
mcfg[k], obs_space[k], action_space[k]
)
return nn.ModuleDict(models)
return hucc.make_model(mcfg, obs_space, action_space)
model = make_model_rec(cfg.model, observation_space, action_space)
log.info(f'Model from config:\n{model}')
model.to(cfg.device)
optim = hucc.make_optim(cfg.optim, model)
agent = hucc.make_agent(cfg.agent, envs, model, optim)
# If the current directoy is different from the original one, assume we have
# a dedicated job directory. We'll just write our summaries to 'tb/' then.
try:
if os.getcwd() != hydra.utils.get_original_cwd():
tbw = SummaryWriter('tb')
else:
tbw = SummaryWriter()
agent.tbw = tbw
except:
# XXX hydra.utils.get_original_cwd throws if we don't run this via
# run_hydra
tbw = None
try:
no_gs_obs = copy(envs.observation_space.spaces)
for key in [k for k in no_gs_obs.keys() if k.startswith('_')]:
del no_gs_obs[key]
if 'time' in no_gs_obs:
del no_gs_obs['time']
no_gs_obs = gym.spaces.Dict(no_gs_obs)
except:
no_gs_obs = envs.observation_space
dump_sc = int(cfg.dump_state_counts)
if dump_sc > 0:
hcr = HashingCountReward(gym.spaces.flatdim(no_gs_obs)).to_(cfg.device)
else:
hcr = None
return TrainingSetup(
cfg=cfg,
agent=agent,
model=model,
tbw=tbw,
viz=viz,
rq=rq,
envs=envs,
eval_envs=eval_envs,
eval_fn=eval,
hcr_space=no_gs_obs,
hcr=hcr,
)