fn compute_mu()

in crates/ratchet-core/src/ops/norm/mod.rs [216:245]


    fn compute_mu<P: WgslPrimitive>(
        kernel_builder: &mut WgslKernelBuilder,
        accessor: String,
        reduction_len: &str,
        workgroup_size: &WorkgroupSize,
    ) {
        let BLOCK_SIZE = workgroup_size.x.render();
        let dt = P::T::DT;
        kernel_builder.write_main(wgsl! {
            for (var i: u32 = local_invocation_id.x; i < 'reduction_len; i += 'BLOCK_SIZE) {
                threadSum += X[anchor + i];
            }
            workgroupBarrier();
            smem[local_invocation_id.x] = threadSum;
            workgroupBarrier();
        });

        let steps = (workgroup_size.x - 1).ilog2();
        for i in (0..=steps).rev().map(|x| 2u32.pow(x)) {
            let v = i.render();
            kernel_builder.write_main(wgsl! { block_sum(local_invocation_id.x, 'v); });
        }

        let mu = match P::W {
            1 => wgsl! { let mu = smem[0] / 'dt(metadata.N); },
            2 | 4 => wgsl! {let mu = dot(smem[0], 'accessor(1.)) / 'dt(metadata.N); },
            _ => unreachable!(),
        };
        kernel_builder.write_main(mu);
    }