in crates/ratchet-core/src/ops/matmul/gemm.rs [61:94]
fn register_bindings<P: WgslPrimitive>(
&self,
builder: &mut WgslKernelBuilder,
_: bool,
) -> Result<(), OperationError> {
let (A, _, bias) = (&self.lhs, &self.rhs, &self.bias);
let float_arr = Array::<P>::default();
let ro = BindingMode::ReadOnly;
match A.dt() {
DType::F32 | DType::F16 => {
builder.register_storage("A", ro, float_arr);
builder.register_storage("B", ro, float_arr);
if bias.is_some() {
builder.register_storage("bias", BindingMode::ReadOnly, float_arr);
}
builder.register_storage("result", BindingMode::ReadWrite, float_arr);
}
DType::Q8_0F(_) | DType::Q8_0H(_) => {
builder.register_storage("A", ro, Array::<Scalar<u32>>::default());
builder.register_storage("scale", ro, float_arr);
builder.register_storage("B", ro, Array::<Scalar<P::T>>::default());
if bias.is_some() {
builder.register_storage("bias", BindingMode::ReadOnly, float_arr);
}
builder.register_storage("result", BindingMode::ReadWrite, float_arr);
}
_ => return Err(InvariantError::UnsupportedDType(A.dt()).into()),
}
builder.register_uniform();
Ok(())
}