def __init__()

in hanabi-learning-environment/agents/rainbow/replay_memory.py [0:0]


  def __init__(self,
               num_actions,
               observation_size,
               stack_size,
               use_staging=True,
               replay_capacity=1000000,
               batch_size=32,
               update_horizon=1,
               gamma=1.0,
               wrapped_memory=None):
    """Initializes a graph wrapper for the python replay memory.

    Args:
      num_actions: int, number of possible actions.
      observation_size: int, size of an input frame.
      stack_size: int, number of frames to use in state stack.
      use_staging: bool, when True it would use a staging area to prefetch the
        next sampling batch.
      replay_capacity: int, number of transitions to keep in memory.
      batch_size: int.
      update_horizon: int, length of update ('n' in n-step update).
      gamma: int, the discount factor.
      wrapped_memory: The 'inner' memory data structure. Defaults to None, which
        creates the standard DQN replay memory.

    Raises:
      ValueError: If update_horizon is not positive.
      ValueError: If discount factor is not in [0, 1].
    """
    if replay_capacity < update_horizon + 1:
      raise ValueError('Update horizon (%i) should be significantly smaller '
                       'than replay capacity (%i).'
                       % (update_horizon, replay_capacity))
    if not update_horizon >= 1:
      raise ValueError('Update horizon must be positive.')
    if not 0.0 <= gamma <= 1.0:
      raise ValueError('Discount factor (gamma) must be in [0, 1].')

    # Allow subclasses to create self.memory.
    if wrapped_memory is not None:
      self.memory = wrapped_memory
    else:
      self.memory = OutOfGraphReplayMemory(
          num_actions, observation_size, stack_size,
          replay_capacity, batch_size, update_horizon, gamma)

    with tf.name_scope('replay'):
      with tf.name_scope('add_placeholders'):
        self.add_obs_ph = tf.placeholder(
            tf.uint8, [observation_size], name='add_obs_ph')
        self.add_action_ph = tf.placeholder(tf.int32, [], name='add_action_ph')
        self.add_reward_ph = tf.placeholder(
            tf.float32, [], name='add_reward_ph')
        self.add_terminal_ph = tf.placeholder(
            tf.uint8, [], name='add_terminal_ph')
        self.add_legal_actions_ph = tf.placeholder(
            tf.float32, [num_actions], name='add_legal_actions_ph')

      add_transition_ph = [
          self.add_obs_ph, self.add_action_ph, self.add_reward_ph,
          self.add_terminal_ph, self.add_legal_actions_ph
      ]

      with tf.device('/cpu:*'):
        self.add_transition_op = tf.py_func(
            self.memory.add, add_transition_ph, [], name='replay_add_py_func')

        self.transition = tf.py_func(
            self.memory.sample_transition_batch, [],
            [tf.uint8, tf.int32, tf.float32, tf.uint8, tf.uint8, tf.int32,
             tf.float32],
            name='replay_sample_py_func')

        if use_staging:
          # To hide the py_func latency use a staging area to pre-fetch the next
          # batch of transitions.
          (states, actions, rewards, next_states,
           terminals, indices, next_legal_actions) = self.transition
          # StagingArea requires all the shapes to be defined.
          states.set_shape([batch_size, observation_size, stack_size])
          actions.set_shape([batch_size])
          rewards.set_shape([batch_size])
          next_states.set_shape(
              [batch_size, observation_size, stack_size])
          terminals.set_shape([batch_size])
          indices.set_shape([batch_size])
          next_legal_actions.set_shape([batch_size, num_actions])

          # Create the staging area in CPU.
          prefetch_area = tf.contrib.staging.StagingArea(
              [tf.uint8, tf.int32, tf.float32, tf.uint8, tf.uint8, tf.int32,
               tf.float32])

          self.prefetch_batch = prefetch_area.put(
              (states, actions, rewards, next_states, terminals, indices,
               next_legal_actions))
        else:
          self.prefetch_batch = tf.no_op()

      if use_staging:
        # Get the sample_transition_batch in GPU. This would do the copy from
        # CPU to GPU.
        self.transition = prefetch_area.get()

      (self.states, self.actions, self.rewards, self.next_states,
       self.terminals, self.indices, self.next_legal_actions) = self.transition

      # Since these are py_func tensors, no information about their shape is
      # present. Setting the shape only for the necessary tensors
      self.states.set_shape([None, observation_size, stack_size])
      self.next_states.set_shape([None, observation_size, stack_size])