fn render()

in crates/ratchet-core/src/ops/norm/mod.rs [112:212]


    fn render<P: WgslPrimitive>(
        &self,
        inplace: bool,
        dst: &Tensor,
        workgroup_size: &WorkgroupSize,
    ) -> Result<KernelSource, OperationError> {
        let device = dst.device().try_gpu()?;
        let mut kernel_builder = WgslKernelBuilder::new(
            workgroup_size.clone(),
            rvec![
                BuiltIn::GlobalInvocationId,
                BuiltIn::LocalInvocationId,
                BuiltIn::WorkgroupId,
            ],
            device.compute_features().clone(),
        );
        self.register_bindings::<P>(&mut kernel_builder, inplace)?;
        kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);

        let NormKernels::Standard(inner) = self;

        let reduction_len = match P::W {
            1 => "metadata.N",
            2 => "metadata.ND2",
            4 => "metadata.ND4",
            v => panic!("Invalid reduction length: {}", v),
        };

        let dt = P::T::DT;
        let accessor = P::render_type();
        let BLOCK_SIZE = workgroup_size.x.render();

        kernel_builder.write_global(wgsl! {
            var<workgroup> smem: array<'accessor, 'BLOCK_SIZE>;
            var<workgroup> sum: 'dt;
        });

        kernel_builder.write_global(wgsl! {
            fn block_sum(index: u32, stride: u32) {
                if index < stride {
                    smem[index] += smem[index + stride];
                }
                workgroupBarrier();
            }
        });

        kernel_builder.write_main(wgsl!{
            let anchor = (workgroup_id.y * metadata.M * 'reduction_len) + workgroup_id.x * 'reduction_len;
        });

        kernel_builder.write_main(wgsl! { var threadSum = 'accessor(0.); });
        if matches!(inner, NormOp::RMSNorm(_)) {
            kernel_builder.write_main(wgsl! { let mu = 0.; });
        } else {
            Self::compute_mu::<P>(
                &mut kernel_builder,
                accessor.clone(),
                reduction_len,
                workgroup_size,
            );
        };

        kernel_builder.write_main(wgsl! {
            threadSum = 'accessor(0.);
            for (var i: u32 = local_invocation_id.x; i < 'reduction_len; i += 'BLOCK_SIZE) {
                let val = X[anchor + i] - mu;
                threadSum = fma(val, val, threadSum);
            }
            workgroupBarrier();
            smem[local_invocation_id.x] = threadSum;
            workgroupBarrier();
        });

        let steps = (workgroup_size.x - 1).ilog2();
        for i in (0..=steps).rev().map(|x| 2u32.pow(x)) {
            let v = i.render();
            kernel_builder.write_main(wgsl! { block_sum(local_invocation_id.x, 'v); });
        }

        let sigma = match P::W {
            1 => wgsl! { let sigma = smem[0] / 'dt(metadata.N); },
            2 | 4 => wgsl! {let sigma = dot(smem[0], 'accessor(1.)) / 'dt(metadata.N); },
            _ => unreachable!(),
        };
        kernel_builder.write_main(sigma);

        let loop_core = if matches!(inner, NormOp::RMSNorm(_)) {
            wgsl! { Y[anchor + i] = val * S[i]; }
        } else {
            wgsl! { Y[anchor + i] = fma(val, S[i], B[i]); }
        };

        kernel_builder.write_main(wgsl! {
            let denom = inverseSqrt(sigma + 'accessor(metadata.eps));
            for(var i: u32 = local_invocation_id.x; i < 'reduction_len; i += 'BLOCK_SIZE) {
                let val = (X[anchor + i] - mu) * denom;
                'loop_core
            }
        });
        Ok(kernel_builder.build()?)
    }