fn const_set()

in candle-core/src/metal_backend/mod.rs [416:508]


    fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
        use crate::scalar::Scalar;
        fn set<S: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
            self_: &mut MetalStorage,
            s: S,
            l: &Layout,
        ) -> Result<()> {
            let device = self_.device();
            let dtype = self_.dtype;
            let shape = l.shape();
            let el_count = shape.elem_count();
            let command_buffer = device.command_buffer()?;
            command_buffer.set_label("const-set");
            let dst = buffer_o(&self_.buffer, l, self_.dtype);

            match (el_count % 2, dtype, l.is_contiguous()) {
                (0, DType::BF16 | DType::F16, true) => {
                    use candle_metal_kernels::unary::contiguous_tiled;
                    let kernel_name = match dtype {
                        DType::F16 => contiguous_tiled::const_set::HALF,
                        DType::BF16 => contiguous_tiled::const_set::BFLOAT,
                        _ => crate::bail!("internal bug in const_set"),
                    };
                    candle_metal_kernels::call_const_set_contiguous_tiled(
                        &device.device,
                        &command_buffer,
                        &device.kernels,
                        kernel_name,
                        el_count,
                        s,
                        dst,
                    )
                    .map_err(MetalError::from)?;
                }
                (_, _, true) => {
                    use candle_metal_kernels::unary::contiguous;
                    let kernel_name = match dtype {
                        DType::F16 => contiguous::const_set::HALF,
                        DType::BF16 => contiguous::const_set::BFLOAT,
                        DType::F32 => contiguous::const_set::FLOAT,
                        DType::I64 => contiguous::const_set::I64,
                        DType::U32 => contiguous::const_set::U32,
                        DType::U8 => contiguous::const_set::U8,
                        DType::F64 => crate::bail!("unsupported const-set f64"),
                    };
                    candle_metal_kernels::call_const_set_contiguous(
                        &device.device,
                        &command_buffer,
                        &device.kernels,
                        kernel_name,
                        el_count,
                        s,
                        dst,
                    )
                    .map_err(MetalError::from)?;
                }
                (_, _, false) => {
                    use candle_metal_kernels::unary::strided;
                    let kernel_name = match dtype {
                        DType::F16 => strided::const_set::HALF,
                        DType::BF16 => strided::const_set::BFLOAT,
                        DType::F32 => strided::const_set::FLOAT,
                        DType::I64 => strided::const_set::I64,
                        DType::U32 => strided::const_set::U32,
                        DType::U8 => strided::const_set::U8,
                        DType::F64 => crate::bail!("unsupported const-set f64"),
                    };
                    candle_metal_kernels::call_const_set_strided(
                        &device.device,
                        &command_buffer,
                        &device.kernels,
                        kernel_name,
                        l.dims(),
                        s,
                        l.stride(),
                        dst,
                    )
                    .map_err(MetalError::from)?;
                }
            }
            Ok(())
        }
        match (self.dtype, s) {
            (DType::U8, Scalar::U8(s)) => set(self, s, l),
            (DType::U32, Scalar::U32(s)) => set(self, s, l),
            (DType::I64, Scalar::I64(s)) => set(self, s, l),
            (DType::F16, Scalar::F16(s)) => set(self, s, l),
            (DType::BF16, Scalar::BF16(s)) => set(self, s, l),
            (DType::F32, Scalar::F32(s)) => set(self, s, l),
            (DType::F64, Scalar::F64(s)) => set(self, s, l),
            _ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s),
        }
    }