in rlkit/core/batch_rl_algorithm.py [0:0]
def _train(self):
if self.min_num_steps_before_training > 0 and not self.batch_rl:
init_expl_paths = self.expl_data_collector.collect_new_paths(
self.max_path_length,
self.min_num_steps_before_training,
discard_incomplete_paths=False,
)
self.replay_buffer.add_paths(init_expl_paths)
self.expl_data_collector.end_epoch(-1)
for epoch in gt.timed_for(
range(self._start_epoch, self.num_epochs),
save_itrs=True,
):
if self.q_learning_alg:
policy_fn = self.policy_fn
if self.trainer.discrete:
policy_fn = self.policy_fn_discrete
self.eval_data_collector.collect_new_paths(
policy_fn,
self.max_path_length,
self.num_eval_steps_per_epoch,
discard_incomplete_paths=True
)
else:
self.eval_data_collector.collect_new_paths(
self.max_path_length,
self.num_eval_steps_per_epoch,
discard_incomplete_paths=True,
)
gt.stamp('evaluation sampling')
for _ in range(self.num_train_loops_per_epoch):
if not self.batch_rl:
# Sample new paths only if not doing batch rl
new_expl_paths = self.expl_data_collector.collect_new_paths(
self.max_path_length,
self.num_expl_steps_per_train_loop,
discard_incomplete_paths=False,
)
gt.stamp('exploration sampling', unique=False)
self.replay_buffer.add_paths(new_expl_paths)
gt.stamp('data storing', unique=False)
elif self.eval_both:
# Now evaluate the policy here:
policy_fn = self.policy_fn
if self.trainer.discrete:
policy_fn = self.policy_fn_discrete
new_expl_paths = self.expl_data_collector.collect_new_paths(
policy_fn,
self.max_path_length,
self.num_eval_steps_per_epoch,
discard_incomplete_paths=True,
)
gt.stamp('policy fn evaluation')
self.training_mode(True)
for _ in range(self.num_trains_per_train_loop):
train_data = self.replay_buffer.random_batch(
self.batch_size)
self.trainer.train(train_data)
gt.stamp('training', unique=False)
self.training_mode(False)
has_nan = self._end_epoch(epoch)
if has_nan:
break