def eval_rl()

in scripts/eval_jat.py [0:0]


def eval_rl(model, processor, task, eval_args):
    # Create the environment
    env_kwargs = {}
    if task.startswith("atari"):
        env_kwargs["clip_reward"] = False
    if eval_args.save_video:
        env_kwargs["render_mode"] = "rgb_array"

    env = make(task, **env_kwargs)

    context_window = 32 if task.startswith("atari") else 256

    scores = []
    frames = []
    for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False):
        observation, _ = env.reset()
        reward = None
        rewards = []
        done = False
        model.reset_rl()  # remove KV Cache
        while not done:
            action = model.get_next_action(
                processor, **observation, reward=reward, action_space=env.action_space, context_window=context_window
            )
            observation, reward, termined, truncated, info = env.step(action)
            done = termined or truncated

            # Handle "fake done" for atari
            if done and task.startswith("atari"):
                if "episode" not in info:
                    observation, info = env.reset()
                    done = False

            # Update the return
            rewards.append(reward)

            # Render the environment
            if eval_args.save_video:
                frames.append(np.array(env.render(), dtype=np.uint8))

        scores.append(sum(rewards))
    env.close()

    raw_mean, raw_std = np.mean(scores), np.std(scores)

    # Normalize the scores
    norm_scores = normalize(scores, task, "expert")
    if norm_scores is not None:  # Can be None if random is better than expert
        norm_mean, norm_std = np.mean(norm_scores), np.std(norm_scores)
        tqdm.write(
            f"Task {task} Raw score: {raw_mean:.2f} ± {raw_std:.2f}\t"
            f"Normalized score: {norm_mean:.2f} ± {norm_std:.2f}"
        )
    else:
        tqdm.write(f"Task {task} Raw score: {raw_mean:.2f} ± {raw_std:.2f}")

    # Resize images by 1/3 to limit memory usage (the video is reduced anyway when aggregated with the others)
    if eval_args.save_video:
        import cv2

        frames = [cv2.resize(frame, (0, 0), fx=1 / 3, fy=1 / 3) for frame in frames]

    return scores, frames, env.metadata["render_fps"]