def get_penalty()

in lib/models/discriminator.py [0:0]


    def get_penalty(self, x_true, x_gen, mode="linear"):
        x_true = x_true.view_as(x_gen)
        if mode ==  "linear":
            alpha = torch.rand((len(x_true),)+(1,)*(x_true.dim()-1))
            if x_true.is_cuda:
                alpha = alpha.cuda(x_true.get_device())
            x_penalty = alpha*x_true + (1-alpha)*x_gen
        elif mode == "gen":
            x_penalty = x_gen.clone()
        elif mode == "data":
            x_penalty = x_true.clone()
        x_penalty.requires_grad_()
        p_penalty = self.forward(x_penalty)
        gradients = grad(p_penalty, x_penalty, grad_outputs=torch.ones_like(p_penalty).cuda(x_true.get_device()) if x_true.is_cuda else torch.ones_like(p_penalty), create_graph=True, retain_graph=True, only_inputs=True)[0]
        penalty = ((gradients.view(len(x_true), -1).norm(2, 1) - 1)**2).mean()

        return penalty