in plot_path_tools.py [0:0]
def compute_path_stats(gen, dis, checkpoint_1, checkpoint_2, dataloader, config,
model_loss_gen, model_loss_dis,
device=None, path_min=-0.1, path_max=1.1, n_points=100,
key_gen='state_gen', key_dis='state_dis', verbose=False):
"""
Computes stats for plotting path between checkpoint_1 and checkpoint_2.
Parameters
----------
gen: Generator
dis: Discriminator
checkpoint_1: pytorch checkpoint
first checkpoint to plot path interpolation
checkpoint_2: pytorch checkpoint
second checkpoint to plot path interpolation
dataloader: pytorch DataLoader
real data loader (mnist)
config: Namespace
configuration (hyper-parameters) for the generator/discriminator
model_loss_dis, model_loss_gen: function
returns generator and discriminator losses given the discriminator output
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# We compute diff which is a vector representing the vector between input1 and input2
# it is useful later when we compute the cosine similarity and dot product.
params_diff = []
for name, p in gen.named_parameters():
d = (checkpoint_1[key_gen][name] - checkpoint_2[key_gen][name])
params_diff.append(d.flatten())
for name, p in dis.named_parameters():
d = (checkpoint_1[key_dis][name] - checkpoint_2[key_dis][name])
params_diff.append(d.flatten())
params_diff = torch.cat(params_diff)
# The different statistics we want to compute are saved in a dict.
hist = {'alpha': [], 'cos_sim': [], 'dot_prod': [], 'gen_loss': [], 'dis_loss': [],
'penalty': [], 'grad_gen_norm': [], 'grad_dis_norm': [], 'grad_total_norm': []}
start_time = time.time()
# Compute statistics we are interested in for different values of alpha.
for alpha in np.linspace(path_min, path_max, n_points):
############### Computing and loading interpolation ##############
# We compute the interpolation between input1 and input2
# with interpolation-coefficient = alpha and load them into the model.
# When alpha = 0 then the model is equal to the parameters of input1.
state_dict_gen = gen.state_dict()
for p in checkpoint_1[key_gen]:
state_dict_gen[p] = alpha * checkpoint_2[key_gen][p] + (1 - alpha) * checkpoint_1[key_gen][p]
gen.load_state_dict(state_dict_gen)
state_dict_dis = dis.state_dict()
for p in checkpoint_1[key_dis]:
state_dict_dis[p] = alpha * checkpoint_2[key_dis][p] + (1 - alpha) * checkpoint_1[key_dis][p]
dis.load_state_dict(state_dict_dis)
gen = gen.to(device)
dis = dis.to(device)
#################################################################
######### Compute Loss and Gradient over Full-Batch ##########
# cos_sim = 0
# norm_grad_gen = 0
# norm_grad_dis = 0
# dot_prod = 0
gen_loss_epoch = 0
dis_loss_epoch = 0
penalty_epoch = 0
grad_gen_epoch = {}
for name, param in gen.named_parameters():
grad_gen_epoch[name] = torch.zeros_like(param).flatten()
grad_dis_epoch = {}
for name, param in dis.named_parameters():
grad_dis_epoch[name] = torch.zeros_like(param).flatten()
n_data = 0
t0 = time.time()
for i, x_true in enumerate(dataloader):
x_true = x_true[0]
z = torch.randn(x_true.size(0), config.nz, 1, 1)
x_true = x_true.to(device)
z = z.to(device)
for p in gen.parameters():
if p.grad is not None:
p.grad.zero_()
for p in dis.parameters():
if p.grad is not None:
p.grad.zero_()
################# Compute Loss #########################
# TODO: Needs to be changed to be able to handle different kind of loss
x_gen = gen(z)
dis_loss, _, _ = model_loss_dis(x_true, x_gen.detach(), dis, device)
gen_loss, _ = model_loss_gen(x_gen, dis, device)
if config.model == 'wgan_gp':
penalty = dis.get_penalty(x_true.detach(), x_gen.detach()).mean()
dis_loss += config.gp_lambda * penalty
else:
penalty = torch.zeros(1)
#################################################
for p in dis.parameters():
p.requires_grad = False
gen_loss.backward(retain_graph=True)
for p in dis.parameters():
p.requires_grad = True
for p in gen.parameters():
p.requires_grad = False
dis_loss.backward()
for p in gen.parameters():
p.requires_grad = True
for name, param in gen.named_parameters():
grad_gen_epoch[name] += param.grad.flatten() * len(x_true)
for name, param in dis.named_parameters():
grad_dis_epoch[name] += param.grad.flatten() * len(x_true)
gen_loss_epoch += gen_loss.item() * len(x_true)
dis_loss_epoch += dis_loss.item() * len(x_true)
penalty_epoch += penalty.item() * len(x_true)
n_data += len(x_true)
########################################################
gen_loss_epoch /= n_data
dis_loss_epoch /= n_data
penalty_epoch /= n_data
grad_gen = []
for name, _ in gen.named_parameters():
grad_gen.append(grad_gen_epoch[name])
grad_dis = []
for name, param in dis.named_parameters():
param_flat = param.flatten()
grad_param = grad_dis_epoch[name]
if config.model == 'wgan':
# zero-out gradient that violate wgan weight constraints
zero_mask = (torch.abs(param_flat) == config.clip) &\
(torch.sign(grad_param) == torch.sign(param_flat))
grad_param[zero_mask] = 0.0
grad_dis.append(grad_param)
grad_gen = torch.cat(grad_gen) / n_data
grad_dis = torch.cat(grad_dis) / n_data
grad_all = torch.cat([grad_gen, grad_dis])
####### Compute statistics we are interested in ##########
# Compute squared norm of the gradient
norm_grad_gen = (grad_gen**2).sum().cpu().numpy()
norm_grad_dis = (grad_dis**2).sum().cpu().numpy()
# Compute the dot product (unnormalized cosine similarity)
dot_prod = (grad_all * params_diff).sum() / torch.sqrt((params_diff**2).sum())
# Compute cosine similarity
cos_sim = dot_prod / torch.sqrt((grad_all**2).sum())
dot_prod = dot_prod.item()
cos_sim = cos_sim.item()
# # Compute cosine similarity
# cos_sim = 1 - distance.cosine(grad_all, params_diff)
# # Compute the dot product (unnormalized cosine similarity)
# dot_prod = (grad_all * params_diff).sum() / np.sqrt((params_diff**2).sum())
##########################################################
if verbose:
print("Alpha: %.2f, Angle: %.2f, Generator loss: %.2e, Discriminator loss: %.2e, Penalty: %.2f, Gen grad norm: %.2e, Dis grad norm: %.2e, Time: %.2fsec"
% (alpha, cos_sim, gen_loss_epoch, dis_loss_epoch, penalty_epoch, norm_grad_gen, norm_grad_dis, time.time() - t0))
hist['alpha'].append(alpha)
hist['cos_sim'].append(cos_sim)
hist['dot_prod'].append(dot_prod)
hist['gen_loss'].append(gen_loss_epoch)
hist['dis_loss'].append(dis_loss_epoch)
hist['penalty'].append(penalty_epoch)
hist['grad_gen_norm'].append(norm_grad_gen)
hist['grad_dis_norm'].append(norm_grad_dis)
hist['grad_total_norm'].append(norm_grad_dis + norm_grad_gen)
if verbose:
print("Time to finish: %.2f minutes" % ((time.time() - start_time) / 60.))
return hist