fn generic_broadcast()

in crates/ratchet-core/src/cpu/reindex.rs [184:222]


fn generic_broadcast<T: TensorDType>(
    src: &[T],
    result: &mut [T],
    src_shape: &Shape,
    dst_shape: &Shape,
) {
    // We now know that these will always be len 4, same as gpu impl.
    let src_shape = &Shape::promote(src_shape.clone(), 4);
    let dst_shape = &Shape::promote(dst_shape.clone(), 4);

    let src_strides = &Strides::from(src_shape);
    let dst_strides = &Strides::from(dst_shape);

    let src_shape: [usize; 4] = src_shape.try_into().unwrap();
    let src_strides: [usize; 4] = src_strides.try_into().unwrap();
    let dst_strides: [usize; 4] = dst_strides.try_into().unwrap();

    fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] {
        let mut result = [0; 4];
        result[0] = if t[0] { a[0] } else { b[0] };
        result[1] = if t[1] { a[1] } else { b[1] };
        result[2] = if t[2] { a[2] } else { b[2] };
        result[3] = if t[3] { a[3] } else { b[3] };
        result
    }

    let shape_onedim_lookup: [bool; 4] = [
        src_shape[0] != 1,
        src_shape[1] != 1,
        src_shape[2] != 1,
        src_shape[3] != 1,
    ];
    for i in 0..result.len() {
        let dst_index = offset_to_ndindex(i, dst_strides);
        let src_index = select(dst_index, [0; 4], shape_onedim_lookup);
        let src_offset = nd_index_to_offset(src_index, src_strides);
        result[i] = src[src_offset]
    }
}