fn unary_impl()

in candle-core/src/metal_backend/mod.rs [627:869]


    fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
        let device = self.device();
        let dtype = self.dtype;
        let shape = layout.shape();
        let el_count = shape.elem_count();
        let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
        let command_buffer = device.command_buffer()?;
        command_buffer.set_label(B::KERNEL);
        let src = buffer_o(&self.buffer, layout, self.dtype);

        match (el_count % 2, dtype, layout.is_contiguous()) {
            (0, DType::BF16 | DType::F16, true) => {
                use candle_metal_kernels::unary::contiguous_tiled;
                let kernel_name = match (B::KERNEL, dtype) {
                    ("uabs", DType::F16) => contiguous_tiled::abs::HALF,
                    ("uabs", DType::F32) => contiguous_tiled::abs::FLOAT,
                    ("uabs", DType::BF16) => contiguous_tiled::abs::BFLOAT,
                    ("uceil", DType::F16) => contiguous_tiled::ceil::HALF,
                    ("uceil", DType::F32) => contiguous_tiled::ceil::FLOAT,
                    ("uceil", DType::BF16) => contiguous_tiled::ceil::BFLOAT,
                    ("ucos", DType::F16) => contiguous_tiled::cos::HALF,
                    ("ucos", DType::F32) => contiguous_tiled::cos::FLOAT,
                    ("ucos", DType::BF16) => contiguous_tiled::cos::BFLOAT,
                    ("uerf", DType::F16) => contiguous_tiled::erf::HALF,
                    ("uerf", DType::F32) => contiguous_tiled::erf::FLOAT,
                    ("uerf", DType::BF16) => contiguous_tiled::erf::BFLOAT,
                    ("uexp", DType::F16) => contiguous_tiled::exp::HALF,
                    ("uexp", DType::F32) => contiguous_tiled::exp::FLOAT,
                    ("uexp", DType::BF16) => contiguous_tiled::exp::BFLOAT,
                    ("ufloor", DType::F16) => contiguous_tiled::floor::HALF,
                    ("ufloor", DType::F32) => contiguous_tiled::floor::FLOAT,
                    ("ufloor", DType::BF16) => contiguous_tiled::floor::BFLOAT,
                    ("ugelu_erf", DType::F16) => contiguous_tiled::gelu_erf::HALF,
                    ("ugelu_erf", DType::F32) => contiguous_tiled::gelu_erf::FLOAT,
                    ("ugelu_erf", DType::BF16) => contiguous_tiled::gelu_erf::BFLOAT,
                    ("ugelu", DType::F16) => contiguous_tiled::gelu::HALF,
                    ("ugelu", DType::F32) => contiguous_tiled::gelu::FLOAT,
                    ("ugelu", DType::BF16) => contiguous_tiled::gelu::BFLOAT,
                    ("ulog", DType::F16) => contiguous_tiled::log::HALF,
                    ("ulog", DType::F32) => contiguous_tiled::log::FLOAT,
                    ("ulog", DType::BF16) => contiguous_tiled::log::BFLOAT,
                    ("uneg", DType::F16) => contiguous_tiled::neg::HALF,
                    ("uneg", DType::F32) => contiguous_tiled::neg::FLOAT,
                    ("uneg", DType::BF16) => contiguous_tiled::neg::BFLOAT,
                    ("urecip", DType::F16) => contiguous_tiled::recip::HALF,
                    ("urecip", DType::F32) => contiguous_tiled::recip::FLOAT,
                    ("urecip", DType::BF16) => contiguous_tiled::recip::BFLOAT,
                    ("urelu", DType::F16) => contiguous_tiled::relu::HALF,
                    ("urelu", DType::F32) => contiguous_tiled::relu::FLOAT,
                    ("urelu", DType::BF16) => contiguous_tiled::relu::BFLOAT,
                    ("uround", DType::F16) => contiguous_tiled::round::HALF,
                    ("uround", DType::F32) => contiguous_tiled::round::FLOAT,
                    ("uround", DType::BF16) => contiguous_tiled::round::BFLOAT,
                    ("usilu", DType::F16) => contiguous_tiled::silu::HALF,
                    ("usilu", DType::F32) => contiguous_tiled::silu::FLOAT,
                    ("usilu", DType::BF16) => contiguous_tiled::silu::BFLOAT,
                    ("usin", DType::F16) => contiguous_tiled::sin::HALF,
                    ("usin", DType::F32) => contiguous_tiled::sin::FLOAT,
                    ("usin", DType::BF16) => contiguous_tiled::sin::BFLOAT,
                    ("usqr", DType::F16) => contiguous_tiled::sqr::HALF,
                    ("usqr", DType::F32) => contiguous_tiled::sqr::FLOAT,
                    ("usqr", DType::BF16) => contiguous_tiled::sqr::BFLOAT,
                    ("usqrt", DType::F16) => contiguous_tiled::sqrt::HALF,
                    ("usqrt", DType::F32) => contiguous_tiled::sqrt::FLOAT,
                    ("usqrt", DType::BF16) => contiguous_tiled::sqrt::BFLOAT,
                    ("utanh", DType::F16) => contiguous_tiled::tanh::HALF,
                    ("utanh", DType::F32) => contiguous_tiled::tanh::FLOAT,
                    ("utanh", DType::BF16) => contiguous_tiled::tanh::BFLOAT,
                    ("usign", DType::F16) => contiguous_tiled::sign::HALF,
                    ("usign", DType::F32) => contiguous_tiled::sign::FLOAT,
                    ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT,
                    ("usign", DType::I64) => contiguous_tiled::sign::I64,
                    (name, dtype) => {
                        crate::bail!(
                            "Metal contiguous_tiled unary {name} {dtype:?} not implemented"
                        )
                    }
                };
                candle_metal_kernels::call_unary_contiguous_tiled(
                    &device.device,
                    &command_buffer,
                    &device.kernels,
                    kernel_name,
                    el_count,
                    src,
                    &buffer,
                )
                .map_err(MetalError::from)?;
            }
            (_, _, true) => {
                use candle_metal_kernels::unary::contiguous;
                let kernel_name = match (B::KERNEL, dtype) {
                    ("uabs", DType::F16) => contiguous::abs::HALF,
                    ("uabs", DType::F32) => contiguous::abs::FLOAT,
                    ("uabs", DType::BF16) => contiguous::abs::BFLOAT,
                    ("uceil", DType::F16) => contiguous::ceil::HALF,
                    ("uceil", DType::F32) => contiguous::ceil::FLOAT,
                    ("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
                    ("ucos", DType::F16) => contiguous::cos::HALF,
                    ("ucos", DType::F32) => contiguous::cos::FLOAT,
                    ("ucos", DType::BF16) => contiguous::cos::BFLOAT,
                    ("uerf", DType::F16) => contiguous::erf::HALF,
                    ("uerf", DType::F32) => contiguous::erf::FLOAT,
                    ("uerf", DType::BF16) => contiguous::erf::BFLOAT,
                    ("uexp", DType::F16) => contiguous::exp::HALF,
                    ("uexp", DType::F32) => contiguous::exp::FLOAT,
                    ("uexp", DType::BF16) => contiguous::exp::BFLOAT,
                    ("ufloor", DType::F16) => contiguous::floor::HALF,
                    ("ufloor", DType::F32) => contiguous::floor::FLOAT,
                    ("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
                    ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
                    ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
                    ("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
                    ("ugelu", DType::F16) => contiguous::gelu::HALF,
                    ("ugelu", DType::F32) => contiguous::gelu::FLOAT,
                    ("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
                    ("ulog", DType::F16) => contiguous::log::HALF,
                    ("ulog", DType::F32) => contiguous::log::FLOAT,
                    ("ulog", DType::BF16) => contiguous::log::BFLOAT,
                    ("uneg", DType::F16) => contiguous::neg::HALF,
                    ("uneg", DType::F32) => contiguous::neg::FLOAT,
                    ("uneg", DType::BF16) => contiguous::neg::BFLOAT,
                    ("urecip", DType::F16) => contiguous::recip::HALF,
                    ("urecip", DType::F32) => contiguous::recip::FLOAT,
                    ("urecip", DType::BF16) => contiguous::recip::BFLOAT,
                    ("urelu", DType::F16) => contiguous::relu::HALF,
                    ("urelu", DType::F32) => contiguous::relu::FLOAT,
                    ("urelu", DType::BF16) => contiguous::relu::BFLOAT,
                    ("uround", DType::F16) => contiguous::round::HALF,
                    ("uround", DType::F32) => contiguous::round::FLOAT,
                    ("uround", DType::BF16) => contiguous::round::BFLOAT,
                    ("usilu", DType::F16) => contiguous::silu::HALF,
                    ("usilu", DType::F32) => contiguous::silu::FLOAT,
                    ("usilu", DType::BF16) => contiguous::silu::BFLOAT,
                    ("usin", DType::F16) => contiguous::sin::HALF,
                    ("usin", DType::F32) => contiguous::sin::FLOAT,
                    ("usin", DType::BF16) => contiguous::sin::BFLOAT,
                    ("usqr", DType::F16) => contiguous::sqr::HALF,
                    ("usqr", DType::F32) => contiguous::sqr::FLOAT,
                    ("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
                    ("usqrt", DType::F16) => contiguous::sqrt::HALF,
                    ("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
                    ("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
                    ("utanh", DType::F16) => contiguous::tanh::HALF,
                    ("utanh", DType::F32) => contiguous::tanh::FLOAT,
                    ("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
                    ("usign", DType::F16) => contiguous::sign::HALF,
                    ("usign", DType::F32) => contiguous::sign::FLOAT,
                    ("usign", DType::BF16) => contiguous::sign::BFLOAT,
                    ("usign", DType::I64) => contiguous::sign::I64,
                    (name, dtype) => {
                        crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
                    }
                };
                candle_metal_kernels::call_unary_contiguous(
                    &device.device,
                    &command_buffer,
                    &device.kernels,
                    kernel_name,
                    el_count,
                    src,
                    &buffer,
                )
                .map_err(MetalError::from)?;
            }
            (_, _, false) => {
                use candle_metal_kernels::unary::strided;
                let kernel_name = match (B::KERNEL, dtype) {
                    ("ucos", DType::F32) => strided::cos::FLOAT,
                    ("usin", DType::F32) => strided::sin::FLOAT,
                    ("usqr", DType::F32) => strided::sqr::FLOAT,
                    ("usqrt", DType::F32) => strided::sqrt::FLOAT,
                    ("uneg", DType::F32) => strided::neg::FLOAT,
                    ("uexp", DType::F32) => strided::exp::FLOAT,
                    ("ulog", DType::F32) => strided::log::FLOAT,
                    ("ugelu", DType::F32) => strided::gelu::FLOAT,
                    ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
                    ("uerf", DType::F32) => strided::erf::FLOAT,
                    ("usilu", DType::F32) => strided::silu::FLOAT,
                    ("uabs", DType::F32) => strided::abs::FLOAT,
                    ("uceil", DType::F32) => strided::ceil::FLOAT,
                    ("ufloor", DType::F32) => strided::floor::FLOAT,
                    ("urelu", DType::F32) => strided::relu::FLOAT,
                    ("uround", DType::F32) => strided::round::FLOAT,
                    ("utanh", DType::F32) => strided::tanh::FLOAT,

                    ("ucos", DType::F16) => strided::cos::HALF,
                    ("usin", DType::F16) => strided::sin::HALF,
                    ("usqr", DType::F16) => strided::sqr::HALF,
                    ("usqrt", DType::F16) => strided::sqrt::HALF,
                    ("uneg", DType::F16) => strided::neg::HALF,
                    ("uexp", DType::F16) => strided::exp::HALF,
                    ("ulog", DType::F16) => strided::log::HALF,
                    ("ugelu", DType::F16) => strided::gelu::HALF,
                    ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
                    ("uerf", DType::F16) => strided::erf::HALF,
                    ("usilu", DType::F16) => strided::silu::HALF,
                    ("uabs", DType::F16) => strided::abs::HALF,
                    ("uceil", DType::F16) => strided::ceil::HALF,
                    ("ufloor", DType::F16) => strided::floor::HALF,
                    ("urelu", DType::F16) => strided::relu::HALF,
                    ("uround", DType::F16) => strided::round::HALF,
                    ("utanh", DType::F16) => strided::tanh::HALF,

                    ("ucos", DType::BF16) => strided::cos::BFLOAT,
                    ("usin", DType::BF16) => strided::sin::BFLOAT,
                    ("usqr", DType::BF16) => strided::sqr::BFLOAT,
                    ("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
                    ("uneg", DType::BF16) => strided::neg::BFLOAT,
                    ("uexp", DType::BF16) => strided::exp::BFLOAT,
                    ("ulog", DType::BF16) => strided::log::BFLOAT,
                    ("ugelu", DType::BF16) => strided::gelu::BFLOAT,
                    ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
                    ("uerf", DType::BF16) => strided::erf::BFLOAT,
                    ("usilu", DType::BF16) => strided::silu::BFLOAT,
                    ("uabs", DType::BF16) => strided::abs::BFLOAT,
                    ("uceil", DType::BF16) => strided::ceil::BFLOAT,
                    ("ufloor", DType::BF16) => strided::floor::BFLOAT,
                    ("urelu", DType::BF16) => strided::relu::BFLOAT,
                    ("uround", DType::BF16) => strided::round::BFLOAT,
                    ("utanh", DType::BF16) => strided::tanh::BFLOAT,

                    (name, dtype) => {
                        crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
                    }
                };
                let dst = BufferOffset::zero_offset(&buffer);
                candle_metal_kernels::call_unary_strided(
                    &device.device,
                    &command_buffer,
                    &device.kernels,
                    kernel_name,
                    layout.dims(),
                    src,
                    layout.stride(),
                    dst,
                )
                .map_err(MetalError::from)?;
            }
        }

        Ok(Self::new(buffer, device.clone(), el_count, dtype))
    }