in tf_agents/agents/td3/td3_agent.py [0:0]
def critic_loss(self,
time_steps: ts.TimeStep,
actions: types.Tensor,
next_time_steps: ts.TimeStep,
weights: Optional[types.Tensor] = None,
training: bool = False) -> types.Tensor:
"""Computes the critic loss for TD3 training.
Args:
time_steps: A batch of timesteps.
actions: A batch of actions.
next_time_steps: A batch of next timesteps.
weights: Optional scalar or element-wise (per-batch-entry) importance
weights.
training: Whether this loss is being used for training.
Returns:
critic_loss: A scalar critic loss.
"""
with tf.name_scope('critic_loss'):
target_actions, _ = self._target_actor_network(
next_time_steps.observation, next_time_steps.step_type,
training=training)
# Add gaussian noise to each action before computing target q values
def add_noise_to_action(action): # pylint: disable=missing-docstring
dist = tfp.distributions.Normal(loc=tf.zeros_like(action),
scale=self._target_policy_noise * \
tf.ones_like(action))
noise = dist.sample()
noise = tf.clip_by_value(noise, -self._target_policy_noise_clip,
self._target_policy_noise_clip)
return action + noise
noisy_target_actions = tf.nest.map_structure(add_noise_to_action,
target_actions)
# Target q-values are the min of the two networks
target_q_input_1 = (next_time_steps.observation, noisy_target_actions)
target_q_values_1, _ = self._target_critic_network_1(
target_q_input_1,
next_time_steps.step_type,
training=False)
target_q_input_2 = (next_time_steps.observation, noisy_target_actions)
target_q_values_2, _ = self._target_critic_network_2(
target_q_input_2,
next_time_steps.step_type,
training=False)
target_q_values = tf.minimum(target_q_values_1, target_q_values_2)
td_targets = tf.stop_gradient(
self._reward_scale_factor * next_time_steps.reward +
self._gamma * next_time_steps.discount * target_q_values)
pred_input_1 = (time_steps.observation, actions)
pred_td_targets_1, _ = self._critic_network_1(
pred_input_1, time_steps.step_type, training=training)
pred_input_2 = (time_steps.observation, actions)
pred_td_targets_2, _ = self._critic_network_2(
pred_input_2, time_steps.step_type, training=training)
pred_td_targets_all = [pred_td_targets_1, pred_td_targets_2]
if self._debug_summaries:
tf.compat.v2.summary.histogram(
name='td_targets', data=td_targets, step=self.train_step_counter)
with tf.name_scope('td_targets'):
tf.compat.v2.summary.scalar(
name='mean',
data=tf.reduce_mean(input_tensor=td_targets),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='max',
data=tf.reduce_max(input_tensor=td_targets),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='min',
data=tf.reduce_min(input_tensor=td_targets),
step=self.train_step_counter)
for td_target_idx in range(2):
pred_td_targets = pred_td_targets_all[td_target_idx]
td_errors = td_targets - pred_td_targets
with tf.name_scope('critic_net_%d' % (td_target_idx + 1)):
tf.compat.v2.summary.histogram(
name='td_errors', data=td_errors, step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='pred_td_targets',
data=pred_td_targets,
step=self.train_step_counter)
with tf.name_scope('td_errors'):
tf.compat.v2.summary.scalar(
name='mean',
data=tf.reduce_mean(input_tensor=td_errors),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='mean_abs',
data=tf.reduce_mean(input_tensor=tf.abs(td_errors)),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='max',
data=tf.reduce_max(input_tensor=td_errors),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='min',
data=tf.reduce_min(input_tensor=td_errors),
step=self.train_step_counter)
with tf.name_scope('pred_td_targets'):
tf.compat.v2.summary.scalar(
name='mean',
data=tf.reduce_mean(input_tensor=pred_td_targets),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='max',
data=tf.reduce_max(input_tensor=pred_td_targets),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='min',
data=tf.reduce_min(input_tensor=pred_td_targets),
step=self.train_step_counter)
critic_loss = (self._td_errors_loss_fn(td_targets, pred_td_targets_1)
+ self._td_errors_loss_fn(td_targets, pred_td_targets_2))
if nest_utils.is_batched_nested_tensors(
time_steps, self.time_step_spec, num_outer_dims=2):
# Sum over the time dimension.
critic_loss = tf.reduce_sum(
input_tensor=critic_loss, axis=range(1, critic_loss.shape.rank))
if weights is not None:
critic_loss *= weights
return tf.reduce_mean(input_tensor=critic_loss)