in candle-core/src/metal_backend/mod.rs [510:625]
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device();
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
let command_buffer = device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U32, DType::F16) => "cast_u32_f16",
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::U8, DType::F16) => "cast_u8_f16",
(DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::I64) => "cast_f32_i64",
(DType::F32, DType::U32) => "cast_f32_u32",
(DType::F32, DType::U8) => "cast_f32_u8",
(DType::I64, DType::BF16) => "cast_i64_bf16",
(DType::I64, DType::F16) => "cast_i64_f16",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::I64, DType::U32) => "cast_i64_u32",
(DType::I64, DType::U8) => "cast_i64_u8",
(DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32",
(DType::F16, DType::I64) => "cast_f16_i64",
(DType::F16, DType::U32) => "cast_f16_u32",
(DType::F16, DType::U8) => "cast_f16_u8",
(DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32",
(DType::BF16, DType::I64) => "cast_bf16_i64",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
}
};
candle_metal_kernels::call_cast_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let kernel_name = match (self.dtype, dtype) {
(DType::BF16, DType::F16) => "cast_bf16_f16_strided",
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
(DType::BF16, DType::I64) => "cast_bf16_i64_strided",
(DType::BF16, DType::U32) => "cast_bf16_u32_strided",
(DType::BF16, DType::U8) => "cast_bf16_u8_strided",
(DType::F16, DType::BF16) => "cast_f16_bf16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(DType::F16, DType::I64) => "cast_f16_i64_strided",
(DType::F16, DType::U32) => "cast_f16_u32_strided",
(DType::F16, DType::U8) => "cast_f16_u8_strided",
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F32, DType::I64) => "cast_f32_i64_strided",
(DType::F32, DType::U32) => "cast_f32_u32_strided",
(DType::F32, DType::U8) => "cast_f32_u8_strided",
(DType::I64, DType::F32) => "cast_i64_f32_strided",
(DType::I64, DType::BF16) => "cast_i64_bf16_strided",
(DType::I64, DType::F16) => "cast_i64_f16_strided",
(DType::I64, DType::U32) => "cast_i64_u32_strided",
(DType::I64, DType::U8) => "cast_i64_u8_strided",
(DType::U32, DType::BF16) => "cast_u32_bf16_strided",
(DType::U32, DType::F16) => "cast_u32_f16_strided",
(DType::U32, DType::F32) => "cast_u32_f32_strided",
(DType::U32, DType::I64) => "cast_u32_i64_strided",
(DType::U32, DType::U8) => "cast_u32_u8_strided",
(DType::U8, DType::BF16) => "cast_u8_bf16_strided",
(DType::U8, DType::F16) => "cast_u8_f16_strided",
(DType::U8, DType::F32) => "cast_u8_f32_strided",
(DType::U8, DType::I64) => "cast_u8_i64_strided",
(DType::U8, DType::U32) => "cast_u8_u32_strided",
(left, right) => {
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
}
};
candle_metal_kernels::call_cast_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
src,
layout.stride(),
&buffer,
)
.map_err(MetalError::from)?;
}
command_buffer.set_label("to_dtype");
Ok(Self::new(buffer, device.clone(), el_count, dtype))
}