in crates/ratchet-core/src/ops/norm/mod.rs [216:245]
fn compute_mu<P: WgslPrimitive>(
kernel_builder: &mut WgslKernelBuilder,
accessor: String,
reduction_len: &str,
workgroup_size: &WorkgroupSize,
) {
let BLOCK_SIZE = workgroup_size.x.render();
let dt = P::T::DT;
kernel_builder.write_main(wgsl! {
for (var i: u32 = local_invocation_id.x; i < 'reduction_len; i += 'BLOCK_SIZE) {
threadSum += X[anchor + i];
}
workgroupBarrier();
smem[local_invocation_id.x] = threadSum;
workgroupBarrier();
});
let steps = (workgroup_size.x - 1).ilog2();
for i in (0..=steps).rev().map(|x| 2u32.pow(x)) {
let v = i.render();
kernel_builder.write_main(wgsl! { block_sum(local_invocation_id.x, 'v); });
}
let mu = match P::W {
1 => wgsl! { let mu = smem[0] / 'dt(metadata.N); },
2 | 4 => wgsl! {let mu = dot(smem[0], 'accessor(1.)) / 'dt(metadata.N); },
_ => unreachable!(),
};
kernel_builder.write_main(mu);
}