def plot_metrics()

in sockeye_contrib/plot_metrics.py [0:0]


def plot_metrics(args):

    fig, ax = plt.subplots()
    if args.y2:
        # Create axis for second Y metric
        ax2 = ax.twinx()
    overall_best_y = None

    if len(args.skip) == 1:
        args.skip *= len(args.input)

    if len(args.every) == 1:
        args.every *= len(args.input)

    # Paper scaling
    linewidth = 1.25 if args.paper else 1.0
    label_size = 12 if args.paper else None
    title_size = 16 if args.paper else None
    legend_size = 12 if args.paper else None
    tick_size = 12 if args.paper else None

    for fname, label, skip, every in zip(args.input,
                                         args.legend if args.legend is not None
                                         else (path.basename(fname) for fname in args.input),
                                         args.skip,
                                         args.every):
        # Read metrics file to dict
        metrics = read_metrics_file(fname)
        x_vals = metrics[args.x][skip:]
        y_vals = metrics[args.y][skip:]
        y2_vals = metrics[args.y2][skip:] if args.y2 else None
        x_label=ax_label(args.x)
        y_label=ax_label(args.y)
        y2_label=ax_label(args.y2)
        # Spread points that collapse into one significant digit (ex: epochs)
        for i_label, i_vals in zip([args.x, args.y], [x_vals, y_vals]):
            if i_label in ['epoch']:
                i_vals[:] = np.linspace(i_vals[0], i_vals[-1], len(i_vals))
        # Optionally invert Y values
        if args.y_invert:
            y_vals = [val * -1 for val in y_vals]
        if args.y2_invert:
            y2_vals = [val * -1 for val in y2_vals]
        # Optionally average best points so far for each Y point
        if args.y_average is not None:
            y_vals = average_points(y_vals, args.y_average, cmp=FIND_BEST[args.y])
            y_label = '{} (Average of {} Points)'.format(y_label, args.y_average)
        # Optionally count points since last improvement for each Y point
        if args.y_since_best:
            y_vals = points_since_improvement(y_vals, cmp=FIND_BEST[args.y])
            y_label = '{} (Checkpoints Since Improvement)'.format(y_label)
        # Optionally compute the window improvement for each Y point
        if args.y_window_improvement is not None:
            y_vals = window_improvement(y_vals, args.y_window_improvement, cmp=FIND_BEST[args.y])
            # Don't plot points for which window improvement is unreliable
            # (fewer than number points used for window)
            x_vals = x_vals[args.y_window_improvement - 1:]
            y_vals = y_vals[args.y_window_improvement - 1:]
            y_label = '{} (Window Improvement over {} Points)'.format(y_label, args.y_window_improvement)
        # Optionally compute current slope for each Y point
        if args.y_slope is not None:
            y_vals = slope(y_vals, args.y_slope)
            # Don't plot points for which slope is unreliable (fewer than number
            # points used to compute slope)
            x_vals = x_vals[args.y_slope - 1:]
            y_vals = y_vals[args.y_slope - 1:]
            if y2_vals:
                y2_vals = y2_vals[args.y_slope - 1:]
            y_label = '{} (Slope of {} Points)'.format(y_label, args.y_slope)
        # Only plot every N values
        x_vals = x_vals[::every]
        y_vals = y_vals[::every]
        if y2_vals:
            y2_vals = y2_vals[::every]
        # Plot values for this metrics file
        ax.plot(x_vals, y_vals, linewidth=linewidth, alpha=0.75, label=label)
        ax.set_xlabel(x_label, fontsize=label_size)
        ax.set_ylabel(y_label, fontsize=label_size)
        plt.title(args.title, fontsize=title_size)
        plt.xticks(fontsize=tick_size)
        plt.yticks(fontsize=tick_size)
        # If present, plot and label second Y axis metric
        if args.y2:
            ax2.plot(x_vals, y2_vals, linewidth=linewidth / 2, alpha=0.75, label=label)
            ax2.set_ylabel(y2_label, fontsize=label_size)
        # Optionally track best point so far
        if args.best:
            best_y = FIND_BEST[args.y](y_vals)
            if overall_best_y is None:
                overall_best_y = best_y
            else:
                overall_best_y = FIND_BEST[args.y](best_y, overall_best_y)
    # Optionally mark best Y point across metrics files
    if args.best:
        ax.axhline(y=overall_best_y, color='gray', linewidth=linewidth, linestyle='--', zorder=999)
    # Optionally draw user specified Y line
    if args.y_line is not None:
        ax.axhline(y=args.y_line, color='gray', linewidth=linewidth, linestyle='--', zorder=999)

    ax.grid()
    ax.legend(fontsize=legend_size)

    fig.tight_layout()
    fig.savefig(args.output)