in crates/ratchet-core/src/ops/matmul/gemm.rs [226:258]
fn build_kernel(
&self,
inplace: bool,
dst: &Tensor,
workgroup_size: &WorkgroupSize,
) -> Result<KernelSource, OperationError> {
let kernel_element = self.spec.select_kernel_element();
match (self.lhs.dt(), kernel_element) {
(DType::F32, KernelElement::Scalar) => {
self.render::<Scalar<f32>>(inplace, dst, workgroup_size)
}
(DType::F32, KernelElement::Vec2) => {
self.render::<Vec2<f32>>(inplace, dst, workgroup_size)
}
(DType::F32, KernelElement::Vec4) => {
self.render::<Vec4<f32>>(inplace, dst, workgroup_size)
}
(DType::F16, KernelElement::Scalar) => {
self.render::<Scalar<f16>>(inplace, dst, workgroup_size)
}
(DType::F16, KernelElement::Vec2) => {
self.render::<Vec2<f16>>(inplace, dst, workgroup_size)
}
(DType::F16, KernelElement::Vec4) => {
self.render::<Vec4<f16>>(inplace, dst, workgroup_size)
}
(DType::Q8_0F(_), _) => self.render::<Scalar<f32>>(inplace, dst, workgroup_size),
(DType::Q8_0H(_), _) => self.render::<Scalar<f16>>(inplace, dst, workgroup_size),
(DType::Q4_KF(_), _) => self.render::<Scalar<f32>>(inplace, dst, workgroup_size),
(DType::Q4_KH(_), _) => self.render::<Scalar<f16>>(inplace, dst, workgroup_size),
_ => Err(InvariantError::UnsupportedDType(self.lhs.dt()).into()),
}
}