in crates/ratchet-core/src/ops/reindex/mod.rs [166:222]
fn metadata(&self, dst: &Tensor, _: &KernelElement) -> Result<Self::Metadata, OperationError> {
let ReindexKernels::Standard(inner) = self;
let srcs = inner.srcs();
let src = srcs.first().unwrap();
let src_shape = Shape::promote(src.shape().clone(), 4);
let dst_shape = Shape::promote(dst.shape().clone(), 4);
let src_numel = src_shape.numel() as u32;
let dst_numel = dst_shape.numel() as u32;
let src_strides = Strides::from(&src_shape);
let dst_strides = Strides::from(&dst_shape);
let src_stride = UVec4::from(&src_strides);
let dst_stride = UVec4::from(&dst_strides);
let src_shape = UVec4::from(&src_shape);
let dst_shape = UVec4::from(&dst_shape);
match inner {
Reindex::Permute(p) => {
let permute = p.promote();
let vdims = permute.iter().map(|&d| d as u32).collect::<Vec<_>>();
let perm: [u32; 4] = vdims.try_into().unwrap();
Ok(ReindexMeta::Permute(PermuteMeta::new(
src_shape,
dst_shape,
src_stride,
dst_stride,
src_numel,
dst_numel,
perm.into(),
)))
}
Reindex::Slice(s) => {
let starts = s.indices().iter().map(|i| i.start).collect::<Vec<_>>();
let mut offsets = [0; 4];
let offset = 4 - starts.len();
for (i, &start) in starts.iter().enumerate() {
offsets[i + offset] = start as u32;
}
let src_offsets = UVec4::from(offsets);
Ok(ReindexMeta::Slice(SliceMeta::new(
src_shape,
dst_shape,
src_stride,
dst_stride,
src_numel,
dst_numel,
src_offsets,
)))
}
Reindex::Broadcast(_) => Ok(ReindexMeta::Broadcast(BroadcastMeta::new(
src_shape, dst_shape, src_stride, dst_stride, src_numel, dst_numel,
))),
}
}