in candle-nn/src/ops.rs [138:229]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::MetalError;
let device = storage.device();
let dtype = storage.dtype();
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
let command_buffer = device.command_buffer()?;
command_buffer.set_label("sigmoid");
let src = candle_metal_kernels::BufferOffset {
buffer: storage.buffer(),
offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
};
match (el_count % 2, dtype, layout.is_contiguous()) {
(0, DType::BF16 | DType::F16, true) => {
use candle_metal_kernels::unary::contiguous_tiled;
let kernel_name = match dtype {
DType::F16 => contiguous_tiled::sigmoid::HALF,
DType::F32 => contiguous_tiled::sigmoid::FLOAT,
DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
dtype => {
candle::bail!(
"Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
)
}
};
candle_metal_kernels::call_unary_contiguous_tiled(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, true) => {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match dtype {
DType::F16 => contiguous::sigmoid::HALF,
DType::F32 => contiguous::sigmoid::FLOAT,
DType::BF16 => contiguous::sigmoid::BFLOAT,
dtype => {
candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
}
};
candle_metal_kernels::call_unary_contiguous(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, false) => {
use candle_metal_kernels::unary::strided;
let kernel_name = match dtype {
DType::F16 => strided::sigmoid::HALF,
DType::F32 => strided::sigmoid::FLOAT,
DType::BF16 => strided::sigmoid::BFLOAT,
dtype => {
candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
}
};
let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
candle_metal_kernels::call_unary_strided(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
layout.dims(),
src,
layout.stride(),
dst,
)
.map_err(MetalError::from)?;
}
}
let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
Ok((new_storage, layout.shape().clone()))
}