fn render()

in crates/ratchet-core/src/ops/reindex/mod.rs [48:124]


    fn render<P: WgslPrimitive>(
        &self,
        inplace: bool,
        dst: &Tensor,
        workgroup_size: &WorkgroupSize,
    ) -> Result<KernelSource, OperationError> {
        let device = dst.device().try_gpu().unwrap();
        let mut kernel_builder = WgslKernelBuilder::new(
            workgroup_size.clone(),
            rvec![
                BuiltIn::LocalInvocationIndex,
                BuiltIn::WorkgroupId,
                BuiltIn::NumWorkgroups,
            ],
            device.compute_features().clone(),
        );
        self.register_bindings::<P>(&mut kernel_builder, inplace)?;
        kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);

        let n = P::W;

        //Custom with slice offset
        kernel_builder.write_global(wgsl! {
            //Converts 4D index into 1D offset
            fn ndIndexToOffset(index: vec4<u32>, src_offsets: vec4<u32>, stride: vec4<u32>) -> u32 {
                var offset: u32 = 0u;
                offset = dot(index + src_offsets, stride);
                return offset;
            }
        });
        kernel_builder.write_offset_to_index();

        kernel_builder.write_main(wgsl! {
            //Dispatch 1 thread per output element
            //dst_offset is index into the output buffer (1D)
            let x_offset = workgroup_id.x * 64u;
            var dst_offset = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index;
            if (dst_offset >= metadata.dst_numel / 'n) {
                return;
            }

            //Convert 1D offset into 4D index
            let dst_index = offsetToNdIndex(dst_offset, metadata.dst_stride);

        });

        let ReindexKernels::Standard(inner) = self;

        let body = match inner {
            Reindex::Permute(_) => wgsl! {
                var src_index = vec4<u32>(0u);
                src_index[metadata.perm[0]] = dst_index[0];
                src_index[metadata.perm[1]] = dst_index[1];
                src_index[metadata.perm[2]] = dst_index[2];
                src_index[metadata.perm[3]] = dst_index[3];
            },
            Reindex::Slice(_) => wgsl! { var src_index = dst_index; },
            Reindex::Broadcast(_) => wgsl! {
                // Broadcasting is valid if dims are equal, or if one of the dims is 1
                var src_index = select(dst_index, vec4<u32>(0u), metadata.src_shape == vec4<u32>(1u));
            },
        };
        kernel_builder.write_main(body);

        let src_offsets = match inner {
            Reindex::Slice(_) => wgsl! { metadata.src_offsets },
            _ => wgsl! { vec4<u32>(0u) },
        };

        kernel_builder.write_main(wgsl! {
            //Convert 4D index into 1D offset
            let src_offset = ndIndexToOffset(src_index, 'src_offsets, metadata.src_stride);
            //Read from input buffer and write to output buffer
            Y[dst_offset] = X[src_offset];
        });
        Ok(kernel_builder.build()?)
    }