fn conv_transpose1d()

in candle-core/src/metal_backend/mod.rs [1004:1113]


    fn conv_transpose1d(
        &self,
        layout: &Layout,
        k: &Self,
        k_layout: &Layout,
        params: &ParamsConvTranspose1D,
    ) -> Result<Self> {
        const USE_COL2IM_CONV1D_TR: bool = true;

        let can_use_col2im = k_layout.is_contiguous()
            && params.dilation == 1
            && params.padding == 0
            && params.output_padding == 0;
        let l_out = params.l_out();
        let dst_el = params.c_out * l_out * params.b_size;

        let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im {
            let (b_size, c_in, l_in) = layout.shape().dims3()?;
            let (c_in2, c_out, k_size) = k_layout.shape().dims3()?;
            if c_in != c_in2 {
                crate::bail!(
                    "convtr1d: shape mismatch on c_in {:?} {:?}",
                    layout.shape(),
                    k_layout.shape()
                )
            }
            let buffer = self
                .device
                .new_buffer(dst_el, self.dtype, "conv_transpose1d")?;

            let name = match self.dtype {
                DType::F32 => "col2im1d_f32",
                DType::U32 => "col2im1d_u32",
                DType::U8 => "col2im1d_u8",
                dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"),
            };
            let col = {
                // This merges the last two dimensions of the kernel together.
                let kernel_l_mm = Layout::new(
                    (b_size, c_in, k_size * c_out).into(),
                    vec![0, k_size * c_out, 1],
                    k_layout.start_offset(),
                );
                self.matmul(
                    k,
                    (b_size, l_in, c_out * k_size, c_in),
                    &layout.transpose(1, 2)?,
                    &kernel_l_mm,
                )?
            };
            // It is important for the command buffer to be obtained *after* the matmul
            // kernel has run, otherwise we might use a command-buffer that has been committed
            // already resulting in the following error.
            // _status < MTLCommandBufferStatusCommitted >
            // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
            let command_buffer = self.device.command_buffer()?;
            candle_metal_kernels::call_col2im1d(
                &self.device.device,
                &command_buffer,
                &self.device.kernels,
                name,
                &[b_size, l_in, c_out, k_size],
                params.k_size,
                params.stride,
                BufferOffset::zero_offset(&col.buffer),
                &buffer,
            )
            .map_err(MetalError::from)?;
            buffer
        } else {
            let buffer = self
                .device
                .new_buffer(dst_el, self.dtype, "conv_transpose1d")?;

            let command_buffer = self.device.command_buffer()?;
            let name = match self.dtype {
                DType::F32 => "conv_transpose1d_f32",
                DType::F16 => "conv_transpose1d_f16",
                DType::BF16 => "conv_transpose1d_bf16",
                DType::U32 => "conv_transpose1d_u32",
                DType::U8 => "conv_transpose1d_u8",
                dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
            };
            candle_metal_kernels::call_conv_transpose1d(
                &self.device.device,
                &command_buffer,
                &self.device.kernels,
                name,
                params.dilation,
                params.stride,
                params.padding,
                params.output_padding,
                params.c_out,
                l_out,
                params.b_size,
                layout.dims(),
                layout.stride(),
                k_layout.dims(),
                k_layout.stride(),
                &self.buffer,
                layout.start_offset() * self.dtype.size_in_bytes(),
                &k.buffer,
                k_layout.start_offset() * k.dtype.size_in_bytes(),
                &buffer,
            )
            .map_err(MetalError::from)?;
            buffer
        };
        Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
    }