in tf_agents/agents/qtopt/qtopt_agent.py [0:0]
def __init__(
self,
time_step_spec,
action_spec,
q_network,
optimizer,
actions_sampler,
epsilon_greedy=0.1,
n_step_update=1,
emit_log_probability=False,
in_graph_bellman_update=True,
# Params for cem
init_mean_cem=None,
init_var_cem=None,
num_samples_cem=32,
num_elites_cem=4,
num_iter_cem=3,
# Params for target network updates
target_q_network=None,
target_update_tau=1.0,
target_update_period=1,
enable_td3=True,
target_q_network_delayed=None,
target_q_network_delayed_2=None,
delayed_target_update_period=5,
# Params for training.
td_errors_loss_fn=None,
auxiliary_loss_fns=None,
gamma=1.0,
reward_scale_factor=1.0,
gradient_clipping=None,
# Params for debugging
debug_summaries=False,
summarize_grads_and_vars=False,
train_step_counter=None,
info_spec=None,
name=None):
"""Creates a Qtopt Agent.
Args:
time_step_spec: A `TimeStep` spec of the expected time_steps.
action_spec: A nest of BoundedTensorSpec representing the actions.
q_network: A tf_agents.network.Network to be used by the agent. The
network will be called with call((observation, action), step_type). The
q_network is different from the one used in DQN where the input is state
and the output has multiple dimension representing Q values for
different actions. The input of this q_network is a tuple of state and
action. The output is one dimension representing Q value for that
specific action. DDPG critic network can be used directly here.
optimizer: The optimizer to use for training.
actions_sampler: A tf_agents.policies.sampler.ActionsSampler to be used to
sample actions in CEM.
epsilon_greedy: probability of choosing a random action in the default
epsilon-greedy collect policy (used only if a wrapper is not provided to
the collect_policy method).
n_step_update: Currently, only n_step_update == 1 is supported.
emit_log_probability: Whether policies emit log probabilities or not.
in_graph_bellman_update: If False, configures the agent to expect
experience containing computed q_values in the policy_step's info field.
This allows simplifies splitting the loss calculation across several
jobs.
init_mean_cem: Initial mean value of the Gaussian distribution to sample
actions for CEM.
init_var_cem: Initial variance value of the Gaussian distribution to
sample actions for CEM.
num_samples_cem: Number of samples to sample for each iteration in CEM.
num_elites_cem: Number of elites to select for each iteration in CEM.
num_iter_cem: Number of iterations in CEM.
target_q_network: (Optional.) A `tf_agents.network.Network`
to be used as the target network during Q learning. Every
`target_update_period` train steps, the weights from
`q_network` are copied (possibly with smoothing via
`target_update_tau`) to `target_q_network`.
If `target_q_network` is not provided, it is created by
making a copy of `q_network`, which initializes a new
network with the same structure and its own layers and weights.
Network copying is performed via the `Network.copy` superclass method,
with the same arguments used during the original network's construction
and may inadvertently lead to weights being shared between networks.
This can happen if, for example, the original
network accepted a pre-built Keras layer in its `__init__`, or
accepted a Keras layer that wasn't built, but neglected to create
a new copy.
In these cases, it is up to you to provide a target Network having
weights that are not shared with the original `q_network`.
If you provide a `target_q_network` that shares any
weights with `q_network`, an exception is thrown.
target_update_tau: Factor for soft update of the target networks.
target_update_period: Period for soft update of the target networks.
enable_td3: Whether or not to enable using a delayed target network to
calculate q value and assign min(q_delayed, q_delayed_2) as
q_next_state.
target_q_network_delayed: (Optional.) Similar network as
target_q_network but lags behind even more. See documentation
for target_q_network. Will only be used if 'enable_td3' is True.
target_q_network_delayed_2: (Optional.) Similar network as
target_q_network_delayed but lags behind even more. See documentation
for target_q_network. Will only be used if 'enable_td3' is True.
delayed_target_update_period: Used when enable_td3 is true. Period for
soft update of the delayed target networks.
td_errors_loss_fn: A function for computing the TD errors loss. If None, a
default value of element_wise_huber_loss is used. This function takes as
input the target and the estimated Q values and returns the loss for
each element of the batch.
auxiliary_loss_fns: An optional list of functions for computing auxiliary
losses. Each auxiliary_loss_fn expects network and transition as
input and should output auxiliary_loss and auxiliary_reg_loss.
gamma: A discount factor for future rewards.
reward_scale_factor: Multiplicative scale for the reward.
gradient_clipping: Norm length to clip gradients.
debug_summaries: A bool to gather debug summaries.
summarize_grads_and_vars: If True, gradient and network variable summaries
will be written during training.
train_step_counter: An optional counter to increment every time the train
op is run. Defaults to the global_step.
info_spec: If not None, the policy info spec is set to this spec.
name: The name of this agent. All variables in this module will fall under
that name. Defaults to the class name.
Raises:
ValueError: If the action spec contains more than one action or action
spec minimum is not equal to 0.
NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an
RNN is provided) and `n_step_update > 1`.
"""
tf.Module.__init__(self, name=name)
self._sampler = actions_sampler
self._init_mean_cem = init_mean_cem
self._init_var_cem = init_var_cem
self._num_samples_cem = num_samples_cem
self._num_elites_cem = num_elites_cem
self._num_iter_cem = num_iter_cem
self._in_graph_bellman_update = in_graph_bellman_update
if not in_graph_bellman_update:
if info_spec is not None:
self._info_spec = info_spec
else:
self._info_spec = {
'target_q': tensor_spec.TensorSpec((), tf.float32),
}
else:
self._info_spec = ()
self._q_network = q_network
net_observation_spec = (time_step_spec.observation, action_spec)
q_network.create_variables(net_observation_spec)
if target_q_network:
target_q_network.create_variables(net_observation_spec)
self._target_q_network = common.maybe_copy_target_network_with_checks(
self._q_network, target_q_network, input_spec=net_observation_spec,
name='TargetQNetwork')
self._target_updater = self._get_target_updater(target_update_tau,
target_update_period)
self._enable_td3 = enable_td3
if (not self._enable_td3 and
(target_q_network_delayed or target_q_network_delayed_2)):
raise ValueError('enable_td3 is set to False but target_q_network_delayed'
' or target_q_network_delayed_2 is passed.')
if self._enable_td3:
if target_q_network_delayed:
target_q_network_delayed.create_variables()
self._target_q_network_delayed = (
common.maybe_copy_target_network_with_checks(
self._q_network, target_q_network_delayed,
'TargetQNetworkDelayed'))
self._target_updater_delayed = self._get_target_updater_delayed(
1.0, delayed_target_update_period)
if target_q_network_delayed_2:
target_q_network_delayed_2.create_variables()
self._target_q_network_delayed_2 = (
common.maybe_copy_target_network_with_checks(
self._q_network, target_q_network_delayed_2,
'TargetQNetworkDelayed2'))
self._target_updater_delayed_2 = self._get_target_updater_delayed_2(
1.0, delayed_target_update_period)
self._update_target = self._update_both
else:
self._update_target = self._target_updater
self._target_q_network_delayed = None
self._target_q_network_delayed_2 = None
self._check_network_output(self._q_network, 'q_network')
self._check_network_output(self._target_q_network, 'target_q_network')
self._epsilon_greedy = epsilon_greedy
self._n_step_update = n_step_update
self._optimizer = optimizer
self._td_errors_loss_fn = (
td_errors_loss_fn or common.element_wise_huber_loss)
self._auxiliary_loss_fns = auxiliary_loss_fns
self._gamma = gamma
self._reward_scale_factor = reward_scale_factor
self._gradient_clipping = gradient_clipping
policy, collect_policy = self._setup_policy(time_step_spec, action_spec,
emit_log_probability)
if q_network.state_spec and n_step_update != 1:
raise NotImplementedError(
'QtOptAgent does not currently support n-step updates with stateful '
'networks (i.e., RNNs), but n_step_update = {}'.format(n_step_update))
# Bypass the train_sequence_length check when RNN is used.
train_sequence_length = (
n_step_update + 1 if not q_network.state_spec else None)
super(QtOptAgent, self).__init__(
time_step_spec,
action_spec,
policy,
collect_policy,
train_sequence_length=train_sequence_length,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
train_step_counter=train_step_counter,
)
self._setup_data_converter(q_network, gamma, n_step_update)