in crates/ratchet-core/src/ops/softmax.rs [55:180]
fn render<P: WgslPrimitive>(
&self,
inplace: bool,
dst: &Tensor,
workgroup_size: &WorkgroupSize,
) -> Result<KernelSource, OperationError> {
let device = dst.device().try_gpu()?;
let mut kernel_builder = WgslKernelBuilder::new(
workgroup_size.clone(),
rvec![
BuiltIn::GlobalInvocationId,
BuiltIn::LocalInvocationId,
BuiltIn::WorkgroupId,
],
device.compute_features().clone(),
);
self.register_bindings::<P>(&mut kernel_builder, inplace)?;
kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);
let dt = P::T::DT;
let accessor = P::render_type();
let BLOCK_SIZE = workgroup_size.x.render();
let minFloat = P::T::MIN;
kernel_builder.write_global(wgsl! {
var<workgroup> smem: array<'accessor, 'BLOCK_SIZE>;
var<workgroup> maximum: 'dt;
var<workgroup> sum: 'dt;
});
kernel_builder.write_global(wgsl! {
fn block_sum(index: u32, stride: u32) {
if index < stride {
smem[index] += smem[index + stride];
}
workgroupBarrier();
}
fn block_max(index: u32, stride: u32) {
if index < stride {
smem[index] = max(smem[index], smem[index + stride]);
}
workgroupBarrier();
}
});
let reduce_var = match P::W {
1 => "metadata.N",
2 => "metadata.ND2",
4 => "metadata.ND4",
_ => {
return Err(OperationError::CompileError(
"Invalid dimension".to_string(),
))?
}
};
let offsets = wgsl! {
let batch_stride = workgroup_id.y * metadata.M * 'reduce_var;
let row_start = batch_stride + workgroup_id.x * 'reduce_var;
let index = local_invocation_id.x;
};
kernel_builder.write_main(offsets);
kernel_builder.write_main(wgsl! {
smem[index] = 'accessor('minFloat);
for (var i: u32 = index; i < 'reduce_var; i += 'BLOCK_SIZE) {
smem[index] = max(smem[index], X[row_start + i]);
}
workgroupBarrier();
});
let steps = (workgroup_size.x - 1).ilog2();
for i in (0..=steps).rev().map(|x| 2u32.pow(x)) {
let v = i.render();
kernel_builder.write_main(wgsl! { block_max(index, 'v); });
}
let finalize_max = match P::W {
1 => wgsl! { maximum = smem[0]; },
2 => wgsl! { maximum = max(smem[0].x, smem[0].y); },
4 => wgsl! { maximum = max(smem[0].x, max(smem[0].y, max(smem[0].z, smem[0].w))); },
_ => unreachable!(),
};
kernel_builder.write_main(wgsl! {
if index == 0 {
'finalize_max
}
workgroupBarrier();
});
kernel_builder.write_main(wgsl! {
smem[index] = 'accessor(0.);
for (var i: u32 = index; i < 'reduce_var; i += 'BLOCK_SIZE) {
smem[index] += exp(X[row_start + i] - maximum);
}
workgroupBarrier();
});
for i in (0..=steps).rev().map(|x| 2u32.pow(x)) {
let v = i.render();
kernel_builder.write_main(wgsl! { block_sum(index, 'v); });
}
let finalize_sum = match P::W {
1 => wgsl! { sum = smem[0]; },
2 | 4 => wgsl! { sum = dot(smem[0], 'accessor(1.)); },
_ => unreachable!(),
};
kernel_builder.write_main(wgsl! {
if index == 0 {
'finalize_sum
}
workgroupBarrier();
});
let finalize = wgsl! {
for(var i: u32 = index; i < 'reduce_var; i += 'BLOCK_SIZE) {
var val = X[row_start + i];
X[row_start + i] = exp(val - maximum) / sum;
}
};
kernel_builder.write_main(finalize);
Ok(kernel_builder.build()?)
}