in tf_agents/agents/td3/td3_agent.py [0:0]
def __init__(self,
time_step_spec: ts.TimeStep,
action_spec: types.NestedTensor,
actor_network: network.Network,
critic_network: network.Network,
actor_optimizer: types.Optimizer,
critic_optimizer: types.Optimizer,
exploration_noise_std: types.Float = 0.1,
critic_network_2: Optional[network.Network] = None,
target_actor_network: Optional[network.Network] = None,
target_critic_network: Optional[network.Network] = None,
target_critic_network_2: Optional[network.Network] = None,
target_update_tau: types.Float = 1.0,
target_update_period: types.Int = 1,
actor_update_period: types.Int = 1,
td_errors_loss_fn: Optional[types.LossFn] = None,
gamma: types.Float = 1.0,
reward_scale_factor: types.Float = 1.0,
target_policy_noise: types.Float = 0.2,
target_policy_noise_clip: types.Float = 0.5,
gradient_clipping: Optional[types.Float] = None,
debug_summaries: bool = False,
summarize_grads_and_vars: bool = False,
train_step_counter: Optional[tf.Variable] = None,
name: Optional[Text] = None):
"""Creates a Td3Agent Agent.
Args:
time_step_spec: A `TimeStep` spec of the expected time_steps.
action_spec: A nest of BoundedTensorSpec representing the actions.
actor_network: A tf_agents.network.Network to be used by the agent. The
network will be called with call(observation, step_type).
critic_network: A tf_agents.network.Network to be used by the agent. The
network will be called with call(observation, action, step_type).
actor_optimizer: The default optimizer to use for the actor network.
critic_optimizer: The default optimizer to use for the critic network.
exploration_noise_std: Scale factor on exploration policy noise.
critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as
the second critic network during Q learning. The weights from
`critic_network` are copied if this is not provided.
target_actor_network: (Optional.) A `tf_agents.network.Network` to be
used as the target actor network during Q learning. Every
`target_update_period` train steps, the weights from `actor_network` are
copied (possibly withsmoothing via `target_update_tau`) to `
target_actor_network`. If `target_actor_network` is not provided, it is
created by making a copy of `actor_network`, which initializes a new
network with the same structure and its own layers and weights.
Performing a `Network.copy` does not work when the network instance
already has trainable parameters (e.g., has already been built, or when
the network is sharing layers with another). In these cases, it is up
to you to build a copy having weights that are not shared with the
original `actor_network`, so that this can be used as a target network.
If you provide a `target_actor_network` that shares any weights with
`actor_network`, a warning will be logged but no exception is thrown.
target_critic_network: (Optional.) Similar network as target_actor_network
but for the critic_network. See documentation for target_actor_network.
target_critic_network_2: (Optional.) Similar network as
target_actor_network but for the critic_network_2. See documentation for
target_actor_network. Will only be used if 'critic_network_2' is also
specified.
target_update_tau: Factor for soft update of the target networks.
target_update_period: Period for soft update of the target networks.
actor_update_period: Period for the optimization step on actor network.
td_errors_loss_fn: A function for computing the TD errors loss. If None,
a default value of elementwise huber_loss is used.
gamma: A discount factor for future rewards.
reward_scale_factor: Multiplicative scale for the reward.
target_policy_noise: Scale factor on target action noise
target_policy_noise_clip: Value to clip noise.
gradient_clipping: Norm length to clip gradients.
debug_summaries: A bool to gather debug summaries.
summarize_grads_and_vars: If True, gradient and network variable summaries
will be written during training.
train_step_counter: An optional counter to increment every time the train
op is run. Defaults to the global_step.
name: The name of this agent. All variables in this module will fall
under that name. Defaults to the class name.
"""
tf.Module.__init__(self, name=name)
self._actor_network = actor_network
actor_network.create_variables()
if target_actor_network:
target_actor_network.create_variables()
self._target_actor_network = common.maybe_copy_target_network_with_checks(
self._actor_network, target_actor_network, 'TargetActorNetwork')
self._critic_network_1 = critic_network
critic_network.create_variables()
if target_critic_network:
target_critic_network.create_variables()
self._target_critic_network_1 = (
common.maybe_copy_target_network_with_checks(self._critic_network_1,
target_critic_network,
'TargetCriticNetwork1'))
if critic_network_2 is not None:
self._critic_network_2 = critic_network_2
else:
self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
# Do not use target_critic_network_2 if critic_network_2 is None.
target_critic_network_2 = None
self._critic_network_2.create_variables()
if target_critic_network_2:
target_critic_network_2.create_variables()
self._target_critic_network_2 = (
common.maybe_copy_target_network_with_checks(self._critic_network_2,
target_critic_network_2,
'TargetCriticNetwork2'))
self._actor_optimizer = actor_optimizer
self._critic_optimizer = critic_optimizer
self._exploration_noise_std = exploration_noise_std
self._target_update_tau = target_update_tau
self._target_update_period = target_update_period
self._actor_update_period = actor_update_period
self._td_errors_loss_fn = (
td_errors_loss_fn or common.element_wise_huber_loss)
self._gamma = gamma
self._reward_scale_factor = reward_scale_factor
self._target_policy_noise = target_policy_noise
self._target_policy_noise_clip = target_policy_noise_clip
self._gradient_clipping = gradient_clipping
self._update_target = self._get_target_updater(
target_update_tau, target_update_period)
policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec, action_spec=action_spec,
actor_network=self._actor_network, clip=True)
collect_policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec, action_spec=action_spec,
actor_network=self._actor_network, clip=False)
collect_policy = gaussian_policy.GaussianPolicy(
collect_policy,
scale=self._exploration_noise_std,
clip=True)
train_sequence_length = 2 if not self._actor_network.state_spec else None
super(Td3Agent, self).__init__(
time_step_spec,
action_spec,
policy,
collect_policy,
train_sequence_length=train_sequence_length,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
train_step_counter=train_step_counter,
)
self._as_transition = data_converter.AsTransition(
self.data_context, squeeze_time_dim=(train_sequence_length == 2))