in crates/ratchet-core/src/op.rs [278:362]
fn select_kernel(&self) -> Self::KernelEnum;
fn compile_gpu(
&self,
dst: &Tensor,
uniform: &mut CpuUniform,
device: &WgpuDevice,
can_inplace: bool,
debug: bool,
) -> Result<CompiledOp, OperationError> {
let kernel = self.select_kernel();
let kernel_element = kernel.kernel_element(dst);
let metadata = kernel.metadata(dst, &kernel_element)?;
let offset = metadata.write(uniform)?;
let workload = kernel.calculate_dispatch(dst)?;
let storage_layout = device
.get_or_create_bind_group_layout(&kernel.storage_bind_group_layout(can_inplace)?)?;
let uniform_layout =
device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::uniform())?;
let pipeline_layout = device.get_or_create_pipeline_layout(&PipelineLayoutDescriptor {
entries: rvec![storage_layout, uniform_layout],
})?;
let key = kernel.kernel_key(
&workload.workgroup_size,
can_inplace,
&self.srcs(),
dst,
&kernel_element,
);
log::debug!("Kernel key: {}", key);
let kernel_src_desc = KernelModuleDesc { key: key.clone() };
let kernel_module = device.get_or_create_compute_module(
&kernel_src_desc,
&kernel,
can_inplace,
dst,
&workload.workgroup_size,
dst.device().try_gpu().unwrap(),
);
let pipeline_descriptor = ComputePipelineDescriptor {
pipeline_layout,
kernel_key: kernel_src_desc.key.clone(),
kernel_module,
};
let pipeline_handle = device.get_or_create_compute_pipeline(&pipeline_descriptor)?;
//TODO: Not sure i like this call here
let storage_bind_groups = CompiledOp::create_storage_bind_groups(
&self.srcs(),
dst,
rvec![storage_layout],
device,
can_inplace,
)?;
#[cfg(feature = "debug")]
let debug_buffer = if debug {
Some(Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
label: Some("debug buffer"),
size: dst.num_bytes() as _,
usage: wgpu::BufferUsages::standard(),
mapped_at_creation: false,
})))
} else {
None
};
Ok(CompiledOp::new(
pipeline_handle,
workload.workgroup_count,
storage_bind_groups,
offset as _,
kernel_src_desc.key,
#[cfg(feature = "debug")]
debug_buffer,
))
}