def __init__()

in tf_agents/agents/cql/cql_sac_agent.py [0:0]


  def __init__(self,
               time_step_spec: ts.TimeStep,
               action_spec: types.NestedTensorSpec,
               critic_network: network.Network,
               actor_network: network.Network,
               actor_optimizer: types.Optimizer,
               critic_optimizer: types.Optimizer,
               alpha_optimizer: types.Optimizer,
               cql_alpha: Union[types.Float, tf.Variable],
               num_cql_samples: int,
               include_critic_entropy_term: bool,
               use_lagrange_cql_alpha: bool,
               cql_alpha_learning_rate: Union[types.Float, tf.Variable] = 1e-4,
               cql_tau: Union[types.Float, tf.Variable] = 10.0,
               random_seed: Optional[int] = None,
               reward_noise_variance: Union[types.Float, tf.Variable] = 0.0,
               num_bc_steps: int = 0,
               actor_loss_weight: types.Float = 1.0,
               critic_loss_weight: types.Float = 0.5,
               alpha_loss_weight: types.Float = 1.0,
               actor_policy_ctor: Callable[
                   ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy,
               critic_network_2: Optional[network.Network] = None,
               target_critic_network: Optional[network.Network] = None,
               target_critic_network_2: Optional[network.Network] = None,
               target_update_tau: types.Float = 1.0,
               target_update_period: types.Int = 1,
               td_errors_loss_fn: types.LossFn = tf.math.squared_difference,
               gamma: types.Float = 1.0,
               reward_scale_factor: types.Float = 1.0,
               initial_log_alpha: types.Float = 0.0,
               use_log_alpha_in_alpha_loss: bool = True,
               target_entropy: Optional[types.Float] = None,
               gradient_clipping: Optional[types.Float] = None,
               log_cql_alpha_clipping: Optional[Tuple[types.Float,
                                                      types.Float]] = None,
               softmax_temperature: types.Float = 1.0,
               bc_debug_mode: bool = False,
               debug_summaries: bool = False,
               summarize_grads_and_vars: bool = False,
               train_step_counter: Optional[tf.Variable] = None,
               name: Optional[Text] = None):
    """Creates a CQL-SAC Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      critic_network: A function critic_network((observations, actions)) that
        returns the q_values for each observation and action.
      actor_network: A function actor_network(observation, action_spec) that
        returns action distribution.
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      alpha_optimizer: The default optimizer to use for the alpha variable.
      cql_alpha: The weight on CQL loss. This can be a tf.Variable.
      num_cql_samples: Number of samples for importance sampling in CQL.
      include_critic_entropy_term: Whether to include the entropy term in the
        target for the critic loss.
      use_lagrange_cql_alpha: Whether to use a Lagrange threshold to
        tune cql_alpha during training.
      cql_alpha_learning_rate: The learning rate to tune cql_alpha.
      cql_tau: The threshold for the expected difference in Q-values which
        determines the tuning of cql_alpha.
      random_seed: Optional seed for tf.random.
      reward_noise_variance: The noise variance to introduce to the rewards.
      num_bc_steps: Number of behavioral cloning steps.
      actor_loss_weight: The weight on actor loss.
      critic_loss_weight: The weight on critic loss.
      alpha_loss_weight: The weight on alpha loss.
      actor_policy_ctor: The policy class to use.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target critic network during Q learning. Every
        `target_update_period` train steps, the weights from `critic_network`
        are copied (possibly withsmoothing via `target_update_tau`) to `
        target_critic_network`.  If `target_critic_network` is not provided, it
        is created by making a copy of `critic_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `critic_network`, so that this can be used as a target network.
        If you provide a `target_critic_network` that shares any weights with
        `critic_network`, a warning will be logged but no exception is thrown.
      target_critic_network_2: (Optional.) Similar network as
        target_critic_network but for the critic_network_2. See documentation
        for target_critic_network. Will only be used if 'critic_network_2' is
        also specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn:  A function for computing the elementwise TD errors
        loss.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      initial_log_alpha: Initial value for log_alpha.
      use_log_alpha_in_alpha_loss: A boolean, whether using log_alpha or alpha
        in alpha loss. Certain implementations of SAC use log_alpha as log
        values are generally nicer to work with.
      target_entropy: The target average policy entropy, for updating alpha. The
        default value is negative of the total number of actions.
      gradient_clipping: Norm length to clip gradients.
      log_cql_alpha_clipping: (Minimum, maximum) values to clip log CQL alpha.
      softmax_temperature: Temperature value which weights Q-values before
        the `cql_loss` logsumexp calculation.
      bc_debug_mode: Whether to run a behavioral cloning mode where the critic
        loss only depends on CQL loss. Useful when debugging and checking that
        CQL loss can be driven down to zero.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run. Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.
    """
    super(CqlSacAgent, self).__init__(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        critic_network=critic_network,
        actor_network=actor_network,
        actor_optimizer=actor_optimizer,
        critic_optimizer=critic_optimizer,
        alpha_optimizer=alpha_optimizer,
        actor_loss_weight=actor_loss_weight,
        critic_loss_weight=critic_loss_weight,
        alpha_loss_weight=alpha_loss_weight,
        actor_policy_ctor=actor_policy_ctor,
        critic_network_2=critic_network_2,
        target_critic_network=target_critic_network,
        target_critic_network_2=target_critic_network_2,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        initial_log_alpha=initial_log_alpha,
        use_log_alpha_in_alpha_loss=use_log_alpha_in_alpha_loss,
        target_entropy=target_entropy,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter,
        name=name)
    self._use_lagrange_cql_alpha = use_lagrange_cql_alpha
    if self._use_lagrange_cql_alpha:
      self._log_cql_alpha = tf.Variable(tf.math.log(cql_alpha), trainable=True)
      self._cql_tau = cql_tau
      self._cql_alpha_optimizer = tf.keras.optimizers.Adam(
          learning_rate=cql_alpha_learning_rate)
    else:
      self._cql_alpha = cql_alpha

    self._num_cql_samples = num_cql_samples
    self._include_critic_entropy_term = include_critic_entropy_term
    self._action_seed_stream = tfp.util.SeedStream(
        seed=random_seed, salt='random_actions')
    self._reward_seed_stream = tfp.util.SeedStream(
        seed=random_seed, salt='random_reward_noise')
    self._reward_noise_variance = reward_noise_variance
    self._num_bc_steps = num_bc_steps
    self._log_cql_alpha_clipping = log_cql_alpha_clipping
    self._softmax_temperature = softmax_temperature
    self._bc_debug_mode = bc_debug_mode