in core/maxframe/tensor/array_utils.py [0:0]
def as_same_device(inputs, device=None, ret_extra=False, copy_if_not_writeable=False):
input_tensors = [
i for i in inputs if hasattr(i, "ndim") and i.ndim > 0
] # filter scalar
has_sparse = any(issparse(i) for i in inputs)
if device is None:
try:
device = _most_nbytes_device(
(i.device.id if hasattr(i, "device") else -1, i.nbytes)
for i in input_tensors
)
except ValueError:
device = -1
if device == -1:
outputs = [_get(i) for i in inputs]
else:
outputs = [move_to_device(i, device) for i in inputs]
if copy_if_not_writeable:
new_outputs = []
for out in outputs:
if not _is_array_writeable(out):
new_outputs.append(out.copy())
elif isinstance(out, (sparse.SparseMatrix, sparse.SparseVector)):
if (
not _is_array_writeable(out.data)
or not _is_array_writeable(out.indices)
or not _is_array_writeable(out.indptr)
):
new_outputs.append(type(out)(out.spmatrix.copy(), shape=out.shape))
else:
new_outputs.append(out)
else:
new_outputs.append(out)
outputs = new_outputs
if not ret_extra:
return outputs
if has_sparse:
m = sparse
else:
if len(input_tensors) > 0:
m = get_array_module(input_tensors[0])
else:
m = np
return outputs, device, m