in tf_agents/agents/ppo/ppo_agent.py [0:0]
def _train(self, experience, weights):
experience = self._as_trajectory(experience)
if self._compute_value_and_advantage_in_train:
processed_experience = self._preprocess(experience)
else:
processed_experience = experience
# Mask trajectories that cannot be used for training.
valid_mask = ppo_utils.make_trajectory_mask(processed_experience)
if weights is None:
masked_weights = valid_mask
else:
masked_weights = weights * valid_mask
# Reconstruct per-timestep policy distribution from stored distribution
# parameters.
old_action_distribution_parameters = processed_experience.policy_info[
'dist_params']
old_actions_distribution = (
ppo_utils.distribution_from_spec(
self._action_distribution_spec,
old_action_distribution_parameters,
legacy_distribution_network=isinstance(
self._actor_net, network.DistributionNetwork)))
# Compute log probability of actions taken during data collection, using the
# collect policy distribution.
old_act_log_probs = common.log_probability(old_actions_distribution,
processed_experience.action,
self._action_spec)
# TODO(b/171573175): remove the condition once histograms are
# supported on TPUs.
if self._debug_summaries and not tf.config.list_logical_devices('TPU'):
actions_list = tf.nest.flatten(processed_experience.action)
show_action_index = len(actions_list) != 1
for i, single_action in enumerate(actions_list):
action_name = ('actions_{}'.format(i)
if show_action_index else 'actions')
tf.compat.v2.summary.histogram(
name=action_name, data=single_action, step=self.train_step_counter)
time_steps = ts.TimeStep(
step_type=processed_experience.step_type,
reward=processed_experience.reward,
discount=processed_experience.discount,
observation=processed_experience.observation)
actions = processed_experience.action
returns = processed_experience.policy_info['return']
advantages = processed_experience.policy_info['advantage']
normalized_advantages = _normalize_advantages(advantages,
variance_epsilon=1e-8)
# TODO(b/171573175): remove the condition once histograms are
# supported on TPUs.
if self._debug_summaries and not tf.config.list_logical_devices('TPU'):
tf.compat.v2.summary.histogram(
name='advantages_normalized',
data=normalized_advantages,
step=self.train_step_counter)
old_value_predictions = processed_experience.policy_info['value_prediction']
batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
# Loss tensors across batches will be aggregated for summaries.
policy_gradient_losses = []
value_estimation_losses = []
l2_regularization_losses = []
entropy_regularization_losses = []
kl_penalty_losses = []
loss_info = None # TODO(b/123627451): Remove.
variables_to_train = list(
object_identity.ObjectIdentitySet(self._actor_net.trainable_weights +
self._value_net.trainable_weights))
# Sort to ensure tensors on different processes end up in same order.
variables_to_train = sorted(variables_to_train, key=lambda x: x.name)
for i_epoch in range(self._num_epochs):
with tf.name_scope('epoch_%d' % i_epoch):
# Only save debug summaries for first and last epochs.
debug_summaries = (
self._debug_summaries and
(i_epoch == 0 or i_epoch == self._num_epochs - 1))
with tf.GradientTape() as tape:
loss_info = self.get_loss(
time_steps,
actions,
old_act_log_probs,
returns,
normalized_advantages,
old_action_distribution_parameters,
masked_weights,
self.train_step_counter,
debug_summaries,
old_value_predictions=old_value_predictions,
training=True)
grads = tape.gradient(loss_info.loss, variables_to_train)
if self._gradient_clipping > 0:
grads, _ = tf.clip_by_global_norm(grads, self._gradient_clipping)
# Tuple is used for py3, where zip is a generator producing values once.
grads_and_vars = tuple(zip(grads, variables_to_train))
# If summarize_gradients, create functions for summarizing both
# gradients and variables.
if self._summarize_grads_and_vars and debug_summaries:
eager_utils.add_gradients_summaries(grads_and_vars,
self.train_step_counter)
eager_utils.add_variables_summaries(grads_and_vars,
self.train_step_counter)
self._optimizer.apply_gradients(grads_and_vars)
self.train_step_counter.assign_add(1)
policy_gradient_losses.append(loss_info.extra.policy_gradient_loss)
value_estimation_losses.append(loss_info.extra.value_estimation_loss)
l2_regularization_losses.append(loss_info.extra.l2_regularization_loss)
entropy_regularization_losses.append(
loss_info.extra.entropy_regularization_loss)
kl_penalty_losses.append(loss_info.extra.kl_penalty_loss)
# TODO(b/1613650790: Move this logic to PPOKLPenaltyAgent.
if self._initial_adaptive_kl_beta > 0:
# After update epochs, update adaptive kl beta, then update observation
# normalizer and reward normalizer.
policy_state = self._collect_policy.get_initial_state(batch_size)
# Compute the mean kl from previous action distribution.
kl_divergence = self._kl_divergence(
time_steps, old_action_distribution_parameters,
self._collect_policy.distribution(time_steps, policy_state).action)
self.update_adaptive_kl_beta(kl_divergence)
if self.update_normalizers_in_train:
self.update_observation_normalizer(time_steps.observation)
self.update_reward_normalizer(processed_experience.reward)
loss_info = tf.nest.map_structure(tf.identity, loss_info)
# Make summaries for total loss averaged across all epochs.
# The *_losses lists will have been populated by
# calls to self.get_loss. Assumes all the losses have same length.
with tf.name_scope('Losses/'):
num_epochs = len(policy_gradient_losses)
total_policy_gradient_loss = tf.add_n(policy_gradient_losses) / num_epochs
total_value_estimation_loss = tf.add_n(
value_estimation_losses) / num_epochs
total_l2_regularization_loss = tf.add_n(
l2_regularization_losses) / num_epochs
total_entropy_regularization_loss = tf.add_n(
entropy_regularization_losses) / num_epochs
total_kl_penalty_loss = tf.add_n(kl_penalty_losses) / num_epochs
tf.compat.v2.summary.scalar(
name='policy_gradient_loss',
data=total_policy_gradient_loss,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='value_estimation_loss',
data=total_value_estimation_loss,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='l2_regularization_loss',
data=total_l2_regularization_loss,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='entropy_regularization_loss',
data=total_entropy_regularization_loss,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='kl_penalty_loss',
data=total_kl_penalty_loss,
step=self.train_step_counter)
total_abs_loss = (
tf.abs(total_policy_gradient_loss) +
tf.abs(total_value_estimation_loss) +
tf.abs(total_entropy_regularization_loss) +
tf.abs(total_l2_regularization_loss) + tf.abs(total_kl_penalty_loss))
tf.compat.v2.summary.scalar(
name='total_abs_loss',
data=total_abs_loss,
step=self.train_step_counter)
with tf.name_scope('LearningRate/'):
learning_rate = ppo_utils.get_learning_rate(self._optimizer)
tf.compat.v2.summary.scalar(
name='learning_rate',
data=learning_rate,
step=self.train_step_counter)
# TODO(b/171573175): remove the condition once histograms are
# supported on TPUs.
if self._summarize_grads_and_vars and not tf.config.list_logical_devices(
'TPU'):
with tf.name_scope('Variables/'):
all_vars = (
self._actor_net.trainable_weights +
self._value_net.trainable_weights)
for var in all_vars:
tf.compat.v2.summary.histogram(
name=var.name.replace(':', '_'),
data=var,
step=self.train_step_counter)
return loss_info