in crates/ratchet-core/src/ops/unary.rs [91:158]
fn render<P: WgslPrimitive>(
&self,
inplace: bool,
dst: &Tensor,
workgroup_size: &WorkgroupSize,
) -> Result<KernelSource, OperationError> {
let device = dst.device().try_gpu()?;
let mut kernel_builder = WgslKernelBuilder::new(
workgroup_size.clone(),
rvec![
BuiltIn::WorkgroupId,
BuiltIn::LocalInvocationIndex,
BuiltIn::NumWorkgroups
],
device.compute_features().clone(),
);
self.register_bindings::<P>(&mut kernel_builder, inplace)?;
kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);
let UnaryKernels::Standard(inner) = self;
//Write global functions
match inner.op {
UnaryOp::Gelu => {
kernel_builder.write_global(Unary::render_tanh::<P>());
kernel_builder.write_global(Unary::render_gelu::<P>());
}
UnaryOp::Tanh => {
kernel_builder.write_global(Unary::render_tanh::<P>());
}
UnaryOp::Sigmoid => {
kernel_builder.write_global(Unary::render_sigmoid::<P>());
}
UnaryOp::Silu => {
kernel_builder.write_global(Unary::render_sigmoid::<P>());
kernel_builder.write_global(Unary::render_silu::<P>());
}
UnaryOp::Relu => {
kernel_builder.write_global(Unary::render_relu::<P>());
}
_ => {}
};
let n = P::W;
kernel_builder.write_main(wgsl! {
let x_offset = workgroup_id.x * 64u;
let index = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index;
if (index >= metadata.numel / 'n) {
return;
}
});
let func = inner.op.kernel_operation();
if inplace {
kernel_builder.write_main(wgsl! {
let val = X[index];
X[index] = 'func(val);
});
} else {
kernel_builder.write_main(wgsl! {
Y[index] = 'func(X[index]);
});
}
Ok(kernel_builder.build()?)
}