in plot_path_tools.py [0:0]
def visualize_interpolation(gen, checkpoint_1, checkpoint_2, config, out_dir, n_samples=64,
device=None, path_min=-0.1, path_max=1.1, n_points=100,
key_gen='state_gen', verbose=False):
"""
Computes stats for plotting path between checkpoint_1 and checkpoint_2.
Parameters
----------
gen: Generator
checkpoint_1: pytorch checkpoint
first checkpoint to plot path interpolation
checkpoint_2: pytorch checkpoint
second checkpoint to plot path interpolation
config: Namespace
configuration (hyper-parameters) for the generator/discriminator
out_dir: str
path to output directory
n_samples: int
number of samples per interpolation point
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fixed_noise = torch.randn(n_samples, config.nz, 1, 1, device=device)
samples = []
IS_scores = []
nrow = int(np.sqrt(n_samples))
start_time = time.time()
alpha_vals = []
# Compute statistics we are interested in for different values of alpha.
for alpha in np.linspace(path_min, path_max, n_points):
alpha_vals.append(alpha)
state_dict_gen = gen.state_dict()
for p in checkpoint_1[key_gen]:
state_dict_gen[p] = alpha * checkpoint_2[key_gen][p] + (1 - alpha) * checkpoint_1[key_gen][p]
gen.load_state_dict(state_dict_gen)
gen = gen.to(device)
with torch.no_grad():
fake = gen(fixed_noise)
samples.append(fake.cpu())
vutils.save_image(fake, '%s/samples-%f.png' % (out_dir, alpha), nrow=nrow, normalize=True)
# generate samples for IS evaluation
IS_fake = []
for i in range(10):
noise = torch.randn(500, config.nz, 1, 1, device=device)
IS_fake.append(gen(noise))
IS_fake = torch.cat(IS_fake)
IS_mean, IS_std = mnist_inception_score(IS_fake, device)
IS_scores.append(IS_mean)
if verbose:
print("IS score: alpha=%.2f, mean=%.4f, std=%.4f" % (alpha, IS_mean, IS_std))
fig = plt.figure()
plt.plot(alpha_vals, IS_scores)
fig.savefig(os.path.join(out_dir, 'IS_scores_path.png'))
results = {'samples': torch.cat(samples),
'IS_scores': torch.tensor(IS_scores)}
torch.save(results, '%s/vis_results.pt' % (out_dir))
if verbose:
print("Time to finish: %.2f minutes" % ((time.time() - start_time) / 60.))