fn render()

in crates/ratchet-core/src/ops/softmax.rs [55:180]


    fn render<P: WgslPrimitive>(
        &self,
        inplace: bool,
        dst: &Tensor,
        workgroup_size: &WorkgroupSize,
    ) -> Result<KernelSource, OperationError> {
        let device = dst.device().try_gpu()?;
        let mut kernel_builder = WgslKernelBuilder::new(
            workgroup_size.clone(),
            rvec![
                BuiltIn::GlobalInvocationId,
                BuiltIn::LocalInvocationId,
                BuiltIn::WorkgroupId,
            ],
            device.compute_features().clone(),
        );
        self.register_bindings::<P>(&mut kernel_builder, inplace)?;
        kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);

        let dt = P::T::DT;
        let accessor = P::render_type();

        let BLOCK_SIZE = workgroup_size.x.render();
        let minFloat = P::T::MIN;

        kernel_builder.write_global(wgsl! {
            var<workgroup> smem: array<'accessor, 'BLOCK_SIZE>;
            var<workgroup> maximum: 'dt;
            var<workgroup> sum: 'dt;
        });

        kernel_builder.write_global(wgsl! {
            fn block_sum(index: u32, stride: u32) {
                if index < stride {
                    smem[index] += smem[index + stride];
                }
                workgroupBarrier();
            }

            fn block_max(index: u32, stride: u32) {
                if index < stride {
                    smem[index] = max(smem[index], smem[index + stride]);
                }
                workgroupBarrier();
            }
        });

        let reduce_var = match P::W {
            1 => "metadata.N",
            2 => "metadata.ND2",
            4 => "metadata.ND4",
            _ => {
                return Err(OperationError::CompileError(
                    "Invalid dimension".to_string(),
                ))?
            }
        };

        let offsets = wgsl! {
            let batch_stride = workgroup_id.y * metadata.M * 'reduce_var;
            let row_start = batch_stride + workgroup_id.x * 'reduce_var;
            let index = local_invocation_id.x;
        };
        kernel_builder.write_main(offsets);

        kernel_builder.write_main(wgsl! {
            smem[index] = 'accessor('minFloat);
            for (var i: u32 = index; i < 'reduce_var; i += 'BLOCK_SIZE) {
                smem[index] = max(smem[index], X[row_start + i]);
            }
            workgroupBarrier();
        });

        let steps = (workgroup_size.x - 1).ilog2();
        for i in (0..=steps).rev().map(|x| 2u32.pow(x)) {
            let v = i.render();
            kernel_builder.write_main(wgsl! { block_max(index, 'v); });
        }

        let finalize_max = match P::W {
            1 => wgsl! { maximum = smem[0]; },
            2 => wgsl! { maximum = max(smem[0].x, smem[0].y); },
            4 => wgsl! { maximum = max(smem[0].x, max(smem[0].y, max(smem[0].z, smem[0].w))); },
            _ => unreachable!(),
        };
        kernel_builder.write_main(wgsl! {
            if index == 0 {
                'finalize_max
            }
            workgroupBarrier();
        });

        kernel_builder.write_main(wgsl! {
            smem[index] = 'accessor(0.);
            for (var i: u32 = index; i < 'reduce_var; i += 'BLOCK_SIZE) {
                smem[index] += exp(X[row_start + i] - maximum);
            }
            workgroupBarrier();
        });

        for i in (0..=steps).rev().map(|x| 2u32.pow(x)) {
            let v = i.render();
            kernel_builder.write_main(wgsl! { block_sum(index, 'v); });
        }

        let finalize_sum = match P::W {
            1 => wgsl! { sum = smem[0]; },
            2 | 4 => wgsl! { sum = dot(smem[0], 'accessor(1.)); },
            _ => unreachable!(),
        };
        kernel_builder.write_main(wgsl! {
            if index == 0 {
                'finalize_sum
            }
            workgroupBarrier();
        });

        let finalize = wgsl! {
            for(var i: u32 = index; i < 'reduce_var; i += 'BLOCK_SIZE) {
                var val = X[row_start + i];
                X[row_start + i] = exp(val - maximum) / sum;
            }
        };
        kernel_builder.write_main(finalize);
        Ok(kernel_builder.build()?)
    }