fn cuda_fwd()

in candle-layer-norm/src/lib.rs [242:257]


    fn cuda_fwd(
        &self,
        x: &candle::CudaStorage,
        x_l: &Layout,
        r: &candle::CudaStorage,
        r_l: &Layout,
    ) -> Result<(candle::CudaStorage, Shape)> {
        match x.dtype() {
            DType::F16 => self.fwd::<f16>(x, x_l, Some(r), Some(r_l)),
            DType::BF16 => self.fwd::<bf16>(x, x_l, Some(r), Some(r_l)),
            DType::F32 => self.fwd::<f32>(x, x_l, Some(r), Some(r_l)),
            dt => {
                candle::bail!("fused-layer-norm is only supported for f32, f16 and bf16 ({dt:?})")
            }
        }
    }