in code/colored_mnist/main.py [0:0]
def penalty(logits, y): scale = torch.tensor(1.).cuda().requires_grad_() loss = mean_nll(logits * scale, y) grad = autograd.grad(loss, [scale], create_graph=True)[0] return torch.sum(grad**2)