def grad_norm()

in jukebox/utils/fp16.py [0:0]


def grad_norm(params, scale, flat=False):
    params = list(params)
    if flat:
        # Faster but more memory
        fp16_grads = [p.grad for p in params if p.grad is not None and p.data.dtype == torch.float16]
        fp16_norm = 0.0 if len(fp16_grads) == 0 else float(_flatten_dense_tensors(fp16_grads).norm(p=2, dtype=torch.float32))
        fp32_grads = [p.grad for p in params if p.grad is not None and p.data.dtype != torch.float16]
        fp32_norm = 0.0 if len(fp32_grads) == 0 else float(_flatten_dense_tensors(fp32_grads).norm(p=2))
        grad_norm = (fp16_norm**2 + fp32_norm**2)**0.5
    else:
        # Slightly slower but less memory
        grad_norm = 0.0
        for p in params:
            if p.grad is not None:
                grad_norm += p.grad.norm(p=2, dtype=torch.float32)**2
        grad_norm = float(grad_norm**0.5)
    return grad_norm / scale