in hanabi-learning-environment/agents/rainbow/dqn_agent.py [0:0]
def __init__(self,
num_actions=None,
observation_size=None,
num_players=None,
gamma=0.99,
update_horizon=1,
min_replay_history=500,
update_period=4,
stack_size=1,
target_update_period=500,
epsilon_fn=linearly_decaying_epsilon,
epsilon_train=0.02,
epsilon_eval=0.001,
epsilon_decay_period=1000,
graph_template=dqn_template,
tf_device='/cpu:*',
use_staging=True,
optimizer=tf.train.RMSPropOptimizer(
learning_rate=.0025,
decay=0.95,
momentum=0.0,
epsilon=1e-6,
centered=True)):
"""Initializes the agent and constructs its graph.
Args:
num_actions: int, number of actions the agent can take at any state.
observation_size: int, size of observation vector.
num_players: int, number of players playing this game.
gamma: float, discount factor as commonly used in the RL literature.
update_horizon: int, horizon at which updates are performed, the 'n' in
n-step update.
min_replay_history: int, number of stored transitions before training.
update_period: int, period between DQN updates.
stack_size: int, number of observations to use as state.
target_update_period: Update period for the target network.
epsilon_fn: Function expecting 4 parameters: (decay_period, step,
warmup_steps, epsilon), and which returns the epsilon value used for
exploration during training.
epsilon_train: float, final epsilon for training.
epsilon_eval: float, epsilon during evaluation.
epsilon_decay_period: int, number of steps for epsilon to decay.
graph_template: function for building the neural network graph.
tf_device: str, Tensorflow device on which to run computations.
use_staging: bool, when True use a staging area to prefetch the next
sampling batch.
optimizer: Optimizer instance used for learning.
"""
tf.logging.info('Creating %s agent with the following parameters:',
self.__class__.__name__)
tf.logging.info('\t gamma: %f', gamma)
tf.logging.info('\t update_horizon: %f', update_horizon)
tf.logging.info('\t min_replay_history: %d', min_replay_history)
tf.logging.info('\t update_period: %d', update_period)
tf.logging.info('\t target_update_period: %d', target_update_period)
tf.logging.info('\t epsilon_train: %f', epsilon_train)
tf.logging.info('\t epsilon_eval: %f', epsilon_eval)
tf.logging.info('\t epsilon_decay_period: %d', epsilon_decay_period)
tf.logging.info('\t tf_device: %s', tf_device)
tf.logging.info('\t use_staging: %s', use_staging)
tf.logging.info('\t optimizer: %s', optimizer)
# Global variables.
self.num_actions = num_actions
self.observation_size = observation_size
self.num_players = num_players
self.gamma = gamma
self.update_horizon = update_horizon
self.cumulative_gamma = math.pow(gamma, update_horizon)
self.min_replay_history = min_replay_history
self.target_update_period = target_update_period
self.epsilon_fn = epsilon_fn
self.epsilon_train = epsilon_train
self.epsilon_eval = epsilon_eval
self.epsilon_decay_period = epsilon_decay_period
self.update_period = update_period
self.eval_mode = False
self.training_steps = 0
self.batch_staged = False
self.optimizer = optimizer
with tf.device(tf_device):
# Calling online_convnet will generate a new graph as defined in
# graph_template using whatever input is passed, but will always share
# the same weights.
online_convnet = tf.make_template('Online', graph_template)
target_convnet = tf.make_template('Target', graph_template)
# The state of the agent. The last axis is the number of past observations
# that make up the state.
states_shape = (1, observation_size, stack_size)
self.state = np.zeros(states_shape)
self.state_ph = tf.placeholder(tf.uint8, states_shape, name='state_ph')
self.legal_actions_ph = tf.placeholder(tf.float32,
[self.num_actions],
name='legal_actions_ph')
self._q = online_convnet(
state=self.state_ph, num_actions=self.num_actions)
self._replay = self._build_replay_memory(use_staging)
self._replay_qs = online_convnet(self._replay.states, self.num_actions)
self._replay_next_qt = target_convnet(self._replay.next_states,
self.num_actions)
self._train_op = self._build_train_op()
self._sync_qt_ops = self._build_sync_op()
self._q_argmax = tf.argmax(self._q + self.legal_actions_ph, axis=1)[0]
# Set up a session and initialize variables.
self._sess = tf.Session(
'', config=tf.ConfigProto(allow_soft_placement=True))
self._init_op = tf.global_variables_initializer()
self._sess.run(self._init_op)
self._saver = tf.train.Saver(max_to_keep=3)
# This keeps tracks of the observed transitions during play, for each
# player.
self.transitions = [[] for _ in range(num_players)]