def plot_df_list()

in sagemaker/source/visualization/model_visualisation_utils.py [0:0]


def plot_df_list(df_list, metric_name, y_label, min_final_value):
    '''
    Helper function to plot the performance of a list of jobs
    Parameters:
    -----------
    df_list: [(str, pd.DataFrame)]
        A list of dataframe where str is the jobname and pd.DataFrame is the correponding data.
        
    metric_name: str
        Name of the metric used
        
    y_label: str
        y_label for the plot
        
    min_final_value: float
        Only plots training jobs that reached the specified value
    '''
    my_dpi = 108
    fig = plt.figure(figsize=(1000/my_dpi, 800/my_dpi), dpi=my_dpi)

    linewidth = 3
    font_size = 24
    
    x = "Epoch"
    
    for job_name, job_df in df_list:
        if metric_name not in job_df.columns:
            continue
        final_value = job_df[metric_name].values[-1]
        if final_value > min_final_value:
            plt.plot(job_df[x], job_df[metric_name], label=job_name, linewidth=linewidth)
    plt.xlabel(x, fontsize=font_size)
    plt.ylabel(y_label, fontsize=font_size)
    plt.grid(color="0.9", linestyle='-', linewidth=3)
    
    plt.tight_layout()