in MTRF/algorithms/softlearning/samplers/utils.py [0:0]
def rollout(env,
policy,
path_length,
sampler_class=simple_sampler.SimpleSampler,
algorithm=None,
extra_fields=None,
sampler_kwargs=None,
callback=None,
render_kwargs=None,
break_on_terminal=True):
pool = replay_pools.SimpleReplayPool(
env,
extra_fields=extra_fields,
max_size=path_length)
if sampler_kwargs:
sampler = sampler_class(
max_path_length=path_length,
min_pool_size=None,
batch_size=None,
**sampler_kwargs)
else:
sampler = sampler_class(
max_path_length=path_length,
min_pool_size=None,
batch_size=None)
if hasattr(sampler, 'set_algorithm'):
sampler.set_algorithm(algorithm)
sampler.initialize(env, policy, pool)
render_mode = (render_kwargs or {}).get('mode', None)
if render_mode == 'rgb_array':
render_kwargs = {
**DEFAULT_PIXEL_RENDER_KWARGS,
**render_kwargs
}
elif render_mode == 'human':
render_kwargs = {
**DEFAULT_HUMAN_RENDER_KWARGS,
**render_kwargs
}
else:
render_kwargs = None
images = []
infos = defaultdict(list)
t = 0
for t in range(path_length):
observation, reward, terminal, info = sampler.sample()
for key, value in info.items():
infos[key].append(value)
if callback is not None:
callback(observation)
if render_kwargs:
if render_mode == 'rgb_array':
#note: this will only work for mujoco-py environments
if hasattr(env.unwrapped, 'imsize'):
imsize = env.unwrapped.imsize
else:
imsize = 200
imsize_flat = imsize*imsize*3
#for goal conditioned stuff
#if observation['observations'].shape[0] == 2*imsize_flat:
# image1 = observation['observations'][:imsize_flat].reshape(48,48,3)
# image2 = observation['observations'][imsize_flat:].reshape(48,48,3)
# image1 = (image1*255.0).astype(np.uint8)
# image2 = (image2*255.0).astype(np.uint8)
# image = np.concatenate([image1, image2], axis=1)
if 'pixels' in observation.keys() and observation['pixels'].shape[-1] == 6:
pixels = observation['pixels']
image1 = pixels[:, :, :3]
image2 = pixels[:, :, 3:]
image = np.concatenate([image1, image2], axis=1)
else:
image = env.render(**render_kwargs)
images.append(image)
else:
image = env.render(**render_kwargs)
images.append(image)
if terminal:
policy.reset()
if break_on_terminal: break
assert pool._size == t + 1
path = pool.batch_by_indices(np.arange(pool._size))
path['infos'] = infos
if render_mode == 'rgb_array':
path['images'] = np.stack(images, axis=0)
return path