fn render()

in crates/ratchet-core/src/ops/unary.rs [91:158]


    fn render<P: WgslPrimitive>(
        &self,
        inplace: bool,
        dst: &Tensor,
        workgroup_size: &WorkgroupSize,
    ) -> Result<KernelSource, OperationError> {
        let device = dst.device().try_gpu()?;
        let mut kernel_builder = WgslKernelBuilder::new(
            workgroup_size.clone(),
            rvec![
                BuiltIn::WorkgroupId,
                BuiltIn::LocalInvocationIndex,
                BuiltIn::NumWorkgroups
            ],
            device.compute_features().clone(),
        );

        self.register_bindings::<P>(&mut kernel_builder, inplace)?;
        kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);

        let UnaryKernels::Standard(inner) = self;

        //Write global functions
        match inner.op {
            UnaryOp::Gelu => {
                kernel_builder.write_global(Unary::render_tanh::<P>());
                kernel_builder.write_global(Unary::render_gelu::<P>());
            }
            UnaryOp::Tanh => {
                kernel_builder.write_global(Unary::render_tanh::<P>());
            }
            UnaryOp::Sigmoid => {
                kernel_builder.write_global(Unary::render_sigmoid::<P>());
            }
            UnaryOp::Silu => {
                kernel_builder.write_global(Unary::render_sigmoid::<P>());
                kernel_builder.write_global(Unary::render_silu::<P>());
            }
            UnaryOp::Relu => {
                kernel_builder.write_global(Unary::render_relu::<P>());
            }
            _ => {}
        };

        let n = P::W;

        kernel_builder.write_main(wgsl! {
            let x_offset = workgroup_id.x * 64u;
            let index = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index;
            if (index >= metadata.numel / 'n) {
                return;
            }
        });

        let func = inner.op.kernel_operation();
        if inplace {
            kernel_builder.write_main(wgsl! {
                let val = X[index];
                X[index] = 'func(val);
            });
        } else {
            kernel_builder.write_main(wgsl! {
                Y[index] = 'func(X[index]);
            });
        }

        Ok(kernel_builder.build()?)
    }