in sparse_autoencoder/train.py [0:0]
def sharded_grad_norm(autoencoder, comms, exclude=None):
if exclude is None:
exclude = []
total_sq_norm = torch.zeros((), device="cuda", dtype=torch.float32)
exclude = set(exclude)
total_num_params = 0
for param in autoencoder.parameters():
if param in exclude:
continue
if param.grad is not None:
sq_norm = ((param.grad).float() ** 2).sum()
if param is autoencoder.pre_bias:
total_sq_norm += sq_norm # pre_bias is the same across all shards
else:
total_sq_norm += comms.sh_sum(sq_norm)
param_shards = comms.n_op_shards if param is autoencoder.pre_bias else 1
total_num_params += param.numel() * param_shards
return total_sq_norm.sqrt()