def policy_gradient_loss()

in tf_agents/agents/ppo/ppo_agent.py [0:0]


  def policy_gradient_loss(
      self,
      time_steps: ts.TimeStep,
      actions: types.NestedTensor,
      sample_action_log_probs: types.Tensor,
      advantages: types.Tensor,
      current_policy_distribution: types.NestedDistribution,
      weights: types.Tensor,
      debug_summaries: bool = False) -> types.Tensor:
    """Create tensor for policy gradient loss.

    All tensors should have a single batch dimension.

    Args:
      time_steps: TimeSteps with observations for each timestep.
      actions: Tensor of actions for timesteps, aligned on index.
      sample_action_log_probs: Tensor of sample probability of each action.
      advantages: Tensor of advantage estimate for each timestep, aligned on
        index. Works better when advantage estimates are normalized.
      current_policy_distribution: The policy distribution, evaluated on all
        time_steps.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.  Includes a mask for invalid timesteps.
      debug_summaries: True if debug summaries should be created.

    Returns:
      policy_gradient_loss: A tensor that will contain policy gradient loss for
        the on-policy experience.
    """
    nest_utils.assert_same_structure(time_steps, self.time_step_spec)
    action_log_prob = common.log_probability(current_policy_distribution,
                                             actions, self._action_spec)
    action_log_prob = tf.cast(action_log_prob, tf.float32)
    if self._log_prob_clipping > 0.0:
      action_log_prob = tf.clip_by_value(action_log_prob,
                                         -self._log_prob_clipping,
                                         self._log_prob_clipping)
    if self._check_numerics:
      action_log_prob = tf.debugging.check_numerics(action_log_prob,
                                                    'action_log_prob')

    # Prepare both clipped and unclipped importance ratios.
    importance_ratio = tf.exp(action_log_prob - sample_action_log_probs)
    importance_ratio_clipped = tf.clip_by_value(
        importance_ratio, 1 - self._importance_ratio_clipping,
        1 + self._importance_ratio_clipping)

    if self._check_numerics:
      importance_ratio = tf.debugging.check_numerics(importance_ratio,
                                                     'importance_ratio')
      if self._importance_ratio_clipping > 0.0:
        importance_ratio_clipped = tf.debugging.check_numerics(
            importance_ratio_clipped, 'importance_ratio_clipped')

    # Pessimistically choose the minimum objective value for clipped and
    #   unclipped importance ratios.
    per_timestep_objective = importance_ratio * advantages
    per_timestep_objective_clipped = importance_ratio_clipped * advantages
    per_timestep_objective_min = tf.minimum(per_timestep_objective,
                                            per_timestep_objective_clipped)

    if self._importance_ratio_clipping > 0.0:
      policy_gradient_loss = -per_timestep_objective_min
    else:
      policy_gradient_loss = -per_timestep_objective

    if self._aggregate_losses_across_replicas:
      policy_gradient_loss = common.aggregate_losses(
          per_example_loss=policy_gradient_loss,
          sample_weight=weights).total_loss
    else:
      policy_gradient_loss = tf.math.reduce_mean(policy_gradient_loss * weights)

    if debug_summaries:
      if self._importance_ratio_clipping > 0.0:
        clip_fraction = tf.reduce_mean(
            input_tensor=tf.cast(
                tf.greater(
                    tf.abs(importance_ratio -
                           1.0), self._importance_ratio_clipping), tf.float32))
        tf.compat.v2.summary.scalar(
            name='clip_fraction',
            data=clip_fraction,
            step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='importance_ratio_mean',
          data=tf.reduce_mean(input_tensor=importance_ratio),
          step=self.train_step_counter)
      entropy = common.entropy(current_policy_distribution, self.action_spec)
      tf.compat.v2.summary.scalar(
          name='policy_entropy_mean',
          data=tf.reduce_mean(input_tensor=entropy),
          step=self.train_step_counter)
      # TODO(b/171573175): remove the condition once histograms are supported
      # on TPUs.
      if not tf.config.list_logical_devices('TPU'):
        tf.compat.v2.summary.histogram(
            name='action_log_prob',
            data=action_log_prob,
            step=self.train_step_counter)
        tf.compat.v2.summary.histogram(
            name='action_log_prob_sample',
            data=sample_action_log_probs,
            step=self.train_step_counter)
        tf.compat.v2.summary.histogram(
            name='importance_ratio',
            data=importance_ratio,
            step=self.train_step_counter)
        tf.compat.v2.summary.histogram(
            name='importance_ratio_clipped',
            data=importance_ratio_clipped,
            step=self.train_step_counter)
        tf.compat.v2.summary.histogram(
            name='per_timestep_objective',
            data=per_timestep_objective,
            step=self.train_step_counter)
        tf.compat.v2.summary.histogram(
            name='per_timestep_objective_clipped',
            data=per_timestep_objective_clipped,
            step=self.train_step_counter)
        tf.compat.v2.summary.histogram(
            name='per_timestep_objective_min',
            data=per_timestep_objective_min,
            step=self.train_step_counter)

        tf.compat.v2.summary.histogram(
            name='policy_entropy', data=entropy, step=self.train_step_counter)
        for i, (single_action, single_distribution) in enumerate(
            zip(
                tf.nest.flatten(self.action_spec),
                tf.nest.flatten(current_policy_distribution))):
          # Categorical distribution (used for discrete actions) doesn't have a
          # mean.
          distribution_index = '_{}'.format(i) if i > 0 else ''
          if not tensor_spec.is_discrete(single_action):
            tf.compat.v2.summary.histogram(
                name='actions_distribution_mean' + distribution_index,
                data=single_distribution.mean(),
                step=self.train_step_counter)
            tf.compat.v2.summary.histogram(
                name='actions_distribution_stddev' + distribution_index,
                data=single_distribution.stddev(),
                step=self.train_step_counter)
        tf.compat.v2.summary.histogram(
            name='policy_gradient_loss',
            data=policy_gradient_loss,
            step=self.train_step_counter)

    if self._check_numerics:
      policy_gradient_loss = tf.debugging.check_numerics(
          policy_gradient_loss, 'policy_gradient_loss')

    return policy_gradient_loss