fn write_readers_and_writers()

in crates/ratchet-core/src/ops/matmul/gemm.rs [361:460]


    fn write_readers_and_writers<P: WgslPrimitive>(
        &self,
        builder: &mut WgslKernelBuilder,
        fits: (bool, bool, bool),
    ) -> Result<(), OperationError> {
        let FIT_A_OUTER = fits.0;
        let FIT_INNER = fits.1;
        let FIT_B_OUTER = fits.2;
        let accessor = P::render_type();

        let a_inner = if self.trans_lhs {
            wgsl! { value = getA(batch, col, row); }
        } else {
            wgsl! { value = getA(batch, row, col); }
        };

        let readA = if FIT_A_OUTER && FIT_INNER {
            a_inner
        } else if self.trans_lhs {
            wgsl! {
                if (row < metadata.lhs_shape.z && col < metadata.lhs_shape.y) {
                    'a_inner
                }
            }
        } else {
            wgsl! {
                if (row < metadata.lhs_shape.y && col < metadata.lhs_shape.z) {
                    'a_inner
                }
            }
        };

        let aAccessor = match self.lhs.dt() {
            DType::Q8_0F(_) => Vec4::<f32>::render_type(),
            DType::Q8_0H(_) => Vec4::<f16>::render_type(),
            _ => accessor.clone(),
        };

        builder.write_global(wgsl! {
            fn mm_readA(batch: i32, row: i32, col: i32) -> 'aAccessor {
                var value = 'aAccessor(0.0);
                'readA
                return value;
            }
        });

        let b_inner = if self.trans_rhs {
            wgsl! { value = getB(batch, col, row); }
        } else {
            wgsl! { value = getB(batch, row, col); }
        };

        let readB = if FIT_INNER && FIT_B_OUTER {
            b_inner
        } else if self.trans_rhs {
            wgsl! {
                if (row < metadata.rhs_shape.z && col < metadata.rhs_shape.y) {
                    'b_inner
                }
            }
        } else {
            wgsl! {
                if (row < metadata.rhs_shape.y && col < metadata.rhs_shape.z) {
                    'b_inner
                }
            }
        };

        builder.write_global(wgsl! {
            fn mm_readB(batch: i32, row: i32, col: i32) -> 'accessor {
                var value = 'accessor(0.0);
                'readB
                return value;
            }
        });

        let write = if FIT_A_OUTER && FIT_B_OUTER {
            wgsl! {
                var value = valueIn;
                let coords = vec3<i32>(batch, row, col);
                setOutputAtCoords(coords[0], coords[1], coords[2], value);
            }
        } else {
            wgsl! {
                if (row < metadata.dim_lhs_outer && col < metadata.dim_rhs_outer) {
                    var value = valueIn;
                    let coords = vec3<i32>(batch, row, col);
                    setOutputAtCoords(coords[0], coords[1], coords[2], value);
                }
            }
        };

        builder.write_global(wgsl! {
            fn mm_write(batch: i32, row: i32, col: i32, valueIn: 'accessor) {
                'write
            }
        });

        Ok(())
    }