in sagemaker/source/visualization/model_visualisation_utils.py [0:0]
def plot_df_list(df_list, metric_name, y_label, min_final_value):
'''
Helper function to plot the performance of a list of jobs
Parameters:
-----------
df_list: [(str, pd.DataFrame)]
A list of dataframe where str is the jobname and pd.DataFrame is the correponding data.
metric_name: str
Name of the metric used
y_label: str
y_label for the plot
min_final_value: float
Only plots training jobs that reached the specified value
'''
my_dpi = 108
fig = plt.figure(figsize=(1000/my_dpi, 800/my_dpi), dpi=my_dpi)
linewidth = 3
font_size = 24
x = "Epoch"
for job_name, job_df in df_list:
if metric_name not in job_df.columns:
continue
final_value = job_df[metric_name].values[-1]
if final_value > min_final_value:
plt.plot(job_df[x], job_df[metric_name], label=job_name, linewidth=linewidth)
plt.xlabel(x, fontsize=font_size)
plt.ylabel(y_label, fontsize=font_size)
plt.grid(color="0.9", linestyle='-', linewidth=3)
plt.tight_layout()