in jukebox/utils/fp16.py [0:0]
def grad_norm(params, scale, flat=False):
params = list(params)
if flat:
# Faster but more memory
fp16_grads = [p.grad for p in params if p.grad is not None and p.data.dtype == torch.float16]
fp16_norm = 0.0 if len(fp16_grads) == 0 else float(_flatten_dense_tensors(fp16_grads).norm(p=2, dtype=torch.float32))
fp32_grads = [p.grad for p in params if p.grad is not None and p.data.dtype != torch.float16]
fp32_norm = 0.0 if len(fp32_grads) == 0 else float(_flatten_dense_tensors(fp32_grads).norm(p=2))
grad_norm = (fp16_norm**2 + fp32_norm**2)**0.5
else:
# Slightly slower but less memory
grad_norm = 0.0
for p in params:
if p.grad is not None:
grad_norm += p.grad.norm(p=2, dtype=torch.float32)**2
grad_norm = float(grad_norm**0.5)
return grad_norm / scale