in crates/ratchet-core/src/cpu/reindex.rs [184:222]
fn generic_broadcast<T: TensorDType>(
src: &[T],
result: &mut [T],
src_shape: &Shape,
dst_shape: &Shape,
) {
// We now know that these will always be len 4, same as gpu impl.
let src_shape = &Shape::promote(src_shape.clone(), 4);
let dst_shape = &Shape::promote(dst_shape.clone(), 4);
let src_strides = &Strides::from(src_shape);
let dst_strides = &Strides::from(dst_shape);
let src_shape: [usize; 4] = src_shape.try_into().unwrap();
let src_strides: [usize; 4] = src_strides.try_into().unwrap();
let dst_strides: [usize; 4] = dst_strides.try_into().unwrap();
fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] {
let mut result = [0; 4];
result[0] = if t[0] { a[0] } else { b[0] };
result[1] = if t[1] { a[1] } else { b[1] };
result[2] = if t[2] { a[2] } else { b[2] };
result[3] = if t[3] { a[3] } else { b[3] };
result
}
let shape_onedim_lookup: [bool; 4] = [
src_shape[0] != 1,
src_shape[1] != 1,
src_shape[2] != 1,
src_shape[3] != 1,
];
for i in 0..result.len() {
let dst_index = offset_to_ndindex(i, dst_strides);
let src_index = select(dst_index, [0; 4], shape_onedim_lookup);
let src_offset = nd_index_to_offset(src_index, src_strides);
result[i] = src[src_offset]
}
}