def _loss()

in tf_agents/agents/categorical_dqn/categorical_dqn_agent.py [0:0]


  def _loss(self,
            experience,
            td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
            gamma=1.0,
            reward_scale_factor=1.0,
            weights=None,
            training=False):
    """Computes critic loss for CategoricalDQN training.

    See Algorithm 1 and the discussion immediately preceding it in page 6 of
    "A Distributional Perspective on Reinforcement Learning"
      Bellemare et al., 2017
      https://arxiv.org/abs/1707.06887

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.required_experience_time_steps` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional weights used for importance sampling.
      training: Whether the loss is being used for training.
    Returns:
      critic_loss: A scalar critic loss.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
    squeeze_time_dim = not self._q_network.state_spec
    if self._n_step_update == 1:
      time_steps, policy_steps, next_time_steps = (
          trajectory.experience_to_transitions(experience, squeeze_time_dim))
      actions = policy_steps.action
    else:
      # To compute n-step returns, we need the first time steps, the first
      # actions, and the last time steps. Therefore we extract the first and
      # last transitions from our Trajectory.
      first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience)
      last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience)
      time_steps, policy_steps, _ = (
          trajectory.experience_to_transitions(
              first_two_steps, squeeze_time_dim))
      actions = policy_steps.action
      _, _, next_time_steps = (
          trajectory.experience_to_transitions(
              last_two_steps, squeeze_time_dim))

    with tf.name_scope('critic_loss'):
      nest_utils.assert_same_structure(actions, self.action_spec)
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
      nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)

      rank = nest_utils.get_outer_rank(time_steps.observation,
                                       self._time_step_spec.observation)

      # If inputs have a time dimension and the q_network is stateful,
      # combine the batch and time dimension.
      batch_squash = (None
                      if rank <= 1 or self._q_network.state_spec in ((), None)
                      else network_utils.BatchSquash(rank))

      network_observation = time_steps.observation

      if self._observation_and_action_constraint_splitter is not None:
        network_observation, _ = (
            self._observation_and_action_constraint_splitter(
                network_observation))

      # q_logits contains the Q-value logits for all actions.
      q_logits, _ = self._q_network(network_observation,
                                    step_type=time_steps.step_type,
                                    training=training)

      if batch_squash is not None:
        # Squash outer dimensions to a single dimensions for facilitation
        # computing the loss the following. Required for supporting temporal
        # inputs, for example.
        q_logits = batch_squash.flatten(q_logits)
        actions = batch_squash.flatten(actions)
        next_time_steps = tf.nest.map_structure(batch_squash.flatten,
                                                next_time_steps)

      next_q_distribution = self._next_q_distribution(next_time_steps)

      if actions.shape.rank > 1:
        actions = tf.squeeze(actions, list(range(1, actions.shape.rank)))

      # Project the sample Bellman update \hat{T}Z_{\theta} onto the original
      # support of Z_{\theta} (see Figure 1 in paper).
      batch_size = q_logits.shape[0] or tf.shape(q_logits)[0]
      tiled_support = tf.tile(self._support, [batch_size])
      tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms])

      if self._n_step_update == 1:
        discount = next_time_steps.discount
        if discount.shape.rank == 1:
          # We expect discount to have a shape of [batch_size], while
          # tiled_support will have a shape of [batch_size, num_atoms]. To
          # multiply these, we add a second dimension of 1 to the discount.
          discount = tf.expand_dims(discount, -1)
        next_value_term = tf.multiply(discount,
                                      tiled_support,
                                      name='next_value_term')

        reward = next_time_steps.reward
        if reward.shape.rank == 1:
          # See the explanation above.
          reward = tf.expand_dims(reward, -1)
        reward_term = tf.multiply(reward_scale_factor,
                                  reward,
                                  name='reward_term')

        target_support = tf.add(reward_term, gamma * next_value_term,
                                name='target_support')
      else:
        # When computing discounted return, we need to throw out the last time
        # index of both reward and discount, which are filled with dummy values
        # to match the dimensions of the observation.
        rewards = reward_scale_factor * experience.reward[:, :-1]
        discounts = gamma * experience.discount[:, :-1]

        # TODO(b/134618876): Properly handle Trajectories that include episode
        # boundaries with nonzero discount.

        discounted_returns = value_ops.discounted_return(
            rewards=rewards,
            discounts=discounts,
            final_value=tf.zeros([batch_size], dtype=discounts.dtype),
            time_major=False,
            provide_all_returns=False)

        # Convert discounted_returns from [batch_size] to [batch_size, 1]
        discounted_returns = tf.expand_dims(discounted_returns, -1)

        final_value_discount = tf.reduce_prod(discounts, axis=1)
        final_value_discount = tf.expand_dims(final_value_discount, -1)

        # Save the values of discounted_returns and final_value_discount in
        # order to check them in unit tests.
        self._discounted_returns = discounted_returns
        self._final_value_discount = final_value_discount

        target_support = tf.add(discounted_returns,
                                final_value_discount * tiled_support,
                                name='target_support')

      target_distribution = tf.stop_gradient(project_distribution(
          target_support, next_q_distribution, self._support))

      # Obtain the current Q-value logits for the selected actions.
      indices = tf.range(batch_size)
      indices = tf.cast(indices, actions.dtype)
      reshaped_actions = tf.stack([indices, actions], axis=-1)
      chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions)

      # Compute the cross-entropy loss between the logits. If inputs have
      # a time dimension, compute the sum over the time dimension before
      # computing the mean over the batch dimension.
      if batch_squash is not None:
        target_distribution = batch_squash.unflatten(target_distribution)
        chosen_action_logits = batch_squash.unflatten(chosen_action_logits)
        critic_loss = tf.reduce_sum(
            tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                labels=target_distribution,
                logits=chosen_action_logits),
            axis=1)
      else:
        critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
            labels=target_distribution,
            logits=chosen_action_logits)

      agg_loss = common.aggregate_losses(
          per_example_loss=critic_loss,
          regularization_loss=self._q_network.losses)
      total_loss = agg_loss.total_loss

      dict_losses = {'critic_loss': agg_loss.weighted,
                     'reg_loss': agg_loss.regularization,
                     'total_loss': total_loss}

      common.summarize_scalar_dict(dict_losses,
                                   step=self.train_step_counter,
                                   name_scope='Losses/')

      if self._debug_summaries:
        distribution_errors = target_distribution - chosen_action_logits
        with tf.name_scope('distribution_errors'):
          common.generate_tensor_summaries(
              'distribution_errors', distribution_errors,
              step=self.train_step_counter)
          tf.compat.v2.summary.scalar(
              'mean', tf.reduce_mean(distribution_errors),
              step=self.train_step_counter)
          tf.compat.v2.summary.scalar(
              'mean_abs', tf.reduce_mean(tf.abs(distribution_errors)),
              step=self.train_step_counter)
          tf.compat.v2.summary.scalar(
              'max', tf.reduce_max(distribution_errors),
              step=self.train_step_counter)
          tf.compat.v2.summary.scalar(
              'min', tf.reduce_min(distribution_errors),
              step=self.train_step_counter)
        with tf.name_scope('target_distribution'):
          common.generate_tensor_summaries(
              'target_distribution', target_distribution,
              step=self.train_step_counter)

      # TODO(b/127318640): Give appropriate values for td_loss and td_error for
      # prioritized replay.
      return tf_agent.LossInfo(total_loss, dqn_agent.DqnLossInfo(td_loss=(),
                                                                 td_error=()))