fn render()

in crates/ratchet-core/src/ops/concat.rs [49:108]


    fn render<P: WgslPrimitive>(
        &self,
        inplace: bool,
        dst: &Tensor,
        workgroup_size: &WorkgroupSize,
    ) -> Result<KernelSource, OperationError> {
        let device = dst.device().try_gpu()?;
        let mut kernel_builder = WgslKernelBuilder::new(
            workgroup_size.clone(),
            rvec![
                BuiltIn::LocalInvocationIndex,
                BuiltIn::NumWorkgroups,
                BuiltIn::WorkgroupId,
            ],
            device.compute_features().clone(),
        );
        self.register_bindings::<P>(&mut kernel_builder, inplace)?;
        kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);
        kernel_builder.write_offset_to_index();
        kernel_builder.write_index_to_offset();

        kernel_builder.write_main(wgsl! {
            let x_offset = workgroup_id.x * 64u;
            let dst_offset = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index;
            if (dst_offset >= metadata.dst_numel) {
                return;
            }

            var dst_index = offsetToNdIndex(dst_offset, metadata.dst_stride);
            let dim = metadata.dim;
        });

        kernel_builder.write_main(wgsl! {
            if(dst_index[dim] < metadata.cum0) {
                let src_offset = ndIndexToOffset(dst_index, metadata.x0_stride);
                Y[dst_offset] = X0[src_offset];
                return;
            }
        });

        let ConcatKernels::Standard(inner) = self;

        for i in 1..inner.inputs.len() {
            let prevcum = format!("metadata.cum{}", i - 1);
            let cum = format!("metadata.cum{}", i);
            let stride = format!("metadata.x{}_stride", i);
            let src = format!("X{}", i);

            kernel_builder.write_main(wgsl! {
                if(dst_index[dim] < 'cum) {
                    dst_index[dim] -= 'prevcum;
                    let src_offset = ndIndexToOffset(dst_index, 'stride);
                    Y[dst_offset] = 'src[src_offset];
                    return;
                }
            });
        }

        Ok(kernel_builder.build()?)
    }