fn metadata()

in crates/ratchet-core/src/ops/reindex/mod.rs [166:222]


    fn metadata(&self, dst: &Tensor, _: &KernelElement) -> Result<Self::Metadata, OperationError> {
        let ReindexKernels::Standard(inner) = self;
        let srcs = inner.srcs();
        let src = srcs.first().unwrap();
        let src_shape = Shape::promote(src.shape().clone(), 4);
        let dst_shape = Shape::promote(dst.shape().clone(), 4);

        let src_numel = src_shape.numel() as u32;
        let dst_numel = dst_shape.numel() as u32;

        let src_strides = Strides::from(&src_shape);
        let dst_strides = Strides::from(&dst_shape);

        let src_stride = UVec4::from(&src_strides);
        let dst_stride = UVec4::from(&dst_strides);

        let src_shape = UVec4::from(&src_shape);
        let dst_shape = UVec4::from(&dst_shape);

        match inner {
            Reindex::Permute(p) => {
                let permute = p.promote();
                let vdims = permute.iter().map(|&d| d as u32).collect::<Vec<_>>();
                let perm: [u32; 4] = vdims.try_into().unwrap();
                Ok(ReindexMeta::Permute(PermuteMeta::new(
                    src_shape,
                    dst_shape,
                    src_stride,
                    dst_stride,
                    src_numel,
                    dst_numel,
                    perm.into(),
                )))
            }
            Reindex::Slice(s) => {
                let starts = s.indices().iter().map(|i| i.start).collect::<Vec<_>>();
                let mut offsets = [0; 4];
                let offset = 4 - starts.len();
                for (i, &start) in starts.iter().enumerate() {
                    offsets[i + offset] = start as u32;
                }
                let src_offsets = UVec4::from(offsets);
                Ok(ReindexMeta::Slice(SliceMeta::new(
                    src_shape,
                    dst_shape,
                    src_stride,
                    dst_stride,
                    src_numel,
                    dst_numel,
                    src_offsets,
                )))
            }
            Reindex::Broadcast(_) => Ok(ReindexMeta::Broadcast(BroadcastMeta::new(
                src_shape, dst_shape, src_stride, dst_stride, src_numel, dst_numel,
            ))),
        }
    }