in lib/models/discriminator.py [0:0]
def get_penalty(self, x_true, x_gen, mode="linear"):
x_true = x_true.view_as(x_gen)
if mode == "linear":
alpha = torch.rand((len(x_true),)+(1,)*(x_true.dim()-1))
if x_true.is_cuda:
alpha = alpha.cuda(x_true.get_device())
x_penalty = alpha*x_true + (1-alpha)*x_gen
elif mode == "gen":
x_penalty = x_gen.clone()
elif mode == "data":
x_penalty = x_true.clone()
x_penalty.requires_grad_()
p_penalty = self.forward(x_penalty)
gradients = grad(p_penalty, x_penalty, grad_outputs=torch.ones_like(p_penalty).cuda(x_true.get_device()) if x_true.is_cuda else torch.ones_like(p_penalty), create_graph=True, retain_graph=True, only_inputs=True)[0]
penalty = ((gradients.view(len(x_true), -1).norm(2, 1) - 1)**2).mean()
return penalty