def rollout()

in MTRF/algorithms/softlearning/samplers/utils.py [0:0]


def rollout(env,
            policy,
            path_length,
            sampler_class=simple_sampler.SimpleSampler,
            algorithm=None,
            extra_fields=None,
            sampler_kwargs=None,
            callback=None,
            render_kwargs=None,
            break_on_terminal=True):
    pool = replay_pools.SimpleReplayPool(
        env, 
        extra_fields=extra_fields, 
        max_size=path_length)

    if sampler_kwargs:
        sampler = sampler_class(
            max_path_length=path_length,
            min_pool_size=None,
            batch_size=None,
            **sampler_kwargs)
    else:
        sampler = sampler_class(
            max_path_length=path_length,
            min_pool_size=None,
            batch_size=None)

    if hasattr(sampler, 'set_algorithm'):
        sampler.set_algorithm(algorithm)

    sampler.initialize(env, policy, pool)

    render_mode = (render_kwargs or {}).get('mode', None)
    if render_mode == 'rgb_array':
        render_kwargs = {
            **DEFAULT_PIXEL_RENDER_KWARGS,
            **render_kwargs
        }
    elif render_mode == 'human':
        render_kwargs = {
            **DEFAULT_HUMAN_RENDER_KWARGS,
            **render_kwargs
        }
    else:
        render_kwargs = None

    images = []
    infos = defaultdict(list)
    t = 0
    for t in range(path_length):
        observation, reward, terminal, info = sampler.sample()
        for key, value in info.items():
            infos[key].append(value)

        if callback is not None:
            callback(observation)

        if render_kwargs:
            if render_mode == 'rgb_array':
                #note: this will only work for mujoco-py environments
                if hasattr(env.unwrapped, 'imsize'):
                    imsize = env.unwrapped.imsize
                else:
                    imsize = 200

                imsize_flat = imsize*imsize*3
                #for goal conditioned stuff
                #if observation['observations'].shape[0] == 2*imsize_flat:
                #    image1 = observation['observations'][:imsize_flat].reshape(48,48,3)
                #    image2 = observation['observations'][imsize_flat:].reshape(48,48,3)
                #    image1 = (image1*255.0).astype(np.uint8)
                #    image2 = (image2*255.0).astype(np.uint8)
                #    image = np.concatenate([image1, image2], axis=1)

                if 'pixels' in observation.keys() and observation['pixels'].shape[-1] == 6:
                    pixels = observation['pixels']
                    image1 = pixels[:, :, :3]
                    image2 = pixels[:, :, 3:]
                    image = np.concatenate([image1, image2], axis=1)
                else:
                    image = env.render(**render_kwargs)
                images.append(image)
            else:
                image = env.render(**render_kwargs)
                images.append(image)

        if terminal:
            policy.reset()
            if break_on_terminal: break

    assert pool._size == t + 1

    path = pool.batch_by_indices(np.arange(pool._size))
    path['infos'] = infos

    if render_mode == 'rgb_array':
        path['images'] = np.stack(images, axis=0)

    return path