in muse/modeling_transformer_v2.py [0:0]
def unfused_rms_norm(input, residual, weight, eps):
if residual is not None:
input = input + residual
prenorm_residual = input
input_dtype = input.dtype
variance = input.to(torch.float32).pow(2).mean(-1, keepdim=True)
input = input * torch.rsqrt(variance + eps)
if weight is not None:
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(weight.dtype)
input = input * weight
else:
input = input.to(input_dtype)
return input, prenorm_residual