def sharded_grad_norm()

in sparse_autoencoder/train.py [0:0]


def sharded_grad_norm(autoencoder, comms, exclude=None):
    if exclude is None:
        exclude = []
    total_sq_norm = torch.zeros((), device="cuda", dtype=torch.float32)
    exclude = set(exclude)

    total_num_params = 0
    for param in autoencoder.parameters():
        if param in exclude:
            continue
        if param.grad is not None:
            sq_norm = ((param.grad).float() ** 2).sum()
            if param is autoencoder.pre_bias:
                total_sq_norm += sq_norm  # pre_bias is the same across all shards
            else:
                total_sq_norm += comms.sh_sum(sq_norm)

            param_shards = comms.n_op_shards if param is autoencoder.pre_bias else 1
            total_num_params += param.numel() * param_shards

    return total_sq_norm.sqrt()