fn select_kernel()

in crates/ratchet-core/src/op.rs [278:362]


    fn select_kernel(&self) -> Self::KernelEnum;

    fn compile_gpu(
        &self,
        dst: &Tensor,
        uniform: &mut CpuUniform,
        device: &WgpuDevice,
        can_inplace: bool,
        debug: bool,
    ) -> Result<CompiledOp, OperationError> {
        let kernel = self.select_kernel();

        let kernel_element = kernel.kernel_element(dst);
        let metadata = kernel.metadata(dst, &kernel_element)?;
        let offset = metadata.write(uniform)?;

        let workload = kernel.calculate_dispatch(dst)?;

        let storage_layout = device
            .get_or_create_bind_group_layout(&kernel.storage_bind_group_layout(can_inplace)?)?;

        let uniform_layout =
            device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::uniform())?;
        let pipeline_layout = device.get_or_create_pipeline_layout(&PipelineLayoutDescriptor {
            entries: rvec![storage_layout, uniform_layout],
        })?;

        let key = kernel.kernel_key(
            &workload.workgroup_size,
            can_inplace,
            &self.srcs(),
            dst,
            &kernel_element,
        );
        log::debug!("Kernel key: {}", key);

        let kernel_src_desc = KernelModuleDesc { key: key.clone() };

        let kernel_module = device.get_or_create_compute_module(
            &kernel_src_desc,
            &kernel,
            can_inplace,
            dst,
            &workload.workgroup_size,
            dst.device().try_gpu().unwrap(),
        );

        let pipeline_descriptor = ComputePipelineDescriptor {
            pipeline_layout,
            kernel_key: kernel_src_desc.key.clone(),
            kernel_module,
        };
        let pipeline_handle = device.get_or_create_compute_pipeline(&pipeline_descriptor)?;

        //TODO: Not sure i like this call here
        let storage_bind_groups = CompiledOp::create_storage_bind_groups(
            &self.srcs(),
            dst,
            rvec![storage_layout],
            device,
            can_inplace,
        )?;

        #[cfg(feature = "debug")]
        let debug_buffer = if debug {
            Some(Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
                label: Some("debug buffer"),
                size: dst.num_bytes() as _,
                usage: wgpu::BufferUsages::standard(),
                mapped_at_creation: false,
            })))
        } else {
            None
        };

        Ok(CompiledOp::new(
            pipeline_handle,
            workload.workgroup_count,
            storage_bind_groups,
            offset as _,
            kernel_src_desc.key,
            #[cfg(feature = "debug")]
            debug_buffer,
        ))
    }