def total_loss()

in tf_agents/agents/reinforce/reinforce_agent.py [0:0]


  def total_loss(self,
                 experience: traj.Trajectory,
                 returns: types.Tensor,
                 weights: types.Tensor,
                 training: bool = False) -> tf_agent.LossInfo:
    # Ensure we see at least one full episode.
    time_steps = ts.TimeStep(experience.step_type,
                             tf.zeros_like(experience.reward),
                             tf.zeros_like(experience.discount),
                             experience.observation)
    is_last = experience.is_last()
    num_episodes = tf.reduce_sum(tf.cast(is_last, tf.float32))
    tf.debugging.assert_greater(
        num_episodes,
        0.0,
        message='No complete episode found. REINFORCE requires full episodes '
        'to compute losses.')

    # Mask out partial episodes at the end of each batch of time_steps.
    # NOTE: We use is_last rather than is_boundary because the last transition
    # is the transition with the last valid reward.  In other words, the
    # reward on the boundary transitions do not have valid rewards.  Since
    # REINFORCE is calculating a loss w.r.t. the returns (and not bootstrapping)
    # keeping the boundary transitions is irrelevant.
    valid_mask = tf.cast(experience.is_last(), dtype=tf.float32)
    valid_mask = tf.math.cumsum(valid_mask, axis=1, reverse=True)
    valid_mask = tf.cast(valid_mask > 0, dtype=tf.float32)
    if weights is not None:
      weights *= valid_mask
    else:
      weights = valid_mask

    advantages = returns
    value_preds = None

    if self._baseline:
      value_preds, _ = self._value_network(time_steps.observation,
                                           time_steps.step_type,
                                           training=True)
      if self._debug_summaries:
        tf.compat.v2.summary.histogram(
            name='value_preds', data=value_preds, step=self.train_step_counter)

    advantages = self._advantage_fn(returns, value_preds)
    if self._debug_summaries:
      tf.compat.v2.summary.histogram(
          name='advantages', data=advantages, step=self.train_step_counter)

    # TODO(b/126592060): replace with tensor normalizer.
    if self._normalize_returns:
      advantages = _standard_normalize(advantages, axes=(0, 1))
      if self._debug_summaries:
        tf.compat.v2.summary.histogram(
            name='normalized_%s' %
            ('advantages' if self._baseline else 'returns'),
            data=advantages,
            step=self.train_step_counter)

    nest_utils.assert_same_structure(time_steps, self.time_step_spec)
    policy_state = _get_initial_policy_state(self.collect_policy, time_steps)
    actions_distribution = self.collect_policy.distribution(
        time_steps, policy_state=policy_state).action

    policy_gradient_loss = self.policy_gradient_loss(
        actions_distribution,
        experience.action,
        experience.is_boundary(),
        advantages,
        num_episodes,
        weights,
    )

    entropy_regularization_loss = self.entropy_regularization_loss(
        actions_distribution, weights)

    network_regularization_loss = tf.nn.scale_regularization_loss(
        self._actor_network.losses)

    total_loss = (policy_gradient_loss +
                  network_regularization_loss +
                  entropy_regularization_loss)

    losses_dict = {
        'policy_gradient_loss': policy_gradient_loss,
        'policy_network_regularization_loss': network_regularization_loss,
        'entropy_regularization_loss': entropy_regularization_loss,
        'value_estimation_loss': 0.0,
        'value_network_regularization_loss': 0.0,
    }

    value_estimation_loss = None
    if self._baseline:
      value_estimation_loss = self.value_estimation_loss(
          value_preds, returns, num_episodes, weights)
      value_network_regularization_loss = tf.nn.scale_regularization_loss(
          self._value_network.losses)
      total_loss += value_estimation_loss + value_network_regularization_loss
      losses_dict['value_estimation_loss'] = value_estimation_loss
      losses_dict['value_network_regularization_loss'] = (
          value_network_regularization_loss)

    loss_info_extra = ReinforceAgentLossInfo(**losses_dict)

    losses_dict['total_loss'] = total_loss  # Total loss not in loss_info_extra.

    common.summarize_scalar_dict(losses_dict,
                                 self.train_step_counter,
                                 name_scope='Losses/')

    return tf_agent.LossInfo(total_loss, loss_info_extra)