def lp_path_norm()

in grok/metrics.py [0:0]


def lp_path_norm(model, device, p=2, input_size=[3, 32, 32]):
    """
    Path norm (Neyshabur 2015)
    """

    tmp_model = copy.deepcopy(model)
    tmp_model.eval()
    for param in tmp_model.parameters():
        if param.requires_grad:
            param.abs_().pow_(p)
    data_ones = torch.ones(input_size).to(device)
    return (tmp_model(data_ones).sum() ** (1 / p)).item()