def visualize_interpolation()

in plot_path_tools.py [0:0]


def visualize_interpolation(gen, checkpoint_1, checkpoint_2, config, out_dir, n_samples=64,
                            device=None, path_min=-0.1, path_max=1.1, n_points=100,
                            key_gen='state_gen', verbose=False):
    """
    Computes stats for plotting path between checkpoint_1 and checkpoint_2.

    Parameters
    ----------
    gen: Generator
    checkpoint_1: pytorch checkpoint
        first checkpoint to plot path interpolation
    checkpoint_2: pytorch checkpoint
        second checkpoint to plot path interpolation
    config: Namespace
        configuration (hyper-parameters) for the generator/discriminator
    out_dir: str
        path to output directory
    n_samples: int
        number of samples per interpolation point
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    fixed_noise = torch.randn(n_samples, config.nz, 1, 1, device=device)
    samples = []
    IS_scores = []
    nrow = int(np.sqrt(n_samples))
    start_time = time.time()
    alpha_vals = []

    # Compute statistics we are interested in for different values of alpha.
    for alpha in np.linspace(path_min, path_max, n_points):
        alpha_vals.append(alpha)
        state_dict_gen = gen.state_dict()
        for p in checkpoint_1[key_gen]:
            state_dict_gen[p] = alpha * checkpoint_2[key_gen][p] + (1 - alpha) * checkpoint_1[key_gen][p]
        gen.load_state_dict(state_dict_gen)

        gen = gen.to(device)
        with torch.no_grad():
            fake = gen(fixed_noise)
            samples.append(fake.cpu())
            vutils.save_image(fake, '%s/samples-%f.png' % (out_dir, alpha), nrow=nrow, normalize=True)

            # generate samples for IS evaluation
            IS_fake = []
            for i in range(10):
                noise = torch.randn(500, config.nz, 1, 1, device=device)
                IS_fake.append(gen(noise))
            IS_fake = torch.cat(IS_fake)

            IS_mean, IS_std = mnist_inception_score(IS_fake, device)
            IS_scores.append(IS_mean)

            if verbose:
                print("IS score: alpha=%.2f, mean=%.4f, std=%.4f" % (alpha, IS_mean, IS_std))

    fig = plt.figure()
    plt.plot(alpha_vals, IS_scores)
    fig.savefig(os.path.join(out_dir, 'IS_scores_path.png'))

    results = {'samples': torch.cat(samples),
               'IS_scores': torch.tensor(IS_scores)}

    torch.save(results, '%s/vis_results.pt' % (out_dir))

    if verbose:
        print("Time to finish: %.2f minutes" % ((time.time() - start_time) / 60.))