in crates/ratchet-core/src/ops/norm/mod.rs [112:212]
fn render<P: WgslPrimitive>(
&self,
inplace: bool,
dst: &Tensor,
workgroup_size: &WorkgroupSize,
) -> Result<KernelSource, OperationError> {
let device = dst.device().try_gpu()?;
let mut kernel_builder = WgslKernelBuilder::new(
workgroup_size.clone(),
rvec![
BuiltIn::GlobalInvocationId,
BuiltIn::LocalInvocationId,
BuiltIn::WorkgroupId,
],
device.compute_features().clone(),
);
self.register_bindings::<P>(&mut kernel_builder, inplace)?;
kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);
let NormKernels::Standard(inner) = self;
let reduction_len = match P::W {
1 => "metadata.N",
2 => "metadata.ND2",
4 => "metadata.ND4",
v => panic!("Invalid reduction length: {}", v),
};
let dt = P::T::DT;
let accessor = P::render_type();
let BLOCK_SIZE = workgroup_size.x.render();
kernel_builder.write_global(wgsl! {
var<workgroup> smem: array<'accessor, 'BLOCK_SIZE>;
var<workgroup> sum: 'dt;
});
kernel_builder.write_global(wgsl! {
fn block_sum(index: u32, stride: u32) {
if index < stride {
smem[index] += smem[index + stride];
}
workgroupBarrier();
}
});
kernel_builder.write_main(wgsl!{
let anchor = (workgroup_id.y * metadata.M * 'reduction_len) + workgroup_id.x * 'reduction_len;
});
kernel_builder.write_main(wgsl! { var threadSum = 'accessor(0.); });
if matches!(inner, NormOp::RMSNorm(_)) {
kernel_builder.write_main(wgsl! { let mu = 0.; });
} else {
Self::compute_mu::<P>(
&mut kernel_builder,
accessor.clone(),
reduction_len,
workgroup_size,
);
};
kernel_builder.write_main(wgsl! {
threadSum = 'accessor(0.);
for (var i: u32 = local_invocation_id.x; i < 'reduction_len; i += 'BLOCK_SIZE) {
let val = X[anchor + i] - mu;
threadSum = fma(val, val, threadSum);
}
workgroupBarrier();
smem[local_invocation_id.x] = threadSum;
workgroupBarrier();
});
let steps = (workgroup_size.x - 1).ilog2();
for i in (0..=steps).rev().map(|x| 2u32.pow(x)) {
let v = i.render();
kernel_builder.write_main(wgsl! { block_sum(local_invocation_id.x, 'v); });
}
let sigma = match P::W {
1 => wgsl! { let sigma = smem[0] / 'dt(metadata.N); },
2 | 4 => wgsl! {let sigma = dot(smem[0], 'accessor(1.)) / 'dt(metadata.N); },
_ => unreachable!(),
};
kernel_builder.write_main(sigma);
let loop_core = if matches!(inner, NormOp::RMSNorm(_)) {
wgsl! { Y[anchor + i] = val * S[i]; }
} else {
wgsl! { Y[anchor + i] = fma(val, S[i], B[i]); }
};
kernel_builder.write_main(wgsl! {
let denom = inverseSqrt(sigma + 'accessor(metadata.eps));
for(var i: u32 = local_invocation_id.x; i < 'reduction_len; i += 'BLOCK_SIZE) {
let val = (X[anchor + i] - mu) * denom;
'loop_core
}
});
Ok(kernel_builder.build()?)
}