fn write_getters()

in crates/ratchet-core/src/ops/matmul/gemm.rs [305:359]


    fn write_getters<P: WgslPrimitive>(
        &self,
        _: &Tensor,
        builder: &mut WgslKernelBuilder,
    ) -> Result<(), OperationError> {
        let (A, _, _) = (&self.lhs, &self.rhs, &self.bias);
        let accessor = P::render_type();
        let W = P::W;
        let dt = P::T::DT;
        builder.write_unpack(A.dt());

        let a_getters = match A.dt() {
            DType::F32 | DType::F16 => {
                wgsl! {
                    fn getA(d0 : i32, d1 : i32, d2 : i32) -> 'accessor {
                        return 'accessor(A[getAIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 'W]);
                    }
                }
            }
            DType::Q8_0F(_) | DType::Q8_0H(_) => {
                wgsl! {
                    fn getA(d0 : i32, d1 : i32, d2 : i32) -> vec4<'dt> {
                        return unpack(A[getAIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 4]);
                    }

                    fn getAbsMax(d0 : i32, d1 : i32, d2 : i32) -> 'dt {
                        let abs_index = getAIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 32;
                        return scale[abs_index];
                    }
                }
            }
            _ => return Err(InvariantError::UnsupportedDType(A.dt()).into()),
        };
        builder.write_global(a_getters);

        match A.dt() {
            DType::F32 | DType::F16 => {
                builder.write_global(wgsl! {
                    fn getB(d0 : i32, d1 : i32, d2 : i32) -> 'accessor {
                        return 'accessor(B[getBIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 'W]);
                    }
                });
            }
            DType::Q8_0F(_) | DType::Q8_0H(_) => {
                builder.write_global(wgsl! {
                    fn getB(d0 : i32, d1 : i32, d2 : i32) -> 'dt {
                        return 'dt(B[getBIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 'W]);
                    }
                });
            }
            _ => return Err(InvariantError::UnsupportedDType(A.dt()).into()),
        }

        Ok(())
    }