in candle-core/src/metal_backend/mod.rs [266:402]
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
// Source dims and strides with the sum dims at the end.
let mut dims = vec![];
let mut stride = vec![];
let mut dst_el: usize = 1;
for (dim_idx, &d) in src_dims.iter().enumerate() {
if !sum_dims.contains(&dim_idx) {
dst_el *= d;
dims.push(d);
stride.push(src_stride[dim_idx]);
}
}
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}
let reduction_shape = Shape::from(dims.clone());
if layout.is_contiguous() && reduction_shape.is_contiguous(&stride) {
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
(k, dtype) => {
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
}
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
src_dims,
dst_el,
src,
&buffer,
)
.map_err(MetalError::from)?;
return Ok(Self::new(buffer, device, dst_el, dtype));
}
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false),
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true),
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8_strided", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8_strided", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
&dims,
&stride,
dst_el,
src,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, device, dst_el, dtype))
}