in train.py [0:0]
def run(self):
assert not self.done
assert self.episode_reward == 0.0
assert self.episode_step == 0
self.agent.reset()
obs = self.env.reset()
start_time = time.time()
while self.step < self.cfg.num_train_steps:
if self.done:
if self.step > 0:
self.logger.log(
'train/episode_reward', self.episode_reward, self.step)
self.logger.log('train/duration',
time.time() - start_time, self.step)
self.logger.log('train/episode', self.episode, self.step)
start_time = time.time()
self.logger.dump(
self.step, save=(self.step > self.cfg.num_seed_steps))
if self.steps_since_eval >= self.cfg.eval_freq:
self.logger.log('eval/episode', self.episode, self.step)
eval_rew = self.evaluate()
self.steps_since_eval = 0
if self.best_eval_rew is None or eval_rew > self.best_eval_rew:
self.save(tag='best')
self.best_eval_rew = eval_rew
self.replay_buffer.save_data(self.replay_dir)
self.save(tag='latest')
if self.step > 0 and self.cfg.save_freq and \
self.steps_since_save >= self.cfg.save_freq:
tag = str(self.step).zfill(self.cfg.save_zfill)
self.save(tag=tag)
self.steps_since_save = 0
if self.cfg.num_initial_states is not None:
self.env.set_seed(self.episode % self.cfg.num_initial_states)
obs = self.env.reset()
self.agent.reset()
self.done = False
self.episode_reward = 0
self.episode_step = 0
self.episode += 1
# sample action for data collection
if self.step < self.cfg.num_seed_steps:
action = self.env.action_space.sample()
else:
with utils.eval_mode(self.agent):
if self.cfg.normalize_obs:
mu, sigma = self.replay_buffer.get_obs_stats()
obs_norm = (obs - mu) / sigma
action = self.agent.act(obs_norm, sample=True)
else:
action = self.agent.act(obs, sample=True)
# run training update
if self.step >= self.cfg.num_seed_steps-1:
self.agent.update(self.replay_buffer, self.logger, self.step)
next_obs, reward, self.done, _ = self.env.step(action)
# allow infinite bootstrap
done_float = float(self.done)
done_no_max = done_float if self.episode_step + 1 < self.env._max_episode_steps \
else 0.
self.episode_reward += reward
self.replay_buffer.add(obs, action, reward, next_obs, done_float, done_no_max)
obs = next_obs
self.episode_step += 1
self.step += 1
self.steps_since_eval += 1
self.steps_since_save += 1
if self.steps_since_eval > 1:
self.logger.log('eval/episode', self.episode, self.step)
self.evaluate()
if self.cfg.delete_replay_at_end:
shutil.rmtree(self.replay_dir)