in MTRF/algorithms/softlearning/algorithms/rl_algorithm.py [0:0]
def _train(self):
"""Return a generator that performs RL training.
Args:
env (`SoftlearningEnv`): Environment used for training.
policy (`Policy`): Policy used for training
initial_exploration_policy ('Policy'): Policy used for exploration
If None, then all exploration is done using policy
pool (`PoolBase`): Sample pool to add samples to
"""
training_environment = self._training_environment
evaluation_environment = self._evaluation_environment
policy = self._policy
pool = self._pool
if not self._training_started:
self._init_training()
self._initial_exploration_hook(
training_environment, self._initial_exploration_policy, pool)
self.sampler.initialize(training_environment, policy, pool)
gt.reset_root()
gt.rename_root('RLAlgorithm')
gt.set_def_unique(False)
self._training_before_hook()
import time
for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
self._epoch_before_hook()
gt.stamp('epoch_before_hook')
start_samples = self.sampler._total_samples
sample_times = []
for i in count():
samples_now = self.sampler._total_samples
self._timestep = samples_now - start_samples
if (samples_now >= start_samples + self._epoch_length
and self.ready_to_train):
break
self._timestep_before_hook()
gt.stamp('timestep_before_hook')
t0 = time.time()
self._do_sampling(timestep=self._total_timestep)
gt.stamp('sample')
sample_times.append(time.time() - t0)
if self.ready_to_train:
self._do_training_repeats(timestep=self._total_timestep)
gt.stamp('train')
self._timestep_after_hook()
gt.stamp('timestep_after_hook')
print("Average Sample Time: ", np.mean(np.array(sample_times)))
training_paths = self._training_paths()
# self.sampler.get_last_n_paths(
# math.ceil(self._epoch_length / self.sampler._max_path_length))
gt.stamp('training_paths')
evaluation_paths = self._evaluation_paths(
policy, evaluation_environment)
gt.stamp('evaluation_paths')
training_metrics = self._evaluate_rollouts(
training_paths, training_environment)
gt.stamp('training_metrics')
#should_save_path = (
# self._path_save_frequency > 0
# and self._epoch % self._path_save_frequency == 0)
#if should_save_path:
# import pickle
# for i, path in enumerate(training_paths):
# #path.pop('images')
# path_file_name = f'training_path_{self._epoch}_{i}.pkl'
# path_file_path = os.path.join(
# os.getcwd(), 'paths', path_file_name)
# if not os.path.exists(os.path.dirname(path_file_path)):
# os.makedirs(os.path.dirname(path_file_path))
# with open(path_file_path, 'wb' ) as f:
# pickle.dump(path, f)
if evaluation_paths:
evaluation_metrics = self._evaluate_rollouts(
evaluation_paths, evaluation_environment)
gt.stamp('evaluation_metrics')
else:
evaluation_metrics = {}
self._epoch_after_hook(training_paths)
gt.stamp('epoch_after_hook')
sampler_diagnostics = self.sampler.get_diagnostics()
diagnostics = self.get_diagnostics(
iteration=self._total_timestep,
batch=self._evaluation_batch(),
training_paths=training_paths,
evaluation_paths=evaluation_paths)
time_diagnostics = gt.get_times().stamps.itrs
diagnostics.update(OrderedDict((
*(
(f'evaluation/{key}', evaluation_metrics[key])
for key in sorted(evaluation_metrics.keys())
),
*(
(f'training/{key}', training_metrics[key])
for key in sorted(training_metrics.keys())
),
*(
(f'times/{key}', time_diagnostics[key][-1])
for key in sorted(time_diagnostics.keys())
),
*(
(f'sampler/{key}', sampler_diagnostics[key])
for key in sorted(sampler_diagnostics.keys())
),
('epoch', self._epoch),
('timestep', self._timestep),
('timesteps_total', self._total_timestep),
('train-steps', self._num_train_steps),
)))
if self._eval_render_kwargs and hasattr(
evaluation_environment, 'render_rollouts'):
# TODO(hartikainen): Make this consistent such that there's no
# need for the hasattr check.
training_environment.render_rollouts(evaluation_paths)
yield diagnostics
self.sampler.terminate()
self._training_after_hook()
del evaluation_paths
yield {'done': True, **diagnostics}