in candle-metal-kernels/src/tests.rs [1338:1439]
fn mlx_gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_mlx_gemm(
GemmDType::F32,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_mlx_gemm(
GemmDType::F32,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx(results, 4),
vec![
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
518.0, 548.0, 578.0
]
);
// OFFSET
let (b, m, n, k) = (2, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
let results = run_mlx_gemm(
GemmDType::F32,
(1, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
12 * 4,
);
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
// bgemm sanity test
{
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
let results = run_mlx_gemm(
GemmDType::BF16,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx_bf16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
{
// hgemm sanity test
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
let results = run_mlx_gemm(
GemmDType::F16,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx_f16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
}