def eval_sample()

in evals/elsuite/incontext_rl/eval.py [0:0]


    def eval_sample(self, solver: Solver, sample: Any, rng: random.Random):

        # Validate sample
        required_keys = ["env", "env_id", "explanations"]
        assert all(
            key in sample for key in required_keys
        ), f"Sample missing required keys: {required_keys}"
        assert isinstance(sample["env"], gym.Env)
        assert isinstance(sample["env_id"], str)
        assert isinstance(sample["explanations"], str)

        env = sample["env"]
        ts = TaskState(
            task_description=self._generate_task_description(env, sample),
            messages=[],
            current_state=CurrentState(
                action_space=env.action_space,
                observation_space=env.observation_space,
                action_space_n=env.action_space.n,  # TODO might not be available for all envs, check when adding a continuous env
                observation_space_n=env.observation_space.n,  # TODO might not be available for all envs, check when adding a continuous env
            ),
        )

        # Reset environment and update task state
        observation, _ = env.reset(seed=42)
        ts.current_state.observations.append(observation)

        # Tell model starting observation and ask it to pick an action
        self._add_reset_message_to_task_state(ts, observation, sample)

        for _ in range(self.max_steps):
            self._add_recap_message_to_task_state(
                ts, ts.current_state.actions, ts.current_state.rewards
            )

            action = self._try_get_valid_action(solver, ts, env.action_space.n)

            if action is None:
                logger.info("Ending sample since couldn't parse an action.")
                break
            else:
                next_observation, reward, terminated, truncated, _ = env.step(action)
                ts.current_state.actions.append(action)
                ts.current_state.rewards.append(float(reward))
                ts.current_state.observations.append(next_observation)

            if terminated or truncated:
                # Tell model that episode ended and what reward was received
                content = self._format_step_message(
                    action, next_observation, reward, sample, terminated=True
                )
                ts.messages += [Message(role="user", content=content)]

                # Log what step the episode ended on
                ts.current_state.episode_end_steps.append(len(ts.current_state.actions))

                # Reset environment
                observation, _ = env.reset(seed=42)
                ts.current_state.observations.append(observation)

                # Tell model new observation after reset and ask it to pick an action
                self._add_reset_message_to_task_state(ts, observation, sample)
            else:
                content = self._format_step_message(action, next_observation, reward, sample)
                ts.messages += [Message(role="user", content=content)]

        env.close()

        episode_rewards = self._calculate_episode_rewards(
            ts.current_state.episode_end_steps, ts.current_state.rewards
        )
        evals.record.record_metrics(
            environment=f"{env.spec.id} {env.spec.kwargs}",
            explanations=self.use_explanations,
            total_return=sum(ts.current_state.rewards),
            total_steps=len(ts.current_state.actions),
            actions=ts.current_state.actions,
            rewards=ts.current_state.rewards,
            episode_rewards=episode_rewards,
            average_episode_reward=float(np.mean(episode_rewards)),
            average_reward_last_5_episodes=float(np.mean(episode_rewards[-5:])),
            average_reward_last_10_episodes=float(np.mean(episode_rewards[-10:])),
            average_reward_last_20_episodes=float(np.mean(episode_rewards[-20:])),
            average_reward_last_50_episodes=float(np.mean(episode_rewards[-50:])),
            invalid_response_rate=ts.current_state.invalid_responses
            / ts.current_state.total_responses
            if ts.current_state.total_responses > 0
            else 0,
            episode_end_steps=ts.current_state.episode_end_steps,
        )