def run()

in tf_agents/environments/trajectory_replay.py [0:0]


  def run(self, trajectory, policy_state=None):
    """Apply the policy to trajectory steps and store actions/info.

    If `self.time_major == True`, the tensors in `trajectory` are assumed to
    have shape `[time, batch, ...]`.  Otherwise they are assumed to
    have shape `[batch, time, ...]`.

    Args:
      trajectory: The `Trajectory` to run against.
        If the replay class was created with `time_major=True`, then
        the tensors in trajectory must be shaped `[time, batch, ...]`.
        Otherwise they must be shaped `[batch, time, ...]`.
      policy_state: (optional) A nest Tensor with initial step policy state.

    Returns:
      output_actions: A nest of the actions that the policy took.
        If the replay class was created with `time_major=True`, then
        the tensors here will be shaped `[time, batch, ...]`.  Otherwise
        they'll be shaped `[batch, time, ...]`.
      output_policy_info: A nest of the policy info that the policy emitted.
        If the replay class was created with `time_major=True`, then
        the tensors here will be shaped `[time, batch, ...]`.  Otherwise
        they'll be shaped `[batch, time, ...]`.
      policy_state: A nest Tensor with final step policy state.

    Raises:
      TypeError: If `policy_state` structure doesn't match
        `self.policy.policy_state_spec`, or `trajectory` structure doesn't
        match `self.policy.trajectory_spec`.
      ValueError: If `policy_state` doesn't match
        `self.policy.policy_state_spec`, or `trajectory` structure doesn't
        match `self.policy.trajectory_spec`.
      ValueError: If `trajectory` lacks two outer dims.
    """
    trajectory_spec = self._policy.trajectory_spec
    outer_dims = nest_utils.get_outer_shape(trajectory, trajectory_spec)

    if tf.compat.dimension_value(outer_dims.shape[0]) != 2:
      raise ValueError(
          "Expected two outer dimensions, but saw '{}' dimensions.\n"
          "Trajectory:\n{}.\nTrajectory spec from policy:\n{}.".format(
              tf.compat.dimension_value(outer_dims.shape[0]), trajectory,
              trajectory_spec))
    if self._time_major:
      sequence_length = outer_dims[0]
      batch_size = outer_dims[1]
      static_batch_size = tf.compat.dimension_value(
          trajectory.discount.shape[1])
    else:
      batch_size = outer_dims[0]
      sequence_length = outer_dims[1]
      static_batch_size = tf.compat.dimension_value(
          trajectory.discount.shape[0])

    if policy_state is None:
      policy_state = self._policy.get_initial_state(batch_size)
    else:
      nest_utils.assert_same_structure(policy_state,
                                       self._policy.policy_state_spec)

    if not self._time_major:
      # Make trajectory time-major.
      trajectory = tf.nest.map_structure(common.transpose_batch_time,
                                         trajectory)

    trajectory_tas = tf.nest.map_structure(
        lambda t: tf.TensorArray(t.dtype, size=sequence_length).unstack(t),
        trajectory)

    def create_output_ta(spec):
      return tf.TensorArray(
          spec.dtype, size=sequence_length,
          element_shape=(tf.TensorShape([static_batch_size])
                         .concatenate(spec.shape)))

    output_action_tas = tf.nest.map_structure(create_output_ta,
                                              trajectory_spec.action)
    output_policy_info_tas = tf.nest.map_structure(create_output_ta,
                                                   trajectory_spec.policy_info)

    read0 = lambda ta: ta.read(0)
    zeros_like0 = lambda t: tf.zeros_like(t[0])
    ones_like0 = lambda t: tf.ones_like(t[0])
    time_step = ts.TimeStep(
        step_type=read0(trajectory_tas.step_type),
        reward=tf.nest.map_structure(zeros_like0, trajectory.reward),
        discount=ones_like0(trajectory.discount),
        observation=tf.nest.map_structure(read0, trajectory_tas.observation))

    def process_step(time, time_step, policy_state,
                     output_action_tas, output_policy_info_tas):
      """Take an action on the given step, and update output TensorArrays.

      Args:
        time: Step time.  Describes which row to read from the trajectory
          TensorArrays and which location to write into in the output
          TensorArrays.
        time_step: Previous step's `TimeStep`.
        policy_state: Policy state tensor or nested structure of tensors.
        output_action_tas: Nest of `tf.TensorArray` containing new actions.
        output_policy_info_tas: Nest of `tf.TensorArray` containing new
          policy info.

      Returns:
        policy_state: The next policy state.
        next_output_action_tas: Updated `output_action_tas`.
        next_output_policy_info_tas: Updated `output_policy_info_tas`.
      """
      action_step = self._policy.action(time_step, policy_state)
      policy_state = action_step.state
      write_ta = lambda ta, t: ta.write(time - 1, t)
      next_output_action_tas = tf.nest.map_structure(
          write_ta, output_action_tas, action_step.action)
      next_output_policy_info_tas = tf.nest.map_structure(
          write_ta, output_policy_info_tas, action_step.info)

      return (action_step.state,
              next_output_action_tas,
              next_output_policy_info_tas)

    def loop_body(time, time_step, policy_state,
                  output_action_tas, output_policy_info_tas):
      """Runs a step in environment.

      While loop will call multiple times.

      Args:
        time: Step time.
        time_step: Previous step's `TimeStep`.
        policy_state: Policy state tensor or nested structure of tensors.
        output_action_tas: Updated nest of `tf.TensorArray`, the new actions.
        output_policy_info_tas: Updated nest of `tf.TensorArray`, the new
          policy info.

      Returns:
        loop_vars for next iteration of tf.while_loop.
      """
      policy_state, next_output_action_tas, next_output_policy_info_tas = (
          process_step(time, time_step, policy_state,
                       output_action_tas,
                       output_policy_info_tas))

      ta_read = lambda ta: ta.read(time)
      ta_read_prev = lambda ta: ta.read(time - 1)
      time_step = ts.TimeStep(
          step_type=ta_read(trajectory_tas.step_type),
          observation=tf.nest.map_structure(ta_read,
                                            trajectory_tas.observation),
          reward=tf.nest.map_structure(ta_read_prev, trajectory_tas.reward),
          discount=ta_read_prev(trajectory_tas.discount))

      return (time + 1, time_step, policy_state,
              next_output_action_tas, next_output_policy_info_tas)

    time = tf.constant(1)
    time, time_step, policy_state, output_action_tas, output_policy_info_tas = (
        tf.while_loop(
            cond=lambda time, *_: time < sequence_length,
            body=loop_body,
            loop_vars=[time, time_step, policy_state,
                       output_action_tas, output_policy_info_tas],
            back_prop=False,
            name="trajectory_replay_loop"))

    # Run the last time step
    last_policy_state, output_action_tas, output_policy_info_tas = (
        process_step(time, time_step, policy_state,
                     output_action_tas, output_policy_info_tas))

    def stack_ta(ta):
      t = ta.stack()
      if not self._time_major:
        t = common.transpose_batch_time(t)
      return t

    stacked_output_actions = tf.nest.map_structure(stack_ta, output_action_tas)
    stacked_output_policy_info = tf.nest.map_structure(stack_ta,
                                                       output_policy_info_tas)

    return (stacked_output_actions,
            stacked_output_policy_info,
            last_policy_state)