in evals/elsuite/incontext_rl/eval.py [0:0]
def eval_sample(self, solver: Solver, sample: Any, rng: random.Random):
# Validate sample
required_keys = ["env", "env_id", "explanations"]
assert all(
key in sample for key in required_keys
), f"Sample missing required keys: {required_keys}"
assert isinstance(sample["env"], gym.Env)
assert isinstance(sample["env_id"], str)
assert isinstance(sample["explanations"], str)
env = sample["env"]
ts = TaskState(
task_description=self._generate_task_description(env, sample),
messages=[],
current_state=CurrentState(
action_space=env.action_space,
observation_space=env.observation_space,
action_space_n=env.action_space.n, # TODO might not be available for all envs, check when adding a continuous env
observation_space_n=env.observation_space.n, # TODO might not be available for all envs, check when adding a continuous env
),
)
# Reset environment and update task state
observation, _ = env.reset(seed=42)
ts.current_state.observations.append(observation)
# Tell model starting observation and ask it to pick an action
self._add_reset_message_to_task_state(ts, observation, sample)
for _ in range(self.max_steps):
self._add_recap_message_to_task_state(
ts, ts.current_state.actions, ts.current_state.rewards
)
action = self._try_get_valid_action(solver, ts, env.action_space.n)
if action is None:
logger.info("Ending sample since couldn't parse an action.")
break
else:
next_observation, reward, terminated, truncated, _ = env.step(action)
ts.current_state.actions.append(action)
ts.current_state.rewards.append(float(reward))
ts.current_state.observations.append(next_observation)
if terminated or truncated:
# Tell model that episode ended and what reward was received
content = self._format_step_message(
action, next_observation, reward, sample, terminated=True
)
ts.messages += [Message(role="user", content=content)]
# Log what step the episode ended on
ts.current_state.episode_end_steps.append(len(ts.current_state.actions))
# Reset environment
observation, _ = env.reset(seed=42)
ts.current_state.observations.append(observation)
# Tell model new observation after reset and ask it to pick an action
self._add_reset_message_to_task_state(ts, observation, sample)
else:
content = self._format_step_message(action, next_observation, reward, sample)
ts.messages += [Message(role="user", content=content)]
env.close()
episode_rewards = self._calculate_episode_rewards(
ts.current_state.episode_end_steps, ts.current_state.rewards
)
evals.record.record_metrics(
environment=f"{env.spec.id} {env.spec.kwargs}",
explanations=self.use_explanations,
total_return=sum(ts.current_state.rewards),
total_steps=len(ts.current_state.actions),
actions=ts.current_state.actions,
rewards=ts.current_state.rewards,
episode_rewards=episode_rewards,
average_episode_reward=float(np.mean(episode_rewards)),
average_reward_last_5_episodes=float(np.mean(episode_rewards[-5:])),
average_reward_last_10_episodes=float(np.mean(episode_rewards[-10:])),
average_reward_last_20_episodes=float(np.mean(episode_rewards[-20:])),
average_reward_last_50_episodes=float(np.mean(episode_rewards[-50:])),
invalid_response_rate=ts.current_state.invalid_responses
/ ts.current_state.total_responses
if ts.current_state.total_responses > 0
else 0,
episode_end_steps=ts.current_state.episode_end_steps,
)