in scripts/run_goal_conditioned_policy.py [0:0]
def simulate_policy(args):
data = torch.load(args.file)
policy = data['evaluation/policy']
env = data['evaluation/env']
print("Policy and environment loaded")
if args.gpu:
ptu.set_gpu_mode(True)
policy.to(ptu.device)
if isinstance(env, VAEWrappedEnv) and hasattr(env, 'mode'):
env.mode(args.mode)
if args.enable_render or hasattr(env, 'enable_render'):
# some environments need to be reconfigured for visualization
env.enable_render()
paths = []
while True:
paths.append(multitask_rollout(
env,
policy,
max_path_length=args.H,
render=not args.hide,
observation_key='observation',
desired_goal_key='desired_goal',
))
if hasattr(env, "log_diagnostics"):
env.log_diagnostics(paths)
if hasattr(env, "get_diagnostics"):
for k, v in env.get_diagnostics(paths).items():
logger.record_tabular(k, v)
logger.dump_tabular()