in train.py [0:0]
def train_loop(setup: TrainingSetup):
cfg = setup.cfg
agent = setup.agent
rq = setup.rq
envs = setup.envs
agent.train()
n_envs = envs.num_envs
cp_path = cfg.checkpoint_path
record_videos = cfg.video is not None
annotate = record_videos and (
cfg.video.annotations or (cfg.video.annotations is None)
)
vwidth = int(cfg.video.size[0]) if record_videos else 0
vheight = int(cfg.video.size[1]) if record_videos else 0
max_steps = int(cfg.max_steps)
dump_sc = int(cfg.dump_state_counts)
obs = envs.reset()
n_imgs = 0
collect_img = False
agent.train()
while setup.n_samples < max_steps:
if setup.n_samples % cfg.eval.interval == 0:
# Checkpoint time
try:
log.debug(
f'Checkpointing to {cp_path} after {setup.n_samples} samples'
)
with open(cp_path, 'wb') as f:
agent.save_checkpoint(f)
if cfg.keep_all_checkpoints:
p = Path(cp_path)
cp_unique_path = str(
p.with_name(
p.stem + f'_{setup.n_samples:08d}' + p.suffix
)
)
shutil.copy(cp_path, cp_unique_path)
except:
log.exception('Checkpoint saving failed')
agent.eval()
setup.eval_fn(setup, setup.n_samples)
agent.train()
if record_videos and setup.n_samples % cfg.video.interval == 0:
collect_img = True
pass
if collect_img:
rqin = {
'img': envs.render_single(
mode='rgb_array', width=vwidth, height=vheight
)
}
if annotate:
rqin['s_left'] = [
f'Samples {setup.n_samples}',
f'Frame {n_imgs}',
]
rqin['s_right'] = [
'Train',
]
rq.push(**rqin)
n_imgs += 1
if n_imgs > cfg.video.length:
rq.plot()
n_imgs = 0
collect_img = False
action, extra = agent.action(envs, obs)
next_obs, reward, done, info = envs.step(action)
agent.step(envs, obs, action, extra, (next_obs, reward, done, info))
if dump_sc > 0:
if setup.n_samples % dump_sc == 0:
d = len(setup.hcr.bucket_sizes)
sc = setup.hcr.tables.clamp(max=1).sum().item() / d
agent.tbw_add_scalar(f'Train/UniqueStates', sc, setup.n_samples)
setup.hcr.inc_hash(th_flatten(setup.hcr_space, obs))
obs = envs.reset_if_done()
setup.n_samples += n_envs
# Final checkpoint & eval time
try:
log.debug(f'Checkpointing to {cp_path} after {setup.n_samples} samples')
with open(cp_path, 'wb') as f:
agent.save_checkpoint(f)
if cfg.keep_all_checkpoints:
p = Path(cp_path)
cp_unique_path = str(
p.with_name(p.stem + f'_{setup.n_samples:08d}' + p.suffix)
)
shutil.copy(cp_path, cp_unique_path)
except:
log.exception('Checkpoint saving failed')
agent.eval()
setup.eval_fn(setup, setup.n_samples)
agent.train()