in agents/phyre_simulator.py [0:0]
def _roll_fwd_obs(self, task_indices, actions, roll_fwd_ratio, nframes):
"""
Use the simulator to roll forward the simulation (for that action) and
return the output. To get an upper bound on how fwd might work.
Args:
task_indices: This is the *integer* indexes into the IDs that the
simulator is initialized with. So if the simulator was init
with 3 tasks ['0001:224', '...' , '...'], then 0 will run the
0th task -- 0001:224.
actions: The actual actions you want to take in that task. Must a
cpu numpy matrix.
simulator: The initialized simulator with the task IDs.
roll_fwd_ratio: 0-1 for where to pick the start point.
-1 => pick randomly such that nframes is satisfied. In that case
nframes can not be -1 as well.
nframes (T): -1=> all the frames following that point
Returns:
new_observations: [TxHxW] -- the rolled out video for each batch
elt (So total B elements in the list). We don't make a np.array
of it since each might have diff number of frames, dep on how
much it rolled out at test/train time etc.
"""
assert 0.0 <= roll_fwd_ratio < 1.0 or roll_fwd_ratio == -1, 'Limits'
new_observations = []
new_obj_observations = []
for i, task_idx in enumerate(task_indices):
kwargs = dict(
need_images=True,
need_featurized_objects=True,
stride=self.stride,
perturb_step=-1,
stop_after_solved=self.stop_after_solved,
)
# If the simulation has to start at the beginning and for fixed set
# of frames, then might as well specify that and make it run faster.
# Will specially help the test time performance.
if roll_fwd_ratio == 0 and nframes != -1:
frames_to_ask = nframes
if self.prepend_empty_frames < 0:
# This many frames will be dropped, so ask for those many
# extra
frames_to_ask += (-self.prepend_empty_frames)
kwargs.update({'nframes': frames_to_ask * self.stride})
simulation = self.simulator.simulate_action(
task_idx, actions[i], **kwargs)
images = simulation.images
objs = None
if simulation.featurized_objects is not None:
objs = simulation.featurized_objects.features
if images is None:
# This means the action was invalid. This should only happen at
# test time when we don't filter for the valid actions only.
# For now, returning the original obs as a single frame clip
images = np.expand_dims(
self.simulator.initial_scenes[task_idx], 0)
objs = self.simulator.initial_featurized_objects[
task_idx].features
if 'nframes' in kwargs:
images = np.tile(images,
(kwargs['nframes'] // self.stride, 1, 1))
objs = np.tile(objs,
(kwargs['nframes'] // self.stride, 1, 1))
if self.prepend_empty_frames > 0:
empty_frames = np.zeros(
(self.prepend_empty_frames, ) + images.shape[1:],
dtype=images.dtype)
images = np.concatenate([empty_frames, images], axis=0)
empty_obj_frames = np.zeros(
(self.prepend_empty_frames, ) + objs.shape[1:],
dtype=images.dtype)
objs = np.concatenate([empty_obj_frames, objs], axis=0)
elif self.prepend_empty_frames < 0:
# Drop these many frames from the beginning
assert images.shape[0] > -self.prepend_empty_frames
assert objs.shape[0] > -self.prepend_empty_frames
images = images[-self.prepend_empty_frames:]
objs = objs[-self.prepend_empty_frames:]
# Remove this many frames from the beginning of the
# To debug/visualize
# if images is None:
# T = phyre.vis.observations_to_uint8_rgb(
# observations[i].cpu().numpy())
# import matplotlib.pyplot as plt
# plt.imsave('/private/home/rgirdhar/temp/prev.jpg', T)
# phyre.vis.save_observation_series_to_gif(
# [images.tolist()], '/private/home/rgirdhar/temp/prev.gif')
if roll_fwd_ratio == -1:
assert nframes != -1, 'Cant pick start point randomly...'
this_roll_fwd_ratio = max(
np.random.random() * (1 - (nframes / images.shape[0])), 0)
else:
this_roll_fwd_ratio = roll_fwd_ratio
split_pt = int(images.shape[0] * this_roll_fwd_ratio)
if nframes == -1: # select all following images
this_nframes = images.shape[0] - split_pt
else:
this_nframes = nframes
clip = images[split_pt:split_pt + this_nframes, ...]
obj_clip = objs[split_pt:split_pt + this_nframes, ...]
# Pad with the last frame repeated if nframes too less
if this_nframes not in [-1, clip.shape[0]]:
assert self.stop_after_solved, (
f'If stop_after_solved is False, then it should always '
f'return enough frames! Returned clip of shape '
f'{clip.shape} while expected {this_nframes}')
logging.debug('Have to pad with %d frames to meet %d nframes',
this_nframes - clip.shape[0], this_nframes)
clip = np.concatenate([
clip,
np.tile(clip[-1:, ...],
[this_nframes - clip.shape[0], 1, 1])
], 0)
obj_clip = np.concatenate([
obj_clip,
np.tile(obj_clip[-1:, ...],
[this_nframes - obj_clip.shape[0], 1, 1])
], 0)
# Add the channel dimension
new_observations.append(clip)
new_obj_observations.append(obj_clip)
return new_observations, new_obj_observations