fn cat_contiguous>()

in candle-core/src/tensor_cat.rs [151:237]


    fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
        if args.is_empty() {
            Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
        }
        let arg0 = args[0].as_ref();
        if args.len() == 1 {
            return Ok(arg0.clone());
        }
        let rank = arg0.rank();
        let device = arg0.device();
        let dtype = arg0.dtype();
        let first_dims = arg0.shape().dims();
        let mut cat_dims = first_dims.to_vec();
        cat_dims[dim] = 0;
        for (arg_idx, arg) in args.iter().enumerate() {
            let arg = arg.as_ref();
            if arg.dtype() != dtype {
                Err(Error::DTypeMismatchBinaryOp {
                    lhs: dtype,
                    rhs: arg.dtype(),
                    op: "cat",
                }
                .bt())?
            }
            if arg.device().location() != device.location() {
                Err(Error::DeviceMismatchBinaryOp {
                    lhs: device.location(),
                    rhs: arg.device().location(),
                    op: "cat",
                }
                .bt())?
            }
            if rank != arg.rank() {
                Err(Error::UnexpectedNumberOfDims {
                    expected: rank,
                    got: arg.rank(),
                    shape: arg.shape().clone(),
                }
                .bt())?
            }
            for (dim_idx, (v1, v2)) in arg0
                .shape()
                .dims()
                .iter()
                .zip(arg.shape().dims().iter())
                .enumerate()
            {
                if dim_idx == dim {
                    cat_dims[dim] += v2;
                }
                if dim_idx != dim && v1 != v2 {
                    Err(Error::ShapeMismatchCat {
                        dim: dim_idx,
                        first_shape: arg0.shape().clone(),
                        n: arg_idx + 1,
                        nth_shape: arg.shape().clone(),
                    }
                    .bt())?
                }
            }
        }
        let cat_target_dim_len = cat_dims[dim];
        let block_size: usize = cat_dims.iter().skip(1 + dim).product();
        let shape = Shape::from(cat_dims);
        let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
        let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
        let mut dst_o = 0;
        for arg in args.iter() {
            let arg = arg.as_ref();
            let arg_dims = arg.shape().dims();
            let d1: usize = arg_dims.iter().take(dim).product();
            let d2 = block_size * arg_dims[dim];
            let dst_s = block_size * cat_target_dim_len;
            let src_o = arg.layout().start_offset();
            arg.storage().copy2d(
                &mut storage,
                d1,
                d2,
                /* src_s */ d2,
                dst_s,
                src_o,
                dst_o,
            )?;
            dst_o += d2;
        }
        Ok(crate::tensor::from_storage(storage, shape, op, false))
    }