in tf_agents/agents/reinforce/reinforce_agent.py [0:0]
def total_loss(self,
experience: traj.Trajectory,
returns: types.Tensor,
weights: types.Tensor,
training: bool = False) -> tf_agent.LossInfo:
# Ensure we see at least one full episode.
time_steps = ts.TimeStep(experience.step_type,
tf.zeros_like(experience.reward),
tf.zeros_like(experience.discount),
experience.observation)
is_last = experience.is_last()
num_episodes = tf.reduce_sum(tf.cast(is_last, tf.float32))
tf.debugging.assert_greater(
num_episodes,
0.0,
message='No complete episode found. REINFORCE requires full episodes '
'to compute losses.')
# Mask out partial episodes at the end of each batch of time_steps.
# NOTE: We use is_last rather than is_boundary because the last transition
# is the transition with the last valid reward. In other words, the
# reward on the boundary transitions do not have valid rewards. Since
# REINFORCE is calculating a loss w.r.t. the returns (and not bootstrapping)
# keeping the boundary transitions is irrelevant.
valid_mask = tf.cast(experience.is_last(), dtype=tf.float32)
valid_mask = tf.math.cumsum(valid_mask, axis=1, reverse=True)
valid_mask = tf.cast(valid_mask > 0, dtype=tf.float32)
if weights is not None:
weights *= valid_mask
else:
weights = valid_mask
advantages = returns
value_preds = None
if self._baseline:
value_preds, _ = self._value_network(time_steps.observation,
time_steps.step_type,
training=True)
if self._debug_summaries:
tf.compat.v2.summary.histogram(
name='value_preds', data=value_preds, step=self.train_step_counter)
advantages = self._advantage_fn(returns, value_preds)
if self._debug_summaries:
tf.compat.v2.summary.histogram(
name='advantages', data=advantages, step=self.train_step_counter)
# TODO(b/126592060): replace with tensor normalizer.
if self._normalize_returns:
advantages = _standard_normalize(advantages, axes=(0, 1))
if self._debug_summaries:
tf.compat.v2.summary.histogram(
name='normalized_%s' %
('advantages' if self._baseline else 'returns'),
data=advantages,
step=self.train_step_counter)
nest_utils.assert_same_structure(time_steps, self.time_step_spec)
policy_state = _get_initial_policy_state(self.collect_policy, time_steps)
actions_distribution = self.collect_policy.distribution(
time_steps, policy_state=policy_state).action
policy_gradient_loss = self.policy_gradient_loss(
actions_distribution,
experience.action,
experience.is_boundary(),
advantages,
num_episodes,
weights,
)
entropy_regularization_loss = self.entropy_regularization_loss(
actions_distribution, weights)
network_regularization_loss = tf.nn.scale_regularization_loss(
self._actor_network.losses)
total_loss = (policy_gradient_loss +
network_regularization_loss +
entropy_regularization_loss)
losses_dict = {
'policy_gradient_loss': policy_gradient_loss,
'policy_network_regularization_loss': network_regularization_loss,
'entropy_regularization_loss': entropy_regularization_loss,
'value_estimation_loss': 0.0,
'value_network_regularization_loss': 0.0,
}
value_estimation_loss = None
if self._baseline:
value_estimation_loss = self.value_estimation_loss(
value_preds, returns, num_episodes, weights)
value_network_regularization_loss = tf.nn.scale_regularization_loss(
self._value_network.losses)
total_loss += value_estimation_loss + value_network_regularization_loss
losses_dict['value_estimation_loss'] = value_estimation_loss
losses_dict['value_network_regularization_loss'] = (
value_network_regularization_loss)
loss_info_extra = ReinforceAgentLossInfo(**losses_dict)
losses_dict['total_loss'] = total_loss # Total loss not in loss_info_extra.
common.summarize_scalar_dict(losses_dict,
self.train_step_counter,
name_scope='Losses/')
return tf_agent.LossInfo(total_loss, loss_info_extra)