in candle-core/src/metal_backend/mod.rs [1115:1195]
fn conv2d(
&self,
layout: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConv2D,
) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let stride = params.stride;
let dilation = params.dilation;
let padding = params.padding;
let h_k = params.k_h;
let w_k = params.k_w;
let h = dims[2];
let w = dims[3];
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k;
let dst = self
.device
.new_buffer(dst_el, self.dtype, "conv2d_im2col")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col_f32",
DType::F16 => "im2col_f16",
DType::BF16 => "im2col_bf16",
DType::U8 => "im2col_u8",
DType::U32 => "im2col_u32",
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
layout.shape().dims(),
layout.stride(),
(h_k, w_k, stride, padding, dilation),
src,
&dst,
)
.map_err(MetalError::from)?;
let col = Self {
buffer: dst,
device,
count: dst_el,
dtype: self.dtype,
};
let h_out = params.out_h();
let w_out = params.out_w();
let b = params.b_size;
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}