in crypten/common/functions/power.py [0:0]
def norm(self, p="fro", dim=None, keepdim=False):
"""Computes the p-norm of the input tensor (or along a dimension)."""
if p == "fro":
p = 2
if isinstance(p, (int, float)):
assert p >= 1, "p-norm requires p >= 1"
if p == 1:
if dim is None:
return self.abs().sum()
return self.abs().sum(dim, keepdim=keepdim)
elif p == 2:
if dim is None:
return self.square().sum().sqrt()
return self.square().sum(dim, keepdim=keepdim).sqrt()
elif p == float("inf"):
if dim is None:
return self.abs().max()
return self.abs().max(dim=dim, keepdim=keepdim)[0]
else:
if dim is None:
return self.abs().pos_pow(p).sum().pos_pow(1 / p)
return self.abs().pos_pow(p).sum(dim, keepdim=keepdim).pos_pow(1 / p)
elif p == "nuc":
raise NotImplementedError("Nuclear norm is not implemented")
else:
raise ValueError(f"Improper value p ({p})for p-norm")