in src/train.py [0:0]
def clip_grad_norms(param_groups, max_norm=math.inf):
"""Clips the norms for all param groups to max_norm and returns gradient norms before clipping
"""
grad_norms = [
torch.nn.utils.clip_grad_norm_(
group['params'],
max_norm if max_norm > 0 else math.inf, # Inf so no clipping but still call to calc
norm_type=2
)
for group in param_groups
]
grad_norms_clipped = [min(g_norm, max_norm) for g_norm in grad_norms] if max_norm > 0 else grad_norms
return grad_norms, grad_norms_clipped