in plot_path_tools.py [0:0]
def compute_eigenvalues(gen, dis, dataloader, config,
model_loss_gen, model_loss_dis,
device=None, n_eigs=20, verbose=False, imaginary=False):
"""
Computes stats for plotting path between checkpoint_1 and checkpoint_2.
Parameters
----------
gen: Generator
dis: Discriminator
dataloader: pytorch DataLoader
real data loader (mnist)
config: Namespace
configuration (hyper-parameters) for the generator/discriminator
"""
start_time = time.time()
grad_gen_epoch = [torch.zeros_like(p) for p in gen.parameters()]
grad_dis_epoch = [torch.zeros_like(p) for p in dis.parameters()]
n_data = 0
for i, x_true in enumerate(dataloader):
print(i)
x_true = x_true[0]
z = torch.randn(x_true.size(0), config.nz, 1, 1)
x_true = x_true.to(device)
z = z.to(device)
################# Compute Loss #########################
# TODO: Needs to be changed to be able to handle different kind of loss
x_gen = gen(z)
dis_loss, _, _ = model_loss_dis(x_true, x_gen, dis, device)
gen_loss, _ = model_loss_gen(x_gen, dis, device)
# p_true, p_gen = dis(x_true), dis(x_gen)
# gen_loss, dis_loss = utils.compute_loss(p_true, p_gen, mode=config.model)
if config.model == 'wgan_gp':
penalty = dis.get_penalty(x_true, x_gen).mean()
dis_loss += config.gp_lambda * penalty
else:
penalty = torch.zeros(1)
grad_gen = autograd.grad(gen_loss, gen.parameters(), create_graph=True)
grad_dis = autograd.grad(dis_loss, dis.parameters(), create_graph=True)
for i, g in enumerate(grad_gen):
grad_gen_epoch[i] += g * len(x_true)
for i, g in enumerate(grad_dis):
grad_dis_epoch[i] += g * len(x_true)
n_data += len(x_true)
grad_gen_epoch = [g / n_data for g in grad_gen_epoch]
grad_dis_epoch = [g / n_data for g in grad_dis_epoch]
t0 = time.time()
A = JacobianVectorProduct(grad_gen_epoch, list(gen.parameters()))
if imaginary:
gen_eigs = linalg.eigs(A, k=n_eigs, which='LI')[0]
else:
gen_eigs = linalg.eigsh(A, k=n_eigs)[0]
print("Time to compute Eig-values: %.2f" % (time.time() - t0))
t0 = time.time()
A = JacobianVectorProduct(grad_dis_epoch, list(dis.parameters()))
if imaginary:
dis_eigs = linalg.eigs(A, k=n_eigs, which='LI')[0]
else:
dis_eigs = linalg.eigsh(A, k=n_eigs)[0]
print("Time to compute Eig-values: %.2f" % (time.time() - t0))
t0 = time.time()
grad = grad_gen_epoch + grad_dis_epoch
params = list(gen.parameters()) + list(dis.parameters())
A = JacobianVectorProduct(grad, params)
if imaginary:
game_eigs = linalg.eigs(A, k=n_eigs, which='LI')[0]
else:
game_eigs = linalg.eigs(A, k=n_eigs)[0]
print("Time to compute Eig-values: %.2f" % (time.time() - t0))
if verbose:
print(gen_eigs[:5])
print(dis_eigs[:5])
print(game_eigs[:5])
print("Time to finish: %.2f minutes" % ((time.time() - start_time) / 60.))
return gen_eigs, dis_eigs, game_eigs