in hanabi-learning-environment/agents/rainbow/replay_memory.py [0:0]
def __init__(self,
num_actions,
observation_size,
stack_size,
use_staging=True,
replay_capacity=1000000,
batch_size=32,
update_horizon=1,
gamma=1.0,
wrapped_memory=None):
"""Initializes a graph wrapper for the python replay memory.
Args:
num_actions: int, number of possible actions.
observation_size: int, size of an input frame.
stack_size: int, number of frames to use in state stack.
use_staging: bool, when True it would use a staging area to prefetch the
next sampling batch.
replay_capacity: int, number of transitions to keep in memory.
batch_size: int.
update_horizon: int, length of update ('n' in n-step update).
gamma: int, the discount factor.
wrapped_memory: The 'inner' memory data structure. Defaults to None, which
creates the standard DQN replay memory.
Raises:
ValueError: If update_horizon is not positive.
ValueError: If discount factor is not in [0, 1].
"""
if replay_capacity < update_horizon + 1:
raise ValueError('Update horizon (%i) should be significantly smaller '
'than replay capacity (%i).'
% (update_horizon, replay_capacity))
if not update_horizon >= 1:
raise ValueError('Update horizon must be positive.')
if not 0.0 <= gamma <= 1.0:
raise ValueError('Discount factor (gamma) must be in [0, 1].')
# Allow subclasses to create self.memory.
if wrapped_memory is not None:
self.memory = wrapped_memory
else:
self.memory = OutOfGraphReplayMemory(
num_actions, observation_size, stack_size,
replay_capacity, batch_size, update_horizon, gamma)
with tf.name_scope('replay'):
with tf.name_scope('add_placeholders'):
self.add_obs_ph = tf.placeholder(
tf.uint8, [observation_size], name='add_obs_ph')
self.add_action_ph = tf.placeholder(tf.int32, [], name='add_action_ph')
self.add_reward_ph = tf.placeholder(
tf.float32, [], name='add_reward_ph')
self.add_terminal_ph = tf.placeholder(
tf.uint8, [], name='add_terminal_ph')
self.add_legal_actions_ph = tf.placeholder(
tf.float32, [num_actions], name='add_legal_actions_ph')
add_transition_ph = [
self.add_obs_ph, self.add_action_ph, self.add_reward_ph,
self.add_terminal_ph, self.add_legal_actions_ph
]
with tf.device('/cpu:*'):
self.add_transition_op = tf.py_func(
self.memory.add, add_transition_ph, [], name='replay_add_py_func')
self.transition = tf.py_func(
self.memory.sample_transition_batch, [],
[tf.uint8, tf.int32, tf.float32, tf.uint8, tf.uint8, tf.int32,
tf.float32],
name='replay_sample_py_func')
if use_staging:
# To hide the py_func latency use a staging area to pre-fetch the next
# batch of transitions.
(states, actions, rewards, next_states,
terminals, indices, next_legal_actions) = self.transition
# StagingArea requires all the shapes to be defined.
states.set_shape([batch_size, observation_size, stack_size])
actions.set_shape([batch_size])
rewards.set_shape([batch_size])
next_states.set_shape(
[batch_size, observation_size, stack_size])
terminals.set_shape([batch_size])
indices.set_shape([batch_size])
next_legal_actions.set_shape([batch_size, num_actions])
# Create the staging area in CPU.
prefetch_area = tf.contrib.staging.StagingArea(
[tf.uint8, tf.int32, tf.float32, tf.uint8, tf.uint8, tf.int32,
tf.float32])
self.prefetch_batch = prefetch_area.put(
(states, actions, rewards, next_states, terminals, indices,
next_legal_actions))
else:
self.prefetch_batch = tf.no_op()
if use_staging:
# Get the sample_transition_batch in GPU. This would do the copy from
# CPU to GPU.
self.transition = prefetch_area.get()
(self.states, self.actions, self.rewards, self.next_states,
self.terminals, self.indices, self.next_legal_actions) = self.transition
# Since these are py_func tensors, no information about their shape is
# present. Setting the shape only for the necessary tensors
self.states.set_shape([None, observation_size, stack_size])
self.next_states.set_shape([None, observation_size, stack_size])