def plot_rewards()

in evals/elsuite/incontext_rl/scripts/plot_experiments.py [0:0]


def plot_rewards(df, environment, reward_type, out_dir, window_size=None):
    """
    Generalized function to plot episode, cumulative, or running average rewards for different models
    on the same graph for a specific environment. It automatically determines the plot type (line or scatter)
    based on the number of episodes and includes the 95% confidence intervals for line plots.

    Args:
    df (pd.DataFrame): DataFrame containing the experiment results.
    environment (str): Name of the environment to plot.
    reward_type (str): Type of reward to plot. Must be one of 'episode_rewards', 'cumulative_episode_rewards', or 'rolling_average_rewards'.
    out_dir (Path): Path to the directory to save the plots.
    window_size (int): Window size for calculating rolling averages. If None, the window size will be determined based on the environment.
    """
    valid_reward_types = ['episode_rewards', 'cumulative_episode_rewards', 'rolling_average_rewards']
    if reward_type not in valid_reward_types:
        raise ValueError(f"Invalid reward_type. Expected one of {valid_reward_types}, got {reward_type}")

    # Filter the DataFrame for the specific environment
    filtered_df = df[df['environment'] == environment]

    # Explode the specified reward list into separate rows and prepare for plotting
    rewards_df = filtered_df.explode(reward_type).reset_index()  # Each row will be a single episode
    rewards_df['episode'] = rewards_df.groupby(['model', 'index']).cumcount() + 1  # Add episode number as a column
    rewards_df['reward'] = rewards_df[reward_type]  # Rename the column for clarity

    truncate_per_model = True
    if environment == "CliffWalking-v0 {}":
        truncate_per_model = False  # Hacky workaround to make better plots since some models only have 1 episode on CliffWalking

    if truncate_per_model:
        filtered_rewards_df = pd.DataFrame()
        for model, group in rewards_df.groupby('model'):
            # Count the number of runs for each episode number
            episode_counts = group.groupby('episode').size()
            # Check if there are at least 3 runs for any episode number
            if episode_counts.max() >= 3:
                # Find the maximum episode number where at least 3 runs are available
                max_episode_with_at_least_3_runs = episode_counts[episode_counts >= 3].index.max()
                # Filter the group DataFrame to only include data up to this episode number
                model_filtered = group[group['episode'] <= max_episode_with_at_least_3_runs]
            else:
                # If there are fewer than 3 runs for all episodes, include all data for this model
                model_filtered = group
            # Append the filtered data for the current model to the overall filtered DataFrame
            filtered_rewards_df = pd.concat([filtered_rewards_df, model_filtered], ignore_index=True)
        rewards_df = filtered_rewards_df

    plt.figure(figsize=(10, 5))
    ax = plt.gca()

    # Determine the plot type based on the number of episodes
    num_episodes = len(rewards_df['episode'].unique())
    if num_episodes > 1:
        # Iterate over each unique model in the DataFrame
        for model in rewards_df['model'].unique():
            # Filter the DataFrame for the current model
            model_df = rewards_df[rewards_df['model'] == model]
            # Get the custom style for the current model using the helper function
            custom_style = MODEL_STYLES.get(model, MODEL_STYLES['default'])
            pretty_model_name = PRETTY_MODEL_NAMES.get(model, model)
            # Plot the data for the current model on the same axes with custom settings
            lineplot = sns.lineplot(data=model_df, x='episode', y='reward', estimator='mean', errorbar=('ci', 95), 
                                    linestyle=custom_style['line_style'], color=custom_style['color'], 
                                    alpha=custom_style['alpha'], label=pretty_model_name, ax=ax,
                                    err_kws={'alpha': 0.035})
            # Add labels to the final value on the x axis
            for line in lineplot.get_lines():
                x, y = line.get_data()
                if len(x) > 0:  # Check if there is data to plot
                    ax.text(x[-1], y[-1], f"{y[-1]:.2f}", color=line.get_color(), fontsize=9)
    else:
        # For a single episode, use scatter plot, differentiating models by color
        scatterplot = sns.scatterplot(data=rewards_df, x='episode', y='reward', hue='model', ax=ax)
        # Add labels to the final value on the x axis
        for line in scatterplot.collections:
            offsets = line.get_offsets()
            if offsets.size > 0:  # Check if there are points to plot
                last_point = offsets[-1]
                ax.text(last_point[0], last_point[1], f"{last_point[1]:.2f}", fontsize=9)

    pretty_env_title = PRETTY_ENV_TITLES.get(environment, environment)
    plt.title(f'{reward_type.replace("_", " ").title()} in {pretty_env_title} (Window Size: {window_size})' if reward_type == 'rolling_average_rewards' else f'{reward_type.replace("_", " ").title()} in {pretty_env_title}')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xlim(1, num_episodes)
    plt.tight_layout()
    plot_dir = out_dir / reward_type
    plot_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(plot_dir / f'{environment}.png')
    plt.show()