def __init__()

in hanabi-learning-environment/agents/rainbow/dqn_agent.py [0:0]


  def __init__(self,
               num_actions=None,
               observation_size=None,
               num_players=None,
               gamma=0.99,
               update_horizon=1,
               min_replay_history=500,
               update_period=4,
               stack_size=1,
               target_update_period=500,
               epsilon_fn=linearly_decaying_epsilon,
               epsilon_train=0.02,
               epsilon_eval=0.001,
               epsilon_decay_period=1000,
               graph_template=dqn_template,
               tf_device='/cpu:*',
               use_staging=True,
               optimizer=tf.train.RMSPropOptimizer(
                   learning_rate=.0025,
                   decay=0.95,
                   momentum=0.0,
                   epsilon=1e-6,
                   centered=True)):
    """Initializes the agent and constructs its graph.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      observation_size: int, size of observation vector.
      num_players: int, number of players playing this game.
      gamma: float, discount factor as commonly used in the RL literature.
      update_horizon: int, horizon at which updates are performed, the 'n' in
        n-step update.
      min_replay_history: int, number of stored transitions before training.
      update_period: int, period between DQN updates.
      stack_size: int, number of observations to use as state.
      target_update_period: Update period for the target network.
      epsilon_fn: Function expecting 4 parameters: (decay_period, step,
        warmup_steps, epsilon), and which returns the epsilon value used for
        exploration during training.
      epsilon_train: float, final epsilon for training.
      epsilon_eval: float, epsilon during evaluation.
      epsilon_decay_period: int, number of steps for epsilon to decay.
      graph_template: function for building the neural network graph.
      tf_device: str, Tensorflow device on which to run computations.
      use_staging: bool, when True use a staging area to prefetch the next
        sampling batch.
      optimizer: Optimizer instance used for learning.
    """

    tf.logging.info('Creating %s agent with the following parameters:',
                    self.__class__.__name__)
    tf.logging.info('\t gamma: %f', gamma)
    tf.logging.info('\t update_horizon: %f', update_horizon)
    tf.logging.info('\t min_replay_history: %d', min_replay_history)
    tf.logging.info('\t update_period: %d', update_period)
    tf.logging.info('\t target_update_period: %d', target_update_period)
    tf.logging.info('\t epsilon_train: %f', epsilon_train)
    tf.logging.info('\t epsilon_eval: %f', epsilon_eval)
    tf.logging.info('\t epsilon_decay_period: %d', epsilon_decay_period)
    tf.logging.info('\t tf_device: %s', tf_device)
    tf.logging.info('\t use_staging: %s', use_staging)
    tf.logging.info('\t optimizer: %s', optimizer)

    # Global variables.
    self.num_actions = num_actions
    self.observation_size = observation_size
    self.num_players = num_players
    self.gamma = gamma
    self.update_horizon = update_horizon
    self.cumulative_gamma = math.pow(gamma, update_horizon)
    self.min_replay_history = min_replay_history
    self.target_update_period = target_update_period
    self.epsilon_fn = epsilon_fn
    self.epsilon_train = epsilon_train
    self.epsilon_eval = epsilon_eval
    self.epsilon_decay_period = epsilon_decay_period
    self.update_period = update_period
    self.eval_mode = False
    self.training_steps = 0
    self.batch_staged = False
    self.optimizer = optimizer

    with tf.device(tf_device):
      # Calling online_convnet will generate a new graph as defined in
      # graph_template using whatever input is passed, but will always share
      # the same weights.
      online_convnet = tf.make_template('Online', graph_template)
      target_convnet = tf.make_template('Target', graph_template)
      # The state of the agent. The last axis is the number of past observations
      # that make up the state.
      states_shape = (1, observation_size, stack_size)
      self.state = np.zeros(states_shape)
      self.state_ph = tf.placeholder(tf.uint8, states_shape, name='state_ph')
      self.legal_actions_ph = tf.placeholder(tf.float32,
                                             [self.num_actions],
                                             name='legal_actions_ph')
      self._q = online_convnet(
          state=self.state_ph, num_actions=self.num_actions)
      self._replay = self._build_replay_memory(use_staging)
      self._replay_qs = online_convnet(self._replay.states, self.num_actions)
      self._replay_next_qt = target_convnet(self._replay.next_states,
                                            self.num_actions)
      self._train_op = self._build_train_op()
      self._sync_qt_ops = self._build_sync_op()

      self._q_argmax = tf.argmax(self._q + self.legal_actions_ph, axis=1)[0]

    # Set up a session and initialize variables.
    self._sess = tf.Session(
        '', config=tf.ConfigProto(allow_soft_placement=True))
    self._init_op = tf.global_variables_initializer()
    self._sess.run(self._init_op)

    self._saver = tf.train.Saver(max_to_keep=3)

    # This keeps tracks of the observed transitions during play, for each
    # player.
    self.transitions = [[] for _ in range(num_players)]