in sockeye_contrib/plot_metrics.py [0:0]
def plot_metrics(args):
fig, ax = plt.subplots()
if args.y2:
# Create axis for second Y metric
ax2 = ax.twinx()
overall_best_y = None
if len(args.skip) == 1:
args.skip *= len(args.input)
if len(args.every) == 1:
args.every *= len(args.input)
# Paper scaling
linewidth = 1.25 if args.paper else 1.0
label_size = 12 if args.paper else None
title_size = 16 if args.paper else None
legend_size = 12 if args.paper else None
tick_size = 12 if args.paper else None
for fname, label, skip, every in zip(args.input,
args.legend if args.legend is not None
else (path.basename(fname) for fname in args.input),
args.skip,
args.every):
# Read metrics file to dict
metrics = read_metrics_file(fname)
x_vals = metrics[args.x][skip:]
y_vals = metrics[args.y][skip:]
y2_vals = metrics[args.y2][skip:] if args.y2 else None
x_label=ax_label(args.x)
y_label=ax_label(args.y)
y2_label=ax_label(args.y2)
# Spread points that collapse into one significant digit (ex: epochs)
for i_label, i_vals in zip([args.x, args.y], [x_vals, y_vals]):
if i_label in ['epoch']:
i_vals[:] = np.linspace(i_vals[0], i_vals[-1], len(i_vals))
# Optionally invert Y values
if args.y_invert:
y_vals = [val * -1 for val in y_vals]
if args.y2_invert:
y2_vals = [val * -1 for val in y2_vals]
# Optionally average best points so far for each Y point
if args.y_average is not None:
y_vals = average_points(y_vals, args.y_average, cmp=FIND_BEST[args.y])
y_label = '{} (Average of {} Points)'.format(y_label, args.y_average)
# Optionally count points since last improvement for each Y point
if args.y_since_best:
y_vals = points_since_improvement(y_vals, cmp=FIND_BEST[args.y])
y_label = '{} (Checkpoints Since Improvement)'.format(y_label)
# Optionally compute the window improvement for each Y point
if args.y_window_improvement is not None:
y_vals = window_improvement(y_vals, args.y_window_improvement, cmp=FIND_BEST[args.y])
# Don't plot points for which window improvement is unreliable
# (fewer than number points used for window)
x_vals = x_vals[args.y_window_improvement - 1:]
y_vals = y_vals[args.y_window_improvement - 1:]
y_label = '{} (Window Improvement over {} Points)'.format(y_label, args.y_window_improvement)
# Optionally compute current slope for each Y point
if args.y_slope is not None:
y_vals = slope(y_vals, args.y_slope)
# Don't plot points for which slope is unreliable (fewer than number
# points used to compute slope)
x_vals = x_vals[args.y_slope - 1:]
y_vals = y_vals[args.y_slope - 1:]
if y2_vals:
y2_vals = y2_vals[args.y_slope - 1:]
y_label = '{} (Slope of {} Points)'.format(y_label, args.y_slope)
# Only plot every N values
x_vals = x_vals[::every]
y_vals = y_vals[::every]
if y2_vals:
y2_vals = y2_vals[::every]
# Plot values for this metrics file
ax.plot(x_vals, y_vals, linewidth=linewidth, alpha=0.75, label=label)
ax.set_xlabel(x_label, fontsize=label_size)
ax.set_ylabel(y_label, fontsize=label_size)
plt.title(args.title, fontsize=title_size)
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
# If present, plot and label second Y axis metric
if args.y2:
ax2.plot(x_vals, y2_vals, linewidth=linewidth / 2, alpha=0.75, label=label)
ax2.set_ylabel(y2_label, fontsize=label_size)
# Optionally track best point so far
if args.best:
best_y = FIND_BEST[args.y](y_vals)
if overall_best_y is None:
overall_best_y = best_y
else:
overall_best_y = FIND_BEST[args.y](best_y, overall_best_y)
# Optionally mark best Y point across metrics files
if args.best:
ax.axhline(y=overall_best_y, color='gray', linewidth=linewidth, linestyle='--', zorder=999)
# Optionally draw user specified Y line
if args.y_line is not None:
ax.axhline(y=args.y_line, color='gray', linewidth=linewidth, linestyle='--', zorder=999)
ax.grid()
ax.legend(fontsize=legend_size)
fig.tight_layout()
fig.savefig(args.output)