in src/lib.rs [211:224]
fn cuda_fwd(
&self,
x: &candle::CudaStorage,
x_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match x.dtype() {
DType::F16 => self.fwd::<f16>(x, x_l, None, None),
DType::BF16 => self.fwd::<bf16>(x, x_l, None, None),
DType::F32 => self.fwd::<f32>(x, x_l, None, None),
dt => {
candle::bail!("fused-layer-norm is only supported for f32, f16 and bf16 ({dt:?})")
}
}
}