in crates/ratchet-core/src/ops/matmul/gemm.rs [305:359]
fn write_getters<P: WgslPrimitive>(
&self,
_: &Tensor,
builder: &mut WgslKernelBuilder,
) -> Result<(), OperationError> {
let (A, _, _) = (&self.lhs, &self.rhs, &self.bias);
let accessor = P::render_type();
let W = P::W;
let dt = P::T::DT;
builder.write_unpack(A.dt());
let a_getters = match A.dt() {
DType::F32 | DType::F16 => {
wgsl! {
fn getA(d0 : i32, d1 : i32, d2 : i32) -> 'accessor {
return 'accessor(A[getAIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 'W]);
}
}
}
DType::Q8_0F(_) | DType::Q8_0H(_) => {
wgsl! {
fn getA(d0 : i32, d1 : i32, d2 : i32) -> vec4<'dt> {
return unpack(A[getAIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 4]);
}
fn getAbsMax(d0 : i32, d1 : i32, d2 : i32) -> 'dt {
let abs_index = getAIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 32;
return scale[abs_index];
}
}
}
_ => return Err(InvariantError::UnsupportedDType(A.dt()).into()),
};
builder.write_global(a_getters);
match A.dt() {
DType::F32 | DType::F16 => {
builder.write_global(wgsl! {
fn getB(d0 : i32, d1 : i32, d2 : i32) -> 'accessor {
return 'accessor(B[getBIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 'W]);
}
});
}
DType::Q8_0F(_) | DType::Q8_0H(_) => {
builder.write_global(wgsl! {
fn getB(d0 : i32, d1 : i32, d2 : i32) -> 'dt {
return 'dt(B[getBIndexFromCoords3D(vec3<i32>(d0,d1,d2)) / 'W]);
}
});
}
_ => return Err(InvariantError::UnsupportedDType(A.dt()).into()),
}
Ok(())
}