in MTRF/algorithms/softlearning/algorithms/phased_sac.py [0:0]
def _train(self):
"""Return a generator that performs RL training.
Args:
env (`SoftlearningEnv`): Environment used for training.
policy (`Policy`): Policy used for training
initial_exploration_policy ('Policy'): Policy used for exploration
If None, then all exploration is done using policy
pool (`PoolBase`): Sample pool to add samples to
"""
import gtimer as gt
from itertools import count
training_environment = self._training_environment
evaluation_environment = self._evaluation_environment
training_metrics = [0 for _ in range(self._num_goals)]
if not self._training_started:
self._init_training()
for i in range(self._num_goals):
self._initial_exploration_hook(
training_environment, self._initial_exploration_policy, i)
self._initialize_samplers()
self._sample_count = 0
gt.reset_root()
gt.rename_root('RLAlgorithm')
gt.set_def_unique(False)
print("starting_training")
self._training_before_hook()
import time
for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
self._epoch_before_hook()
gt.stamp('epoch_before_hook')
start_samples = sum([self._samplers[i]._total_samples for i in range(self._num_goals)])
sample_times = []
for i in count():
samples_now = sum([self._samplers[i]._total_samples for i in range(self._num_goals)])
self._timestep = samples_now - start_samples
# Stopping condition
if samples_now >= start_samples + self._epoch_length and self.ready_to_train:
break
t0 = time.time()
self._timestep_before_hook()
gt.stamp('timestep_before_hook')
self._do_sampling(timestep=self._total_timestep)
gt.stamp('sample')
sample_times.append(time.time() - t0)
t0 = time.time()
if self.ready_to_train:
self._do_training_repeats(timestep=self._total_timestep)
gt.stamp('train')
# print("Train time: ", time.time() - t0)
self._timestep_after_hook()
gt.stamp('timestep_after_hook')
print("Average Sample Time: ", np.mean(np.array(sample_times)))
print("Step count", self._sample_count)
training_paths_per_policy = self._training_paths()
# self.sampler.get_last_n_paths(
# math.ceil(self._epoch_length / self.sampler._max_path_length))
gt.stamp('training_paths')
evaluation_paths_per_policy = self._evaluation_paths()
if self._eval_n_episodes < 1:
# If we don't choose to do evaluations, set the eval paths as
# the fake ones generated on the first iteration.
# NOTE: If you do this, however, all logged evaluation metrics are garbage.
evaluation_paths_per_policy = self.fake_eval_paths
gt.stamp('evaluation_paths')
# Overwrite with fake trajectories
empty_policies = []
for tpn, tp in enumerate(training_paths_per_policy):
if len(tp) == 0:
empty_policies.append(tpn)
training_paths_per_policy[tpn] = evaluation_paths_per_policy[tpn]
training_metrics_per_policy = self._evaluate_rollouts(
training_paths_per_policy, training_environment)
gt.stamp('training_metrics')
if evaluation_paths_per_policy:
evaluation_metrics_per_policy = self._evaluate_rollouts(
evaluation_paths_per_policy, evaluation_environment)
gt.stamp('evaluation_metrics')
else:
evaluation_metrics_per_policy = [{} for _ in range(self._num_goals)]
self._epoch_after_hook(training_paths_per_policy)
gt.stamp('epoch_after_hook')
t0 = time.time()
sampler_diagnostics_per_policy = [
self._samplers[i].get_diagnostics() for i in range(self._num_goals)]
diagnostics = self.get_diagnostics(
iteration=self._total_timestep,
batches=self._evaluation_batches(),
training_paths_per_policy=training_paths_per_policy,
evaluation_paths_per_policy=evaluation_paths_per_policy)
time_diagnostics = gt.get_times().stamps.itrs
print("Basic diagnostics: ", time.time() - t0)
print("Sample count: ", self._sample_count)
diagnostics.update(OrderedDict((
*(
(f'times/{key}', time_diagnostics[key][-1])
for key in sorted(time_diagnostics.keys())
),
('epoch', self._epoch),
('timestep', self._timestep),
('timesteps_total', self._total_timestep),
('train-steps', self._num_train_steps),
)))
diagnostics.update({
f"phase_{i}/episode_count_this_epoch": self._num_paths_per_phase[i]
for i in range(self._num_goals)
})
print("Other basic diagnostics: ", time.time() - t0)
for i, (evaluation_metrics, training_metrics, sampler_diagnostics) in (
enumerate(zip(evaluation_metrics_per_policy,
training_metrics_per_policy,
sampler_diagnostics_per_policy))):
if i not in empty_policies:
if self._eval_n_episodes >= 1:
# Only log evaluation metrics if they are actually meaningful
diagnostics.update({
f'evaluation_{i}/{key}': evaluation_metrics[key]
for key in sorted(evaluation_metrics.keys())
})
diagnostics.update(OrderedDict((
*(
(f'training_{i}/{key}', training_metrics[key])
for key in sorted(training_metrics.keys())
),
*(
(f'sampler_{i}/{key}', sampler_diagnostics[key])
for key in sorted(sampler_diagnostics.keys())
),
)))
else:
if self._eval_n_episodes >= 1:
diagnostics.update({
f'evaluation_{i}/{key}': evaluation_metrics[key]
for key in sorted(evaluation_metrics.keys())
})
diagnostics.update(OrderedDict((
*(
(f'training_{i}/{key}', -1000)
for key in sorted(training_metrics.keys())
),
*(
(f'sampler_{i}/{key}', -1000)
for key in sorted(sampler_diagnostics.keys())
),
)))
# if self._eval_render_kwargs and hasattr(
# evaluation_environment, 'render_rollouts'):
# # TODO(hartikainen): Make this consistent such that there's no
# # need for the hasattr check.
# training_environment.render_rollouts(evaluation_paths)
yield diagnostics
print("Diagnostic time: ", time.time() - t0)
for i in range(self._num_goals):
self._samplers[i].terminate()
self._training_after_hook()
del evaluation_paths_per_policy
yield {'done': True, **diagnostics}