def unfused_rms_norm()

in muse/modeling_transformer_v2.py [0:0]


def unfused_rms_norm(input, residual, weight, eps):
    if residual is not None:
        input = input + residual

    prenorm_residual = input

    input_dtype = input.dtype
    variance = input.to(torch.float32).pow(2).mean(-1, keepdim=True)
    input = input * torch.rsqrt(variance + eps)

    if weight is not None:
        # convert into half-precision if necessary
        if weight.dtype in [torch.float16, torch.bfloat16]:
            input = input.to(weight.dtype)
        input = input * weight
    else:
        input = input.to(input_dtype)

    return input, prenorm_residual