def __init__()

in tf_agents/agents/qtopt/qtopt_agent.py [0:0]


  def __init__(
      self,
      time_step_spec,
      action_spec,
      q_network,
      optimizer,
      actions_sampler,
      epsilon_greedy=0.1,
      n_step_update=1,
      emit_log_probability=False,
      in_graph_bellman_update=True,
      # Params for cem
      init_mean_cem=None,
      init_var_cem=None,
      num_samples_cem=32,
      num_elites_cem=4,
      num_iter_cem=3,
      # Params for target network updates
      target_q_network=None,
      target_update_tau=1.0,
      target_update_period=1,
      enable_td3=True,
      target_q_network_delayed=None,
      target_q_network_delayed_2=None,
      delayed_target_update_period=5,
      # Params for training.
      td_errors_loss_fn=None,
      auxiliary_loss_fns=None,
      gamma=1.0,
      reward_scale_factor=1.0,
      gradient_clipping=None,
      # Params for debugging
      debug_summaries=False,
      summarize_grads_and_vars=False,
      train_step_counter=None,
      info_spec=None,
      name=None):
    """Creates a Qtopt Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call((observation, action), step_type). The
        q_network is different from the one used in DQN where the input is state
        and the output has multiple dimension representing Q values for
        different actions. The input of this q_network is a tuple of state and
        action. The output is one dimension representing Q value for that
        specific action. DDPG critic network can be used directly here.
      optimizer: The optimizer to use for training.
      actions_sampler: A tf_agents.policies.sampler.ActionsSampler to be used to
        sample actions in CEM.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      n_step_update: Currently, only n_step_update == 1 is supported.
      emit_log_probability: Whether policies emit log probabilities or not.
      in_graph_bellman_update: If False, configures the agent to expect
        experience containing computed q_values in the policy_step's info field.
        This allows simplifies splitting the loss calculation across several
        jobs.
      init_mean_cem: Initial mean value of the Gaussian distribution to sample
        actions for CEM.
      init_var_cem: Initial variance value of the Gaussian distribution to
        sample actions for CEM.
      num_samples_cem: Number of samples to sample for each iteration in CEM.
      num_elites_cem: Number of elites to select for each iteration in CEM.
      num_iter_cem: Number of iterations in CEM.
      target_q_network: (Optional.)  A `tf_agents.network.Network`
        to be used as the target network during Q learning.  Every
        `target_update_period` train steps, the weights from
        `q_network` are copied (possibly with smoothing via
        `target_update_tau`) to `target_q_network`.

        If `target_q_network` is not provided, it is created by
        making a copy of `q_network`, which initializes a new
        network with the same structure and its own layers and weights.

        Network copying is performed via the `Network.copy` superclass method,
        with the same arguments used during the original network's construction
        and may inadvertently lead to weights being shared between networks.
        This can happen if, for example, the original
        network accepted a pre-built Keras layer in its `__init__`, or
        accepted a Keras layer that wasn't built, but neglected to create
        a new copy.

        In these cases, it is up to you to provide a target Network having
        weights that are not shared with the original `q_network`.
        If you provide a `target_q_network` that shares any
        weights with `q_network`, an exception is thrown.

      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      enable_td3: Whether or not to enable using a delayed target network to
        calculate q value and assign min(q_delayed, q_delayed_2) as
        q_next_state.
      target_q_network_delayed: (Optional.) Similar network as
        target_q_network but lags behind even more. See documentation
        for target_q_network. Will only be used if 'enable_td3' is True.
      target_q_network_delayed_2: (Optional.) Similar network as
        target_q_network_delayed but lags behind even more. See documentation
        for target_q_network. Will only be used if 'enable_td3' is True.
      delayed_target_update_period: Used when enable_td3 is true. Period for
        soft update of the delayed target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      auxiliary_loss_fns: An optional list of functions for computing auxiliary
        losses. Each auxiliary_loss_fn expects network and transition as
        input and should output auxiliary_loss and auxiliary_reg_loss.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      info_spec: If not None, the policy info spec is set to this spec.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.

    Raises:
      ValueError: If the action spec contains more than one action or action
        spec minimum is not equal to 0.
      NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an
        RNN is provided) and `n_step_update > 1`.
    """
    tf.Module.__init__(self, name=name)

    self._sampler = actions_sampler
    self._init_mean_cem = init_mean_cem
    self._init_var_cem = init_var_cem
    self._num_samples_cem = num_samples_cem
    self._num_elites_cem = num_elites_cem
    self._num_iter_cem = num_iter_cem
    self._in_graph_bellman_update = in_graph_bellman_update
    if not in_graph_bellman_update:
      if info_spec is not None:
        self._info_spec = info_spec
      else:
        self._info_spec = {
            'target_q': tensor_spec.TensorSpec((), tf.float32),
        }
    else:
      self._info_spec = ()

    self._q_network = q_network
    net_observation_spec = (time_step_spec.observation, action_spec)

    q_network.create_variables(net_observation_spec)

    if target_q_network:
      target_q_network.create_variables(net_observation_spec)

    self._target_q_network = common.maybe_copy_target_network_with_checks(
        self._q_network, target_q_network, input_spec=net_observation_spec,
        name='TargetQNetwork')

    self._target_updater = self._get_target_updater(target_update_tau,
                                                    target_update_period)

    self._enable_td3 = enable_td3

    if (not self._enable_td3 and
        (target_q_network_delayed or target_q_network_delayed_2)):
      raise ValueError('enable_td3 is set to False but target_q_network_delayed'
                       ' or target_q_network_delayed_2 is passed.')

    if self._enable_td3:
      if target_q_network_delayed:
        target_q_network_delayed.create_variables()
      self._target_q_network_delayed = (
          common.maybe_copy_target_network_with_checks(
              self._q_network, target_q_network_delayed,
              'TargetQNetworkDelayed'))
      self._target_updater_delayed = self._get_target_updater_delayed(
          1.0, delayed_target_update_period)

      if target_q_network_delayed_2:
        target_q_network_delayed_2.create_variables()
      self._target_q_network_delayed_2 = (
          common.maybe_copy_target_network_with_checks(
              self._q_network, target_q_network_delayed_2,
              'TargetQNetworkDelayed2'))
      self._target_updater_delayed_2 = self._get_target_updater_delayed_2(
          1.0, delayed_target_update_period)

      self._update_target = self._update_both
    else:
      self._update_target = self._target_updater
      self._target_q_network_delayed = None
      self._target_q_network_delayed_2 = None

    self._check_network_output(self._q_network, 'q_network')
    self._check_network_output(self._target_q_network, 'target_q_network')

    self._epsilon_greedy = epsilon_greedy
    self._n_step_update = n_step_update
    self._optimizer = optimizer
    self._td_errors_loss_fn = (
        td_errors_loss_fn or common.element_wise_huber_loss)
    self._auxiliary_loss_fns = auxiliary_loss_fns
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._gradient_clipping = gradient_clipping

    policy, collect_policy = self._setup_policy(time_step_spec, action_spec,
                                                emit_log_probability)

    if q_network.state_spec and n_step_update != 1:
      raise NotImplementedError(
          'QtOptAgent does not currently support n-step updates with stateful '
          'networks (i.e., RNNs), but n_step_update = {}'.format(n_step_update))

    # Bypass the train_sequence_length check when RNN is used.
    train_sequence_length = (
        n_step_update + 1 if not q_network.state_spec else None)

    super(QtOptAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=train_sequence_length,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter,
    )

    self._setup_data_converter(q_network, gamma, n_step_update)