def as_same_device()

in core/maxframe/tensor/array_utils.py [0:0]


def as_same_device(inputs, device=None, ret_extra=False, copy_if_not_writeable=False):
    input_tensors = [
        i for i in inputs if hasattr(i, "ndim") and i.ndim > 0
    ]  # filter scalar
    has_sparse = any(issparse(i) for i in inputs)

    if device is None:
        try:
            device = _most_nbytes_device(
                (i.device.id if hasattr(i, "device") else -1, i.nbytes)
                for i in input_tensors
            )
        except ValueError:
            device = -1

    if device == -1:
        outputs = [_get(i) for i in inputs]
    else:
        outputs = [move_to_device(i, device) for i in inputs]

    if copy_if_not_writeable:
        new_outputs = []
        for out in outputs:
            if not _is_array_writeable(out):
                new_outputs.append(out.copy())
            elif isinstance(out, (sparse.SparseMatrix, sparse.SparseVector)):
                if (
                    not _is_array_writeable(out.data)
                    or not _is_array_writeable(out.indices)
                    or not _is_array_writeable(out.indptr)
                ):
                    new_outputs.append(type(out)(out.spmatrix.copy(), shape=out.shape))
                else:
                    new_outputs.append(out)
            else:
                new_outputs.append(out)
        outputs = new_outputs

    if not ret_extra:
        return outputs

    if has_sparse:
        m = sparse
    else:
        if len(input_tensors) > 0:
            m = get_array_module(input_tensors[0])
        else:
            m = np
    return outputs, device, m