in crates/ratchet-core/src/ops/reindex/mod.rs [48:124]
fn render<P: WgslPrimitive>(
&self,
inplace: bool,
dst: &Tensor,
workgroup_size: &WorkgroupSize,
) -> Result<KernelSource, OperationError> {
let device = dst.device().try_gpu().unwrap();
let mut kernel_builder = WgslKernelBuilder::new(
workgroup_size.clone(),
rvec![
BuiltIn::LocalInvocationIndex,
BuiltIn::WorkgroupId,
BuiltIn::NumWorkgroups,
],
device.compute_features().clone(),
);
self.register_bindings::<P>(&mut kernel_builder, inplace)?;
kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);
let n = P::W;
//Custom with slice offset
kernel_builder.write_global(wgsl! {
//Converts 4D index into 1D offset
fn ndIndexToOffset(index: vec4<u32>, src_offsets: vec4<u32>, stride: vec4<u32>) -> u32 {
var offset: u32 = 0u;
offset = dot(index + src_offsets, stride);
return offset;
}
});
kernel_builder.write_offset_to_index();
kernel_builder.write_main(wgsl! {
//Dispatch 1 thread per output element
//dst_offset is index into the output buffer (1D)
let x_offset = workgroup_id.x * 64u;
var dst_offset = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index;
if (dst_offset >= metadata.dst_numel / 'n) {
return;
}
//Convert 1D offset into 4D index
let dst_index = offsetToNdIndex(dst_offset, metadata.dst_stride);
});
let ReindexKernels::Standard(inner) = self;
let body = match inner {
Reindex::Permute(_) => wgsl! {
var src_index = vec4<u32>(0u);
src_index[metadata.perm[0]] = dst_index[0];
src_index[metadata.perm[1]] = dst_index[1];
src_index[metadata.perm[2]] = dst_index[2];
src_index[metadata.perm[3]] = dst_index[3];
},
Reindex::Slice(_) => wgsl! { var src_index = dst_index; },
Reindex::Broadcast(_) => wgsl! {
// Broadcasting is valid if dims are equal, or if one of the dims is 1
var src_index = select(dst_index, vec4<u32>(0u), metadata.src_shape == vec4<u32>(1u));
},
};
kernel_builder.write_main(body);
let src_offsets = match inner {
Reindex::Slice(_) => wgsl! { metadata.src_offsets },
_ => wgsl! { vec4<u32>(0u) },
};
kernel_builder.write_main(wgsl! {
//Convert 4D index into 1D offset
let src_offset = ndIndexToOffset(src_index, 'src_offsets, metadata.src_stride);
//Read from input buffer and write to output buffer
Y[dst_offset] = X[src_offset];
});
Ok(kernel_builder.build()?)
}