in tf_agents/agents/cql/cql_sac_agent.py [0:0]
def _train(self, experience, weights):
"""Returns a train op to update the agent's networks.
This method trains with the provided batched experience.
Args:
experience: A time-stacked trajectory object.
weights: Optional scalar or elementwise (per-batch-entry) importance
weights.
Returns:
A train_op.
Raises:
ValueError: If optimizers are None and no default value was provided to
the constructor.
"""
transition = self._as_transition(experience)
time_steps, policy_steps, next_time_steps = transition
actions = policy_steps.action
trainable_critic_variables = list(object_identity.ObjectIdentitySet(
self._critic_network_1.trainable_variables +
self._critic_network_2.trainable_variables))
with tf.GradientTape(watch_accessed_variables=False) as tape:
assert trainable_critic_variables, ('No trainable critic variables to '
'optimize.')
tape.watch(trainable_critic_variables)
critic_loss = self._critic_loss_with_optional_entropy_term(
time_steps,
actions,
next_time_steps,
td_errors_loss_fn=self._td_errors_loss_fn,
gamma=self._gamma,
reward_scale_factor=self._reward_scale_factor,
weights=weights,
training=True)
critic_loss *= self._critic_loss_weight
cql_alpha = self._get_cql_alpha()
cql_loss = self._cql_loss(time_steps, actions, training=True)
if self._bc_debug_mode:
cql_critic_loss = cql_loss * cql_alpha
else:
cql_critic_loss = critic_loss + (cql_loss * cql_alpha)
tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
tf.debugging.check_numerics(cql_loss, 'CQL loss is inf or nan.')
critic_grads = tape.gradient(cql_critic_loss, trainable_critic_variables)
self._apply_gradients(critic_grads, trainable_critic_variables,
self._critic_optimizer)
trainable_actor_variables = self._actor_network.trainable_variables
with tf.GradientTape(watch_accessed_variables=False) as tape:
assert trainable_actor_variables, ('No trainable actor variables to '
'optimize.')
tape.watch(trainable_actor_variables)
actor_loss = self._actor_loss_weight * self.actor_loss(
time_steps, actions=actions, weights=weights)
tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
self._apply_gradients(actor_grads, trainable_actor_variables,
self._actor_optimizer)
alpha_variable = [self._log_alpha]
with tf.GradientTape(watch_accessed_variables=False) as tape:
assert alpha_variable, 'No alpha variable to optimize.'
tape.watch(alpha_variable)
alpha_loss = self._alpha_loss_weight * self.alpha_loss(
time_steps, weights=weights)
tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
alpha_grads = tape.gradient(alpha_loss, alpha_variable)
self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer)
# Based on the equation (24), which automates CQL alpha with the "budget"
# parameter tau. CQL(H) is now CQL-Lagrange(H):
# ```
# min_Q max_{alpha >= 0} alpha * (log_sum_exp(Q(s, a')) - Q(s, a) - tau)
# ```
# If the expected difference in Q-values is less than tau, alpha
# will adjust to be closer to 0. If the difference is higher than tau,
# alpha is likely to take on high values and more aggressively penalize
# Q-values.
cql_alpha_loss = tf.constant(0.)
if self._use_lagrange_cql_alpha:
cql_alpha_variable = [self._log_cql_alpha]
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(cql_alpha_variable)
cql_alpha_loss = -self._get_cql_alpha() * (cql_loss - self._cql_tau)
tf.debugging.check_numerics(cql_alpha_loss,
'CQL alpha loss is inf or nan.')
cql_alpha_gradients = tape.gradient(cql_alpha_loss, cql_alpha_variable)
self._apply_gradients(cql_alpha_gradients, cql_alpha_variable,
self._cql_alpha_optimizer)
with tf.name_scope('Losses'):
tf.compat.v2.summary.scalar(
name='critic_loss', data=critic_loss, step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='actor_loss', data=actor_loss, step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='alpha_loss', data=alpha_loss, step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='cql_loss', data=cql_loss, step=self.train_step_counter)
if self._use_lagrange_cql_alpha:
tf.compat.v2.summary.scalar(
name='cql_alpha_loss',
data=cql_alpha_loss,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='cql_alpha', data=cql_alpha, step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='sac_alpha', data=tf.exp(self._log_alpha),
step=self.train_step_counter)
self.train_step_counter.assign_add(1)
self._update_target()
total_loss = cql_critic_loss + actor_loss + alpha_loss
extra = CqlSacLossInfo(
critic_loss=critic_loss,
actor_loss=actor_loss,
alpha_loss=alpha_loss,
cql_loss=cql_loss,
cql_alpha=cql_alpha,
cql_alpha_loss=cql_alpha_loss)
return tf_agent.LossInfo(loss=total_loss, extra=extra)