in MTRF/r3l/r3l/r3l_agents/softlearning/evaluation_scripts/phased_evals_reposition.py [0:0]
def do_evals(seed_dir):
print(seed_dir, "\n")
path = Path(seed_dir)
checkpoint_dirs = [d for d in glob.glob(str(path / "*")) if 'checkpoint' in d and os.path.isdir(d)]
checkpoint_dirs.sort(key=lambda s: int(s.split("_")[-1]), reverse=True)
N_EVAL_EPISODES = 1
T = 50
EVAL_EVERY_N = 2
env = GymAdapter(
"SawyerDhandInHandValve3",
"RepositionFixed-v0",
init_xyz_range_params={
# "type": "UniformRange",
# "values": [np.array([0.72 - 0.25, 0.15 - 0.25, 0.75]), np.array([0.72 + 0.25, 0.15 + 0.25, 0.75])],
"type": "DiscreteRange",
"values": [np.array([0.72 - 0.25, 0.15 - 0.25, 0.75])],
},
init_euler_range_params={
"type": "UniformRange",
"values": [np.array([0, 0, -np.pi]), np.array([0, 0, np.pi])],
},
reset_every_n_episodes=1,
reset_robot=True,
readjust_to_object_in_reset=True,
)
env.reset()
success_rates = []
ckpt_numbers = []
obs_dicts_per_policy = []
rew_dicts_per_policy = []
returns_per_policy = []
for ckpt_dir in checkpoint_dirs[::EVAL_EVERY_N]:
ckpt_number = ckpt_dir.split("_")[-1]
print("EVALUATING CHECKPOINT: ", ckpt_number)
policy = load_policy_from_checkpoint(ckpt_dir, env)
successes = []
obs_dicts = []
rew_dicts = []
returns = []
frames = []
for ep in range(N_EVAL_EPISODES):
env.reset()
while env.get_obs_dict()["object_to_target_xy_distance"] < 0.1:
env.reset()
ret = 0
for t in range(T):
_, rew, done, info = env.step(policy(env.get_obs_dict()))
ret += rew
frames.append(env.render(mode="rgb_array", width=480, height=480))
obs_dict = env.get_obs_dict()
rew_dict = env.get_reward_dict(None, obs_dict)
success = obs_dict["object_to_target_xy_distance"] < 0.1
successes.append(success)
returns.append(ret)
obs_dicts.append(obs_dict)
rew_dicts.append(rew_dict)
video_name = f"./videos/repos/phased_checkpoint_{ckpt_number}.mp4"
save_video(video_name, np.asarray(frames), fps=40)
ckpt_numbers.append(ckpt_number)
success_rate = np.array(successes).astype(int).mean()
print("success % = ", success_rate)
success_rates.append(success_rate)
obs_dicts_per_policy.append(obs_dicts)
rew_dicts_per_policy.append(rew_dicts)
returns_per_policy.append(np.mean(returns))
break
return {
"iters": ckpt_numbers,
"success": success_rates,
"obs": obs_dicts_per_policy,
"rew": rew_dicts_per_policy,
"returns": returns_per_policy,
}