def main()

in plot-demo.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('exp_root', type=str)
    args = parser.parse_args()

    fname = f"{args.exp_root}/latest.pkl"
    with open(fname, 'rb') as f:
        W = pkl.load(f)

    nrows, ncols = 1, 1
    fig, ax = plt.subplots(
        nrows, ncols, figsize=(6*ncols, 4*nrows),
        subplot_kw={'projection': 'mollweide'}
    )

    all_xs, _, _, Fs, _ = W.flow.apply(
        W.optimizer.target, utils.spherical_to_euclidean(tp), debug=True)
    plot_heatmap(Fs[0].reshape(2*NUM_POINTS, NUM_POINTS), ax)

    fname = f"{args.exp_root}/potential.png"
    print(f'Saving to {fname}')
    fig.tight_layout()
    fig.subplots_adjust(wspace=0, hspace=0)
    fig.savefig(fname)
    os.system(f"convert {fname} -trim {fname}")


    nrows, ncols = 1, 1
    fig, ax = plt.subplots(
        nrows, ncols, figsize=(6*ncols, 4*nrows),
        subplot_kw={'projection': 'mollweide'}
    )

    def plot_grid(x,y, ax=None, **kwargs):
        ax = ax or plt.gca()
        segs1 = np.stack((x,y), axis=2)
        segs2 = segs1.transpose(1,0,2)
        ax.add_collection(LineCollection(segs1, **kwargs))
        ax.add_collection(LineCollection(segs2, **kwargs))

    b = 0.2
    lw = 0.5
    grid_x, grid_y = np.meshgrid(
        np.linspace(-np.pi+b, np.pi-b, 50),
        np.linspace(-np.pi+b, np.pi-b, 50))
    plot_grid(grid_x, grid_y, ax, color='lightgrey', lw=lw)

    grid_sphere = utils.spherical_to_euclidean(
        jnp.stack((grid_x+np.pi, (grid_y+np.pi)/2.)).reshape(2, -1).T
    )
    F_grid_sphere, _ = W.flow.apply(W.optimizer.target, grid_sphere)
    F_grid = utils.euclidean_to_spherical(F_grid_sphere)
    F_grid_x = F_grid[:,0].reshape(grid_x.shape) - np.pi
    F_grid_y = F_grid[:,1].reshape(grid_x.shape)*2. - np.pi
    plot_grid(F_grid_x, F_grid_y, color='C0', lw=lw)

    ax.set_axis_off()
    fname = f"{args.exp_root}/grid.png"
    print(f'Saving to {fname}')
    fig.tight_layout()
    fig.subplots_adjust(wspace=0, hspace=0)
    fig.savefig(fname)
    os.system(f"convert {fname} -trim {fname}")