in crates/ratchet-core/src/ops/cache.rs [51:105]
fn render<P: WgslPrimitive>(
&self,
inplace: bool,
dst: &Tensor,
workgroup_size: &WorkgroupSize,
) -> Result<KernelSource, OperationError> {
let device = dst.device().try_gpu()?;
let mut kernel_builder = WgslKernelBuilder::new(
workgroup_size.clone(),
rvec![
BuiltIn::WorkgroupId,
BuiltIn::LocalInvocationIndex,
BuiltIn::NumWorkgroups,
],
device.compute_features().clone(),
);
self.register_bindings::<P>(&mut kernel_builder, inplace)?;
kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);
kernel_builder.write_offset_to_index();
kernel_builder.write_index_to_offset();
kernel_builder.write_main(wgsl! {
//Dispatch 1 thread per output element
//dst_offset is index into the output buffer (1D)
let x_offset = workgroup_id.x * 64u;
let dst_offset = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index;
if (dst_offset >= metadata.dst_numel) {
return;
}
//Convert 1D offset into 4D index
var dst_index = offsetToNdIndex(dst_offset, metadata.dst_stride);
let dim = metadata.dim;
if (dst_index[dim] < metadata.cum0) {
//Inside cache, just copy from cache to DST
let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride);
D[dst_offset] = C[src_offset];
return;
}
if (dst_index[dim] < metadata.cum1) {
//Inside src, copy from src to cache and then to DST
let cache_offset = ndIndexToOffset(dst_index, metadata.cache_stride);
dst_index[dim] -= metadata.cum0;
let src_offset = ndIndexToOffset(dst_index, metadata.src_stride);
let val = S[src_offset];
C[cache_offset] = val;
D[dst_offset] = val;
return;
}
});
Ok(kernel_builder.build()?)
}