in grok/metrics.py [0:0]
def lp_path_norm(model, device, p=2, input_size=[3, 32, 32]):
"""
Path norm (Neyshabur 2015)
"""
tmp_model = copy.deepcopy(model)
tmp_model.eval()
for param in tmp_model.parameters():
if param.requires_grad:
param.abs_().pow_(p)
data_ones = torch.ones(input_size).to(device)
return (tmp_model(data_ones).sum() ** (1 / p)).item()