in tf_agents/agents/categorical_dqn/categorical_dqn_agent.py [0:0]
def _loss(self,
experience,
td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
gamma=1.0,
reward_scale_factor=1.0,
weights=None,
training=False):
"""Computes critic loss for CategoricalDQN training.
See Algorithm 1 and the discussion immediately preceding it in page 6 of
"A Distributional Perspective on Reinforcement Learning"
Bellemare et al., 2017
https://arxiv.org/abs/1707.06887
Args:
experience: A batch of experience data in the form of a `Trajectory`. The
structure of `experience` must match that of `self.policy.step_spec`.
All tensors in `experience` must be shaped `[batch, time, ...]` where
`time` must be equal to `self.required_experience_time_steps` if that
property is not `None`.
td_errors_loss_fn: A function(td_targets, predictions) to compute loss.
gamma: Discount for future rewards.
reward_scale_factor: Multiplicative factor to scale rewards.
weights: Optional weights used for importance sampling.
training: Whether the loss is being used for training.
Returns:
critic_loss: A scalar critic loss.
Raises:
ValueError:
if the number of actions is greater than 1.
"""
squeeze_time_dim = not self._q_network.state_spec
if self._n_step_update == 1:
time_steps, policy_steps, next_time_steps = (
trajectory.experience_to_transitions(experience, squeeze_time_dim))
actions = policy_steps.action
else:
# To compute n-step returns, we need the first time steps, the first
# actions, and the last time steps. Therefore we extract the first and
# last transitions from our Trajectory.
first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience)
last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience)
time_steps, policy_steps, _ = (
trajectory.experience_to_transitions(
first_two_steps, squeeze_time_dim))
actions = policy_steps.action
_, _, next_time_steps = (
trajectory.experience_to_transitions(
last_two_steps, squeeze_time_dim))
with tf.name_scope('critic_loss'):
nest_utils.assert_same_structure(actions, self.action_spec)
nest_utils.assert_same_structure(time_steps, self.time_step_spec)
nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)
rank = nest_utils.get_outer_rank(time_steps.observation,
self._time_step_spec.observation)
# If inputs have a time dimension and the q_network is stateful,
# combine the batch and time dimension.
batch_squash = (None
if rank <= 1 or self._q_network.state_spec in ((), None)
else network_utils.BatchSquash(rank))
network_observation = time_steps.observation
if self._observation_and_action_constraint_splitter is not None:
network_observation, _ = (
self._observation_and_action_constraint_splitter(
network_observation))
# q_logits contains the Q-value logits for all actions.
q_logits, _ = self._q_network(network_observation,
step_type=time_steps.step_type,
training=training)
if batch_squash is not None:
# Squash outer dimensions to a single dimensions for facilitation
# computing the loss the following. Required for supporting temporal
# inputs, for example.
q_logits = batch_squash.flatten(q_logits)
actions = batch_squash.flatten(actions)
next_time_steps = tf.nest.map_structure(batch_squash.flatten,
next_time_steps)
next_q_distribution = self._next_q_distribution(next_time_steps)
if actions.shape.rank > 1:
actions = tf.squeeze(actions, list(range(1, actions.shape.rank)))
# Project the sample Bellman update \hat{T}Z_{\theta} onto the original
# support of Z_{\theta} (see Figure 1 in paper).
batch_size = q_logits.shape[0] or tf.shape(q_logits)[0]
tiled_support = tf.tile(self._support, [batch_size])
tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms])
if self._n_step_update == 1:
discount = next_time_steps.discount
if discount.shape.rank == 1:
# We expect discount to have a shape of [batch_size], while
# tiled_support will have a shape of [batch_size, num_atoms]. To
# multiply these, we add a second dimension of 1 to the discount.
discount = tf.expand_dims(discount, -1)
next_value_term = tf.multiply(discount,
tiled_support,
name='next_value_term')
reward = next_time_steps.reward
if reward.shape.rank == 1:
# See the explanation above.
reward = tf.expand_dims(reward, -1)
reward_term = tf.multiply(reward_scale_factor,
reward,
name='reward_term')
target_support = tf.add(reward_term, gamma * next_value_term,
name='target_support')
else:
# When computing discounted return, we need to throw out the last time
# index of both reward and discount, which are filled with dummy values
# to match the dimensions of the observation.
rewards = reward_scale_factor * experience.reward[:, :-1]
discounts = gamma * experience.discount[:, :-1]
# TODO(b/134618876): Properly handle Trajectories that include episode
# boundaries with nonzero discount.
discounted_returns = value_ops.discounted_return(
rewards=rewards,
discounts=discounts,
final_value=tf.zeros([batch_size], dtype=discounts.dtype),
time_major=False,
provide_all_returns=False)
# Convert discounted_returns from [batch_size] to [batch_size, 1]
discounted_returns = tf.expand_dims(discounted_returns, -1)
final_value_discount = tf.reduce_prod(discounts, axis=1)
final_value_discount = tf.expand_dims(final_value_discount, -1)
# Save the values of discounted_returns and final_value_discount in
# order to check them in unit tests.
self._discounted_returns = discounted_returns
self._final_value_discount = final_value_discount
target_support = tf.add(discounted_returns,
final_value_discount * tiled_support,
name='target_support')
target_distribution = tf.stop_gradient(project_distribution(
target_support, next_q_distribution, self._support))
# Obtain the current Q-value logits for the selected actions.
indices = tf.range(batch_size)
indices = tf.cast(indices, actions.dtype)
reshaped_actions = tf.stack([indices, actions], axis=-1)
chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions)
# Compute the cross-entropy loss between the logits. If inputs have
# a time dimension, compute the sum over the time dimension before
# computing the mean over the batch dimension.
if batch_squash is not None:
target_distribution = batch_squash.unflatten(target_distribution)
chosen_action_logits = batch_squash.unflatten(chosen_action_logits)
critic_loss = tf.reduce_sum(
tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
labels=target_distribution,
logits=chosen_action_logits),
axis=1)
else:
critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
labels=target_distribution,
logits=chosen_action_logits)
agg_loss = common.aggregate_losses(
per_example_loss=critic_loss,
regularization_loss=self._q_network.losses)
total_loss = agg_loss.total_loss
dict_losses = {'critic_loss': agg_loss.weighted,
'reg_loss': agg_loss.regularization,
'total_loss': total_loss}
common.summarize_scalar_dict(dict_losses,
step=self.train_step_counter,
name_scope='Losses/')
if self._debug_summaries:
distribution_errors = target_distribution - chosen_action_logits
with tf.name_scope('distribution_errors'):
common.generate_tensor_summaries(
'distribution_errors', distribution_errors,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
'mean', tf.reduce_mean(distribution_errors),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
'mean_abs', tf.reduce_mean(tf.abs(distribution_errors)),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
'max', tf.reduce_max(distribution_errors),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
'min', tf.reduce_min(distribution_errors),
step=self.train_step_counter)
with tf.name_scope('target_distribution'):
common.generate_tensor_summaries(
'target_distribution', target_distribution,
step=self.train_step_counter)
# TODO(b/127318640): Give appropriate values for td_loss and td_error for
# prioritized replay.
return tf_agent.LossInfo(total_loss, dqn_agent.DqnLossInfo(td_loss=(),
td_error=()))