in scripts/eval_jat.py [0:0]
def eval_rl(model, processor, task, eval_args):
# Create the environment
env_kwargs = {}
if task.startswith("atari"):
env_kwargs["clip_reward"] = False
if eval_args.save_video:
env_kwargs["render_mode"] = "rgb_array"
env = make(task, **env_kwargs)
context_window = 32 if task.startswith("atari") else 256
scores = []
frames = []
for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False):
observation, _ = env.reset()
reward = None
rewards = []
done = False
model.reset_rl() # remove KV Cache
while not done:
action = model.get_next_action(
processor, **observation, reward=reward, action_space=env.action_space, context_window=context_window
)
observation, reward, termined, truncated, info = env.step(action)
done = termined or truncated
# Handle "fake done" for atari
if done and task.startswith("atari"):
if "episode" not in info:
observation, info = env.reset()
done = False
# Update the return
rewards.append(reward)
# Render the environment
if eval_args.save_video:
frames.append(np.array(env.render(), dtype=np.uint8))
scores.append(sum(rewards))
env.close()
raw_mean, raw_std = np.mean(scores), np.std(scores)
# Normalize the scores
norm_scores = normalize(scores, task, "expert")
if norm_scores is not None: # Can be None if random is better than expert
norm_mean, norm_std = np.mean(norm_scores), np.std(norm_scores)
tqdm.write(
f"Task {task} Raw score: {raw_mean:.2f} ± {raw_std:.2f}\t"
f"Normalized score: {norm_mean:.2f} ± {norm_std:.2f}"
)
else:
tqdm.write(f"Task {task} Raw score: {raw_mean:.2f} ± {raw_std:.2f}")
# Resize images by 1/3 to limit memory usage (the video is reduced anyway when aggregated with the others)
if eval_args.save_video:
import cv2
frames = [cv2.resize(frame, (0, 0), fx=1 / 3, fy=1 / 3) for frame in frames]
return scores, frames, env.metadata["render_fps"]