def plot_path_stats()

in plot_path_tools.py [0:0]


def plot_path_stats(hist, out_dir=None, summary_writer=None, step=0):
    """
    plots interpolation path in `hist` and computed by `compute_path_stats`.
    """
    assert out_dir is not None or summary_writer is not None, 'save results either as files in out_dir or in tensorboard!'

    fig1 = plt.figure()
    # set height ratios for sublots
    gs = gridspec.GridSpec(2, 1)
    # the fisrt subplot
    ax0 = plt.subplot(gs[0])
    cos_sim, = ax0.plot(hist['alpha'], hist['cos_sim'], color='r')

    # plt.plot(hist['alpha'], hist['cos_sim'])
    x_previous = hist['cos_sim'][0]
    for i, x in enumerate(hist['cos_sim']):
        if x * x_previous < 0:
            ax0.axvline(x=hist['alpha'][i], color='black', linestyle='--')
        x_previous = x
    ax0.axhline(y=0, color='black', linestyle='--')
    # the second subplot
    # shared axis X
    ax1 = plt.subplot(gs[1], sharex=ax0)
    ax1.set_yscale('log')
    grad_norm, = ax1.plot(hist['alpha'], hist['grad_total_norm'])
    ax1.axvline(x=0, color='black', linestyle='--')
    ax1.axvline(x=1, color='black', linestyle='--')

    plt.setp(ax0.get_xticklabels(), visible=False)
    # remove last tick label for the second subplot
    yticks = ax1.yaxis.get_major_ticks()
    yticks[-1].label1.set_visible(False)

    # put lened on first subplot
    ax0.legend((cos_sim, grad_norm), ('cos_sim', 'grad_total_norm'), loc='upper right')

    # remove vertical gap between subplots
    plt.subplots_adjust(hspace=.0)

    fig2 = plt.figure()
    plt.plot(hist['alpha'], hist['dot_prod'])
    x_previous = hist['dot_prod'][0]
    for i, x in enumerate(hist['dot_prod']):
        if x * x_previous < 0:
            plt.axvline(x=hist['alpha'][i], color='black', linestyle='--')
        x_previous = x
    plt.axhline(y=0, color='black', linestyle='--')

    fig3 = plt.figure()
    plt.plot(hist['alpha'], hist['gen_loss'])
    plt.axvline(x=0, color='black', linestyle='--')
    plt.axvline(x=1, color='black', linestyle='--')

    fig4 = plt.figure()
    plt.plot(hist['alpha'], hist['dis_loss'])
    plt.axvline(x=0, color='black', linestyle='--')
    plt.axvline(x=1, color='black', linestyle='--')

    fig5 = plt.figure()
    plt.plot(hist['alpha'], hist['penalty'])
    plt.axvline(x=0, color='black', linestyle='--')
    plt.axvline(x=1, color='black', linestyle='--')

    fig6 = plt.figure()
    plt.plot(hist['alpha'], hist['grad_gen_norm'])
    plt.axvline(x=0, color='black', linestyle='--')
    plt.axvline(x=1, color='black', linestyle='--')
    plt.yscale('log')

    fig7 = plt.figure()
    plt.plot(hist['alpha'], hist['grad_dis_norm'])
    plt.axvline(x=0, color='black', linestyle='--')
    plt.axvline(x=1, color='black', linestyle='--')
    plt.yscale('log')

    fig8 = plt.figure()
    grad_norm = np.sqrt(np.array(hist['grad_gen_norm']) + np.array(hist['grad_dis_norm']))
    y_coord = np.sqrt(abs(grad_norm**2 - np.array(hist['dot_prod'])**2))
    plt.quiver(hist['alpha'][::2], 0, hist['dot_prod'][::2], y_coord[::2], width=0.003, scale=np.max(grad_norm) * 2)
    plt.axvline(x=0, color='black', linestyle='--')
    plt.axvline(x=1, color='black', linestyle='--')
    plt.ylim(0, 1)
    plt.xlim(-0.5, 1.5)

    # fig9 = plt.figure()
    # plt.plot(hist['alpha'], hist['grad_total_norm'])
    # plt.axvline(x=0, color='black', linestyle='--')
    # plt.axvline(x=1, color='black', linestyle='--')
    # plt.yscale('log')

    if out_dir is not None:
        fig1.savefig(os.path.join(out_dir, 'cos_sim_%06d.png' % step))
        fig2.savefig(os.path.join(out_dir, 'dot_prod_%06d.png' % step))
        fig3.savefig(os.path.join(out_dir, 'gen_loss_%06d.png' % step))
        fig4.savefig(os.path.join(out_dir, 'dis_loss_%06d.png' % step))
        fig5.savefig(os.path.join(out_dir, 'penalty_%06d.png' % step))
        fig6.savefig(os.path.join(out_dir, 'grad_gen_norm_%06d.png' % step))
        fig7.savefig(os.path.join(out_dir, 'grad_dis_norm_%06d.png' % step))
        fig8.savefig(os.path.join(out_dir, 'grad_direction_%06d.png' % step))
        # fig9.savefig(os.path.join(out_dir, 'grad_total_norm_%06d.png' % step))

    if summary_writer is not None:
        summary_writer.add_figure('cos_sim', fig1, step)
        summary_writer.add_figure('dot_prod', fig2, step)
        summary_writer.add_figure('gen_loss', fig3, step)
        summary_writer.add_figure('dis_loss', fig4, step)
        summary_writer.add_figure('grad_gen_norm', fig6, step)
        summary_writer.add_figure('grad_dis_norm', fig7, step)
        # summary_writer.add_figure('grad_total_norm', fig9, step)
        summary_writer.add_figure('grad_direction', fig8, step)
        summary_writer.add_figure('penalty', fig5, step)