in tf_agents/agents/sac/sac_agent.py [0:0]
def __init__(self,
time_step_spec: ts.TimeStep,
action_spec: types.NestedTensorSpec,
critic_network: network.Network,
actor_network: network.Network,
actor_optimizer: types.Optimizer,
critic_optimizer: types.Optimizer,
alpha_optimizer: types.Optimizer,
actor_loss_weight: types.Float = 1.0,
critic_loss_weight: types.Float = 0.5,
alpha_loss_weight: types.Float = 1.0,
actor_policy_ctor: Callable[
..., tf_policy.TFPolicy] = actor_policy.ActorPolicy,
critic_network_2: Optional[network.Network] = None,
target_critic_network: Optional[network.Network] = None,
target_critic_network_2: Optional[network.Network] = None,
target_update_tau: types.Float = 1.0,
target_update_period: types.Int = 1,
td_errors_loss_fn: types.LossFn = tf.math.squared_difference,
gamma: types.Float = 1.0,
reward_scale_factor: types.Float = 1.0,
initial_log_alpha: types.Float = 0.0,
use_log_alpha_in_alpha_loss: bool = True,
target_entropy: Optional[types.Float] = None,
gradient_clipping: Optional[types.Float] = None,
debug_summaries: bool = False,
summarize_grads_and_vars: bool = False,
train_step_counter: Optional[tf.Variable] = None,
name: Optional[Text] = None):
"""Creates a SAC Agent.
Args:
time_step_spec: A `TimeStep` spec of the expected time_steps.
action_spec: A nest of BoundedTensorSpec representing the actions.
critic_network: A function critic_network((observations, actions)) that
returns the q_values for each observation and action.
actor_network: A function actor_network(observation, action_spec) that
returns action distribution.
actor_optimizer: The optimizer to use for the actor network.
critic_optimizer: The default optimizer to use for the critic network.
alpha_optimizer: The default optimizer to use for the alpha variable.
actor_loss_weight: The weight on actor loss.
critic_loss_weight: The weight on critic loss.
alpha_loss_weight: The weight on alpha loss.
actor_policy_ctor: The policy class to use.
critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as
the second critic network during Q learning. The weights from
`critic_network` are copied if this is not provided.
target_critic_network: (Optional.) A `tf_agents.network.Network` to be
used as the target critic network during Q learning. Every
`target_update_period` train steps, the weights from `critic_network`
are copied (possibly withsmoothing via `target_update_tau`) to `
target_critic_network`. If `target_critic_network` is not provided, it
is created by making a copy of `critic_network`, which initializes a new
network with the same structure and its own layers and weights.
Performing a `Network.copy` does not work when the network instance
already has trainable parameters (e.g., has already been built, or when
the network is sharing layers with another). In these cases, it is up
to you to build a copy having weights that are not shared with the
original `critic_network`, so that this can be used as a target network.
If you provide a `target_critic_network` that shares any weights with
`critic_network`, a warning will be logged but no exception is thrown.
target_critic_network_2: (Optional.) Similar network as
target_critic_network but for the critic_network_2. See documentation
for target_critic_network. Will only be used if 'critic_network_2' is
also specified.
target_update_tau: Factor for soft update of the target networks.
target_update_period: Period for soft update of the target networks.
td_errors_loss_fn: A function for computing the elementwise TD errors
loss.
gamma: A discount factor for future rewards.
reward_scale_factor: Multiplicative scale for the reward.
initial_log_alpha: Initial value for log_alpha.
use_log_alpha_in_alpha_loss: A boolean, whether using log_alpha or alpha
in alpha loss. Certain implementations of SAC use log_alpha as log
values are generally nicer to work with.
target_entropy: The target average policy entropy, for updating alpha. The
default value is negative of the total number of actions.
gradient_clipping: Norm length to clip gradients.
debug_summaries: A bool to gather debug summaries.
summarize_grads_and_vars: If True, gradient and network variable summaries
will be written during training.
train_step_counter: An optional counter to increment every time the train
op is run. Defaults to the global_step.
name: The name of this agent. All variables in this module will fall under
that name. Defaults to the class name.
"""
tf.Module.__init__(self, name=name)
self._check_action_spec(action_spec)
net_observation_spec = time_step_spec.observation
critic_spec = (net_observation_spec, action_spec)
self._critic_network_1 = critic_network
if critic_network_2 is not None:
self._critic_network_2 = critic_network_2
else:
self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
# Do not use target_critic_network_2 if critic_network_2 is None.
target_critic_network_2 = None
# Wait until critic_network_2 has been copied from critic_network_1 before
# creating variables on both.
self._critic_network_1.create_variables(critic_spec)
self._critic_network_2.create_variables(critic_spec)
if target_critic_network:
target_critic_network.create_variables(critic_spec)
self._target_critic_network_1 = (
common.maybe_copy_target_network_with_checks(
self._critic_network_1,
target_critic_network,
input_spec=critic_spec,
name='TargetCriticNetwork1'))
if target_critic_network_2:
target_critic_network_2.create_variables(critic_spec)
self._target_critic_network_2 = (
common.maybe_copy_target_network_with_checks(
self._critic_network_2,
target_critic_network_2,
input_spec=critic_spec,
name='TargetCriticNetwork2'))
if actor_network:
actor_network.create_variables(net_observation_spec)
self._actor_network = actor_network
policy = actor_policy_ctor(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=self._actor_network,
training=False)
self._train_policy = actor_policy_ctor(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=self._actor_network,
training=True)
self._log_alpha = common.create_variable(
'initial_log_alpha',
initial_value=initial_log_alpha,
dtype=tf.float32,
trainable=True)
if target_entropy is None:
target_entropy = self._get_default_target_entropy(action_spec)
self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss
self._target_update_tau = target_update_tau
self._target_update_period = target_update_period
self._actor_optimizer = actor_optimizer
self._critic_optimizer = critic_optimizer
self._alpha_optimizer = alpha_optimizer
self._actor_loss_weight = actor_loss_weight
self._critic_loss_weight = critic_loss_weight
self._alpha_loss_weight = alpha_loss_weight
self._td_errors_loss_fn = td_errors_loss_fn
self._gamma = gamma
self._reward_scale_factor = reward_scale_factor
self._target_entropy = target_entropy
self._gradient_clipping = gradient_clipping
self._debug_summaries = debug_summaries
self._summarize_grads_and_vars = summarize_grads_and_vars
self._update_target = self._get_target_updater(
tau=self._target_update_tau, period=self._target_update_period)
train_sequence_length = 2 if not critic_network.state_spec else None
super(SacAgent, self).__init__(
time_step_spec,
action_spec,
policy=policy,
collect_policy=policy,
train_sequence_length=train_sequence_length,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
train_step_counter=train_step_counter,
)
self._as_transition = data_converter.AsTransition(
self.data_context, squeeze_time_dim=(train_sequence_length == 2))