in rlkit/launchers/skewfit_experiments.py [0:0]
def generate_vae_dataset(variant):
env_class = variant.get('env_class', None)
env_kwargs = variant.get('env_kwargs', None)
env_id = variant.get('env_id', None)
N = variant.get('N', 10000)
test_p = variant.get('test_p', 0.9)
use_cached = variant.get('use_cached', True)
imsize = variant.get('imsize', 84)
num_channels = variant.get('num_channels', 3)
show = variant.get('show', False)
init_camera = variant.get('init_camera', None)
dataset_path = variant.get('dataset_path', None)
oracle_dataset_using_set_to_goal = variant.get(
'oracle_dataset_using_set_to_goal', False)
random_rollout_data = variant.get('random_rollout_data', False)
random_and_oracle_policy_data = variant.get('random_and_oracle_policy_data',
False)
random_and_oracle_policy_data_split = variant.get(
'random_and_oracle_policy_data_split', 0)
policy_file = variant.get('policy_file', None)
n_random_steps = variant.get('n_random_steps', 100)
vae_dataset_specific_env_kwargs = variant.get(
'vae_dataset_specific_env_kwargs', None)
save_file_prefix = variant.get('save_file_prefix', None)
non_presampled_goal_img_is_garbage = variant.get(
'non_presampled_goal_img_is_garbage', None)
tag = variant.get('tag', '')
from multiworld.core.image_env import ImageEnv, unormalize_image
import rlkit.torch.pytorch_util as ptu
info = {}
if dataset_path is not None:
dataset = load_local_or_remote_file(dataset_path)
N = dataset.shape[0]
else:
if env_kwargs is None:
env_kwargs = {}
if save_file_prefix is None:
save_file_prefix = env_id
if save_file_prefix is None:
save_file_prefix = env_class.__name__
filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format(
save_file_prefix,
str(N),
init_camera.__name__ if init_camera else '',
imsize,
random_and_oracle_policy_data_split,
tag,
)
if use_cached and osp.isfile(filename):
dataset = np.load(filename)
print("loaded data from saved file", filename)
else:
now = time.time()
if env_id is not None:
import gym
import multiworld
multiworld.register_all_envs()
env = gym.make(env_id)
else:
if vae_dataset_specific_env_kwargs is None:
vae_dataset_specific_env_kwargs = {}
for key, val in env_kwargs.items():
if key not in vae_dataset_specific_env_kwargs:
vae_dataset_specific_env_kwargs[key] = val
env = env_class(**vae_dataset_specific_env_kwargs)
if not isinstance(env, ImageEnv):
env = ImageEnv(
env,
imsize,
init_camera=init_camera,
transpose=True,
normalize=True,
non_presampled_goal_img_is_garbage=non_presampled_goal_img_is_garbage,
)
else:
imsize = env.imsize
env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage
env.reset()
info['env'] = env
if random_and_oracle_policy_data:
policy_file = load_local_or_remote_file(policy_file)
policy = policy_file['policy']
policy.to(ptu.device)
if random_rollout_data:
from rlkit.exploration_strategies.ou_strategy import OUStrategy
policy = OUStrategy(env.action_space)
dataset = np.zeros((N, imsize * imsize * num_channels),
dtype=np.uint8)
for i in range(N):
if random_and_oracle_policy_data:
num_random_steps = int(
N * random_and_oracle_policy_data_split)
if i < num_random_steps:
env.reset()
for _ in range(n_random_steps):
obs = env.step(env.action_space.sample())[0]
else:
obs = env.reset()
policy.reset()
for _ in range(n_random_steps):
policy_obs = np.hstack((
obs['state_observation'],
obs['state_desired_goal'],
))
action, _ = policy.get_action(policy_obs)
obs, _, _, _ = env.step(action)
elif oracle_dataset_using_set_to_goal:
print(i)
goal = env.sample_goal()
env.set_to_goal(goal)
obs = env._get_obs()
elif random_rollout_data:
if i % n_random_steps == 0:
g = dict(
state_desired_goal=env.sample_goal_for_rollout())
env.set_to_goal(g)
policy.reset()
# env.reset()
u = policy.get_action_from_raw_action(
env.action_space.sample())
obs = env.step(u)[0]
else:
env.reset()
for _ in range(n_random_steps):
obs = env.step(env.action_space.sample())[0]
img = obs['image_observation']
dataset[i, :] = unormalize_image(img)
if show:
img = img.reshape(3, imsize, imsize).transpose()
img = img[::-1, :, ::-1]
cv2.imshow('img', img)
cv2.waitKey(1)
# radius = input('waiting...')
print("done making training data", filename, time.time() - now)
np.save(filename, dataset)
n = int(N * test_p)
train_dataset = dataset[:n, :]
test_dataset = dataset[n:, :]
return train_dataset, test_dataset, info