def critic_loss()

in tf_agents/agents/td3/td3_agent.py [0:0]


  def critic_loss(self,
                  time_steps: ts.TimeStep,
                  actions: types.Tensor,
                  next_time_steps: ts.TimeStep,
                  weights: Optional[types.Tensor] = None,
                  training: bool = False) -> types.Tensor:
    """Computes the critic loss for TD3 training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.

    Returns:
      critic_loss: A scalar critic loss.
    """
    with tf.name_scope('critic_loss'):
      target_actions, _ = self._target_actor_network(
          next_time_steps.observation, next_time_steps.step_type,
          training=training)

      # Add gaussian noise to each action before computing target q values
      def add_noise_to_action(action):  # pylint: disable=missing-docstring
        dist = tfp.distributions.Normal(loc=tf.zeros_like(action),
                                        scale=self._target_policy_noise * \
                                        tf.ones_like(action))
        noise = dist.sample()
        noise = tf.clip_by_value(noise, -self._target_policy_noise_clip,
                                 self._target_policy_noise_clip)
        return action + noise

      noisy_target_actions = tf.nest.map_structure(add_noise_to_action,
                                                   target_actions)

      # Target q-values are the min of the two networks
      target_q_input_1 = (next_time_steps.observation, noisy_target_actions)
      target_q_values_1, _ = self._target_critic_network_1(
          target_q_input_1,
          next_time_steps.step_type,
          training=False)
      target_q_input_2 = (next_time_steps.observation, noisy_target_actions)
      target_q_values_2, _ = self._target_critic_network_2(
          target_q_input_2,
          next_time_steps.step_type,
          training=False)
      target_q_values = tf.minimum(target_q_values_1, target_q_values_2)

      td_targets = tf.stop_gradient(
          self._reward_scale_factor * next_time_steps.reward +
          self._gamma * next_time_steps.discount * target_q_values)

      pred_input_1 = (time_steps.observation, actions)
      pred_td_targets_1, _ = self._critic_network_1(
          pred_input_1, time_steps.step_type, training=training)
      pred_input_2 = (time_steps.observation, actions)
      pred_td_targets_2, _ = self._critic_network_2(
          pred_input_2, time_steps.step_type, training=training)
      pred_td_targets_all = [pred_td_targets_1, pred_td_targets_2]

      if self._debug_summaries:
        tf.compat.v2.summary.histogram(
            name='td_targets', data=td_targets, step=self.train_step_counter)
        with tf.name_scope('td_targets'):
          tf.compat.v2.summary.scalar(
              name='mean',
              data=tf.reduce_mean(input_tensor=td_targets),
              step=self.train_step_counter)
          tf.compat.v2.summary.scalar(
              name='max',
              data=tf.reduce_max(input_tensor=td_targets),
              step=self.train_step_counter)
          tf.compat.v2.summary.scalar(
              name='min',
              data=tf.reduce_min(input_tensor=td_targets),
              step=self.train_step_counter)

        for td_target_idx in range(2):
          pred_td_targets = pred_td_targets_all[td_target_idx]
          td_errors = td_targets - pred_td_targets
          with tf.name_scope('critic_net_%d' % (td_target_idx + 1)):
            tf.compat.v2.summary.histogram(
                name='td_errors', data=td_errors, step=self.train_step_counter)
            tf.compat.v2.summary.histogram(
                name='pred_td_targets',
                data=pred_td_targets,
                step=self.train_step_counter)
            with tf.name_scope('td_errors'):
              tf.compat.v2.summary.scalar(
                  name='mean',
                  data=tf.reduce_mean(input_tensor=td_errors),
                  step=self.train_step_counter)
              tf.compat.v2.summary.scalar(
                  name='mean_abs',
                  data=tf.reduce_mean(input_tensor=tf.abs(td_errors)),
                  step=self.train_step_counter)
              tf.compat.v2.summary.scalar(
                  name='max',
                  data=tf.reduce_max(input_tensor=td_errors),
                  step=self.train_step_counter)
              tf.compat.v2.summary.scalar(
                  name='min',
                  data=tf.reduce_min(input_tensor=td_errors),
                  step=self.train_step_counter)
            with tf.name_scope('pred_td_targets'):
              tf.compat.v2.summary.scalar(
                  name='mean',
                  data=tf.reduce_mean(input_tensor=pred_td_targets),
                  step=self.train_step_counter)
              tf.compat.v2.summary.scalar(
                  name='max',
                  data=tf.reduce_max(input_tensor=pred_td_targets),
                  step=self.train_step_counter)
              tf.compat.v2.summary.scalar(
                  name='min',
                  data=tf.reduce_min(input_tensor=pred_td_targets),
                  step=self.train_step_counter)

      critic_loss = (self._td_errors_loss_fn(td_targets, pred_td_targets_1)
                     + self._td_errors_loss_fn(td_targets, pred_td_targets_2))
      if nest_utils.is_batched_nested_tensors(
          time_steps, self.time_step_spec, num_outer_dims=2):
        # Sum over the time dimension.
        critic_loss = tf.reduce_sum(
            input_tensor=critic_loss, axis=range(1, critic_loss.shape.rank))

      if weights is not None:
        critic_loss *= weights

      return tf.reduce_mean(input_tensor=critic_loss)