in pycls/utils/plotting.py [0:0]
def plot_error_curves_plotly(log_files, names, filename, key='top1_err'):
"""Plot error curves using plotly and save to file."""
plot_data = prepare_plot_data(log_files, names, key)
colors = get_plot_colors(len(plot_data), 'plotly')
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
data = []
for i, d in enumerate(plot_data):
s = str(i)
line_train = {'color': colors[i], 'dash': 'dashdot', 'width': 1.5}
line_test = {'color': colors[i], 'dash': 'solid', 'width': 1.5}
data.append(go.Scatter(
x=d['x_train'], y=d['y_train'], mode='lines', name=d['train_label'],
line=line_train, legendgroup=s, visible=True, showlegend=False
))
data.append(go.Scatter(
x=d['x_test'], y=d['y_test'], mode='lines', name=d['test_label'],
line=line_test, legendgroup=s, visible=True, showlegend=True
))
data.append(go.Scatter(
x=d['x_train'], y=d['y_train'], mode='lines', name=d['train_label'],
line=line_train, legendgroup=s, visible=False, showlegend=True
))
# Prepare layout w ability to toggle 'all', 'train', 'test'
titlefont = {'size': 18, 'color': '#7f7f7f'}
vis = [[True, True, False], [False, False, True], [False, True, False]]
buttons = zip(['all', 'train', 'test'], [[{'visible': v}] for v in vis])
buttons = [{'label': l, 'args': v, 'method': 'update'} for l, v in buttons]
layout = go.Layout(
title=key + ' vs. epoch<br>[dash=train, solid=test]',
xaxis={'title': 'epoch', 'titlefont': titlefont},
yaxis={'title': key, 'titlefont': titlefont},
showlegend=True,
hoverlabel={'namelength': -1},
updatemenus=[{
'buttons': buttons, 'direction': 'down', 'showactive': True,
'x': 1.02, 'xanchor': 'left', 'y': 1.08, 'yanchor': 'top'
}]
)
# Create plotly plot
offline.plot({'data': data, 'layout': layout}, filename=filename)