in crates/ratchet-core/src/cpu/gemm.rs [30:57]
fn calculate_skips(
lhs_shape: &Shape,
lhs_strides: &[isize],
rhs_shape: &Shape,
rhs_strides: &[isize],
rank: usize,
m: usize,
n: usize,
k: usize,
) -> Result<(usize, usize)> {
let lhs_skip: usize = match lhs_strides[..rank - 2] {
[s1, stride] if s1 == stride * lhs_shape[1] as isize => stride as usize,
[_, stride] if lhs_shape[0] == 1 => stride as usize,
[stride, _] if lhs_shape[1] == 1 => stride as usize,
[stride] => stride as usize,
[] => m * k,
_ => Err(anyhow!("non-contiguous lhs"))?,
};
let rhs_skip: usize = match rhs_strides[..rank - 2] {
[s1, stride] if s1 == stride * rhs_shape[1] as isize => stride as usize,
[_, stride] if rhs_shape[0] == 1 => stride as usize,
[stride, _] if rhs_shape[1] == 1 => stride as usize,
[stride] => stride as usize,
[] => n * k,
_ => Err(anyhow!("non-contiguous rhs"))?,
};
Ok((lhs_skip, rhs_skip))
}