def _train()

in tf_agents/agents/cql/cql_sac_agent.py [0:0]


  def _train(self, experience, weights):
    """Returns a train op to update the agent's networks.

    This method trains with the provided batched experience.

    Args:
      experience: A time-stacked trajectory object.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      A train_op.

    Raises:
      ValueError: If optimizers are None and no default value was provided to
        the constructor.
    """
    transition = self._as_transition(experience)
    time_steps, policy_steps, next_time_steps = transition
    actions = policy_steps.action

    trainable_critic_variables = list(object_identity.ObjectIdentitySet(
        self._critic_network_1.trainable_variables +
        self._critic_network_2.trainable_variables))

    with tf.GradientTape(watch_accessed_variables=False) as tape:
      assert trainable_critic_variables, ('No trainable critic variables to '
                                          'optimize.')
      tape.watch(trainable_critic_variables)
      critic_loss = self._critic_loss_with_optional_entropy_term(
          time_steps,
          actions,
          next_time_steps,
          td_errors_loss_fn=self._td_errors_loss_fn,
          gamma=self._gamma,
          reward_scale_factor=self._reward_scale_factor,
          weights=weights,
          training=True)
      critic_loss *= self._critic_loss_weight

      cql_alpha = self._get_cql_alpha()
      cql_loss = self._cql_loss(time_steps, actions, training=True)

      if self._bc_debug_mode:
        cql_critic_loss = cql_loss * cql_alpha
      else:
        cql_critic_loss = critic_loss + (cql_loss * cql_alpha)

    tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
    tf.debugging.check_numerics(cql_loss, 'CQL loss is inf or nan.')
    critic_grads = tape.gradient(cql_critic_loss, trainable_critic_variables)
    self._apply_gradients(critic_grads, trainable_critic_variables,
                          self._critic_optimizer)

    trainable_actor_variables = self._actor_network.trainable_variables
    with tf.GradientTape(watch_accessed_variables=False) as tape:
      assert trainable_actor_variables, ('No trainable actor variables to '
                                         'optimize.')
      tape.watch(trainable_actor_variables)
      actor_loss = self._actor_loss_weight * self.actor_loss(
          time_steps, actions=actions, weights=weights)
    tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
    actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
    self._apply_gradients(actor_grads, trainable_actor_variables,
                          self._actor_optimizer)

    alpha_variable = [self._log_alpha]
    with tf.GradientTape(watch_accessed_variables=False) as tape:
      assert alpha_variable, 'No alpha variable to optimize.'
      tape.watch(alpha_variable)
      alpha_loss = self._alpha_loss_weight * self.alpha_loss(
          time_steps, weights=weights)
    tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
    alpha_grads = tape.gradient(alpha_loss, alpha_variable)
    self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer)

    # Based on the equation (24), which automates CQL alpha with the "budget"
    # parameter tau. CQL(H) is now CQL-Lagrange(H):
    # ```
    # min_Q max_{alpha >= 0} alpha * (log_sum_exp(Q(s, a')) - Q(s, a) - tau)
    # ```
    # If the expected difference in Q-values is less than tau, alpha
    # will adjust to be closer to 0. If the difference is higher than tau,
    # alpha is likely to take on high values and more aggressively penalize
    # Q-values.
    cql_alpha_loss = tf.constant(0.)
    if self._use_lagrange_cql_alpha:
      cql_alpha_variable = [self._log_cql_alpha]
      with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(cql_alpha_variable)
        cql_alpha_loss = -self._get_cql_alpha() * (cql_loss - self._cql_tau)
      tf.debugging.check_numerics(cql_alpha_loss,
                                  'CQL alpha loss is inf or nan.')
      cql_alpha_gradients = tape.gradient(cql_alpha_loss, cql_alpha_variable)
      self._apply_gradients(cql_alpha_gradients, cql_alpha_variable,
                            self._cql_alpha_optimizer)

    with tf.name_scope('Losses'):
      tf.compat.v2.summary.scalar(
          name='critic_loss', data=critic_loss, step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='actor_loss', data=actor_loss, step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='alpha_loss', data=alpha_loss, step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='cql_loss', data=cql_loss, step=self.train_step_counter)
      if self._use_lagrange_cql_alpha:
        tf.compat.v2.summary.scalar(
            name='cql_alpha_loss',
            data=cql_alpha_loss,
            step=self.train_step_counter)
    tf.compat.v2.summary.scalar(
        name='cql_alpha', data=cql_alpha, step=self.train_step_counter)
    tf.compat.v2.summary.scalar(
        name='sac_alpha', data=tf.exp(self._log_alpha),
        step=self.train_step_counter)

    self.train_step_counter.assign_add(1)
    self._update_target()

    total_loss = cql_critic_loss + actor_loss + alpha_loss

    extra = CqlSacLossInfo(
        critic_loss=critic_loss,
        actor_loss=actor_loss,
        alpha_loss=alpha_loss,
        cql_loss=cql_loss,
        cql_alpha=cql_alpha,
        cql_alpha_loss=cql_alpha_loss)

    return tf_agent.LossInfo(loss=total_loss, extra=extra)