fn render()

in crates/ratchet-core/src/ops/conv.rs [40:123]


    fn render<P: WgslPrimitive>(
        &self,
        inplace: bool,
        dst: &Tensor,
        workgroup_size: &WorkgroupSize,
    ) -> Result<KernelSource, OperationError> {
        let device = dst.device().try_gpu()?;
        let mut kernel_builder = WgslKernelBuilder::new(
            workgroup_size.clone(),
            rvec![
                BuiltIn::GlobalInvocationId,
                BuiltIn::LocalInvocationId,
                BuiltIn::WorkgroupId,
            ],
            device.compute_features().clone(),
        );
        self.register_bindings::<P>(&mut kernel_builder, inplace)?;
        kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);

        let dt = P::T::DT;
        kernel_builder.write_global(wgsl! {
            var<workgroup> F: array<'dt, 4096u>;
        });

        kernel_builder.write_global(wgsl! {
            fn inner(input_index: u32, filter_index: u32, output_index: u32, bias_index: u32, start: u32, end: u32) {
                var inp = vec3<'dt>(0f);
                var kernel = vec3<'dt>(0f);
                var acc = vec3<'dt>(0f);
                for(var i = 0u; i < metadata.Cin; i++) {
                    let input_start = input_index + (i * metadata.Lin) - metadata.padding; //-1 is for padding
                    //We only populate the input between the provided indices, used for padding
                    for(var j = start; j <= end; j++) {
                        inp[j] = X[input_start + j];
                    }

                    let filter_start = i * metadata.KS;
                    kernel.x = F[filter_start];
                    kernel.y = F[filter_start + 1u];
                    kernel.z = F[filter_start + 2u];

                    acc = fma(inp, kernel, acc);
                }
                Y[output_index] = acc.x + acc.y + acc.z + B[bias_index];
            }

            //Each thread may load more than 1 element into shared memory
            fn load_filters_into_smem(local_invocation_id: vec3<u32>, filter_index: u32) {
                let windex = filter_index + (local_invocation_id.x * metadata.Fperthread);
                let findex = (local_invocation_id.x * metadata.Fperthread);
                for(var i=0u; i < metadata.Fperthread; i++) {
                    if findex + i < metadata.F_numel {
                        F[findex + i] = W[windex + i];
                    }
                }
            }
        });

        let wgsx = workgroup_size.x.render();
        kernel_builder.write_main(wgsl!{
            let input_index = (workgroup_id.x * 'wgsx + local_invocation_id.x) * metadata.stride;
            let filter_index = (workgroup_id.y * metadata.F_numel);
            load_filters_into_smem(local_invocation_id, filter_index);
            workgroupBarrier();

            if input_index >= metadata.Lin {
                //Break after loading because all threads may be needed for loading F
                return;
            }

            let output_index = (workgroup_id.x * 'wgsx + local_invocation_id.x) + (workgroup_id.y * metadata.Lout);
            let bias_index = workgroup_id.y;

            if input_index == metadata.Lin - metadata.padding {
                inner(input_index, filter_index, output_index, bias_index, 0u, 1u);
            } else if input_index == 0u {
                inner(input_index, filter_index, output_index, bias_index, 1u, 2u);
            } else {
                inner(input_index, filter_index, output_index, bias_index, 0u, 2u);
            }
        });

        Ok(kernel_builder.build()?)
    }