in gad/src/matrix.rs [88:116]
fn matmul(
&mut self,
v1: &af::Dim4,
v2: &af::Dim4,
prop1: MatProp,
prop2: MatProp,
) -> Result<af::Dim4> {
let tv1 = if prop1.transposed {
self.transpose(v1, false)?
} else {
*v1
};
let tv2 = if prop2.transposed {
self.transpose(v2, false)?
} else {
*v2
};
if tv1[1] != tv2[0] {
return Err(Error::dimensions(func_name!(), &[v1, v2]));
}
let r = match (tv1[2], tv1[3], tv2[2], tv2[3]) {
(1, 1, a, b) | (a, b, 1, 1) => [tv1[0], tv2[1], a, b],
(a, b, c, d) if a == c && b == d => [tv1[0], tv2[1], a, b],
_ => {
return Err(Error::dimensions(func_name!(), &[v1, v2]));
}
};
Ok(af::Dim4::new(&r))
}