in crates/ratchet-core/src/ops/matmul/gemm.rs [197:224]
fn calculate_dispatch(&self, _: &Tensor) -> Result<crate::Workload, OperationError> {
//GEMM
let TILE_DIM = 32;
let lhs_shape = self.spec.lhs_shape();
let rhs_shape = self.spec.rhs_shape();
let dimA = if self.trans_lhs {
lhs_shape[1]
} else {
lhs_shape[0]
};
let dimB = if self.trans_rhs {
rhs_shape[0]
} else {
rhs_shape[1]
};
let group_x = WorkgroupCount::div_ceil(dimB as _, TILE_DIM);
let div_ceil = WorkgroupCount::div_ceil(dimA, TILE_DIM);
let group_y = div_ceil;
let workgroup_count = wgc![group_x as _, group_y as _, self.spec.stacks() as _];
Ok(Workload {
workgroup_count,
workgroup_size: wgs![8, 8, 1],
})
}