in evals/elsuite/incontext_rl/scripts/plot_experiments.py [0:0]
def plot_rewards(df, environment, reward_type, out_dir, window_size=None):
"""
Generalized function to plot episode, cumulative, or running average rewards for different models
on the same graph for a specific environment. It automatically determines the plot type (line or scatter)
based on the number of episodes and includes the 95% confidence intervals for line plots.
Args:
df (pd.DataFrame): DataFrame containing the experiment results.
environment (str): Name of the environment to plot.
reward_type (str): Type of reward to plot. Must be one of 'episode_rewards', 'cumulative_episode_rewards', or 'rolling_average_rewards'.
out_dir (Path): Path to the directory to save the plots.
window_size (int): Window size for calculating rolling averages. If None, the window size will be determined based on the environment.
"""
valid_reward_types = ['episode_rewards', 'cumulative_episode_rewards', 'rolling_average_rewards']
if reward_type not in valid_reward_types:
raise ValueError(f"Invalid reward_type. Expected one of {valid_reward_types}, got {reward_type}")
# Filter the DataFrame for the specific environment
filtered_df = df[df['environment'] == environment]
# Explode the specified reward list into separate rows and prepare for plotting
rewards_df = filtered_df.explode(reward_type).reset_index() # Each row will be a single episode
rewards_df['episode'] = rewards_df.groupby(['model', 'index']).cumcount() + 1 # Add episode number as a column
rewards_df['reward'] = rewards_df[reward_type] # Rename the column for clarity
truncate_per_model = True
if environment == "CliffWalking-v0 {}":
truncate_per_model = False # Hacky workaround to make better plots since some models only have 1 episode on CliffWalking
if truncate_per_model:
filtered_rewards_df = pd.DataFrame()
for model, group in rewards_df.groupby('model'):
# Count the number of runs for each episode number
episode_counts = group.groupby('episode').size()
# Check if there are at least 3 runs for any episode number
if episode_counts.max() >= 3:
# Find the maximum episode number where at least 3 runs are available
max_episode_with_at_least_3_runs = episode_counts[episode_counts >= 3].index.max()
# Filter the group DataFrame to only include data up to this episode number
model_filtered = group[group['episode'] <= max_episode_with_at_least_3_runs]
else:
# If there are fewer than 3 runs for all episodes, include all data for this model
model_filtered = group
# Append the filtered data for the current model to the overall filtered DataFrame
filtered_rewards_df = pd.concat([filtered_rewards_df, model_filtered], ignore_index=True)
rewards_df = filtered_rewards_df
plt.figure(figsize=(10, 5))
ax = plt.gca()
# Determine the plot type based on the number of episodes
num_episodes = len(rewards_df['episode'].unique())
if num_episodes > 1:
# Iterate over each unique model in the DataFrame
for model in rewards_df['model'].unique():
# Filter the DataFrame for the current model
model_df = rewards_df[rewards_df['model'] == model]
# Get the custom style for the current model using the helper function
custom_style = MODEL_STYLES.get(model, MODEL_STYLES['default'])
pretty_model_name = PRETTY_MODEL_NAMES.get(model, model)
# Plot the data for the current model on the same axes with custom settings
lineplot = sns.lineplot(data=model_df, x='episode', y='reward', estimator='mean', errorbar=('ci', 95),
linestyle=custom_style['line_style'], color=custom_style['color'],
alpha=custom_style['alpha'], label=pretty_model_name, ax=ax,
err_kws={'alpha': 0.035})
# Add labels to the final value on the x axis
for line in lineplot.get_lines():
x, y = line.get_data()
if len(x) > 0: # Check if there is data to plot
ax.text(x[-1], y[-1], f"{y[-1]:.2f}", color=line.get_color(), fontsize=9)
else:
# For a single episode, use scatter plot, differentiating models by color
scatterplot = sns.scatterplot(data=rewards_df, x='episode', y='reward', hue='model', ax=ax)
# Add labels to the final value on the x axis
for line in scatterplot.collections:
offsets = line.get_offsets()
if offsets.size > 0: # Check if there are points to plot
last_point = offsets[-1]
ax.text(last_point[0], last_point[1], f"{last_point[1]:.2f}", fontsize=9)
pretty_env_title = PRETTY_ENV_TITLES.get(environment, environment)
plt.title(f'{reward_type.replace("_", " ").title()} in {pretty_env_title} (Window Size: {window_size})' if reward_type == 'rolling_average_rewards' else f'{reward_type.replace("_", " ").title()} in {pretty_env_title}')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xlim(1, num_episodes)
plt.tight_layout()
plot_dir = out_dir / reward_type
plot_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(plot_dir / f'{environment}.png')
plt.show()