in train.py [0:0]
def train(self):
# predicates
train_until_step = utils.Until(self.cfg.num_train_frames,
self.cfg.action_repeat)
seed_until_step = utils.Until(self.cfg.num_seed_frames,
self.cfg.action_repeat)
eval_every_step = utils.Every(self.cfg.eval_every_frames,
self.cfg.action_repeat)
episode_step, episode_reward = 0, 0
time_step = self.train_env.reset()
self.replay_storage.add(time_step)
self.train_video_recorder.init(time_step.observation)
metrics = None
while train_until_step(self.global_step):
if time_step.last():
self._global_episode += 1
self.train_video_recorder.save(f'{self.global_frame}.mp4')
# wait until all the metrics schema is populated
if metrics is not None:
# log stats
elapsed_time, total_time = self.timer.reset()
episode_frame = episode_step * self.cfg.action_repeat
with self.logger.log_and_dump_ctx(self.global_frame,
ty='train') as log:
log('fps', episode_frame / elapsed_time)
log('total_time', total_time)
log('episode_reward', episode_reward)
log('episode_length', episode_frame)
log('episode', self.global_episode)
log('buffer_size', len(self.replay_storage))
log('step', self.global_step)
# reset env
time_step = self.train_env.reset()
self.replay_storage.add(time_step)
self.train_video_recorder.init(time_step.observation)
# try to save snapshot
if self.cfg.save_snapshot:
self.save_snapshot()
episode_step = 0
episode_reward = 0
# try to evaluate
if eval_every_step(self.global_step):
self.logger.log('eval_total_time', self.timer.total_time(),
self.global_frame)
self.eval()
# sample action
with torch.no_grad(), utils.eval_mode(self.agent):
action = self.agent.act(time_step.observation,
self.global_step,
eval_mode=False)
# try to update the agent
if not seed_until_step(self.global_step):
metrics = self.agent.update(self.replay_iter, self.global_step)
self.logger.log_metrics(metrics, self.global_frame, ty='train')
# take env step
time_step = self.train_env.step(action)
episode_reward += time_step.reward
self.replay_storage.add(time_step)
self.train_video_recorder.record(time_step.observation)
episode_step += 1
self._global_step += 1