in core/maxframe/tensor/indexing/core.py [0:0]
def calc_shape(tensor_shape, index):
shape = []
in_axis = 0
out_axis = 0
fancy_index = None
fancy_index_shapes = []
for ind in index:
if isinstance(ind, TENSOR_TYPE + (np.ndarray,)) and ind.dtype == np.bool_:
# bool
shape.append(np.nan if not isinstance(ind, np.ndarray) else int(ind.sum()))
for i, t_size, size in zip(
itertools.count(0),
ind.shape,
tensor_shape[in_axis : ind.ndim + in_axis],
):
if not np.isnan(t_size) and not np.isnan(size) and t_size != size:
raise IndexError(
f"boolean index did not match indexed array along dimension {in_axis + i}; "
f"dimension is {size} but corresponding boolean dimension is {t_size}"
)
in_axis += ind.ndim
out_axis += 1
elif isinstance(ind, TENSOR_TYPE + (np.ndarray,)):
first_fancy_index = False
if fancy_index is None:
first_fancy_index = True
fancy_index = out_axis
if isinstance(ind, np.ndarray) and np.any(ind >= tensor_shape[in_axis]):
out_of_range_index = next(
i for i in ind.flat if i >= tensor_shape[in_axis]
)
raise IndexError(
f"IndexError: index {out_of_range_index} is out of "
f"bounds with size {tensor_shape[in_axis]}"
)
fancy_index_shapes.append(ind.shape)
in_axis += 1
if first_fancy_index:
out_axis += ind.ndim
elif isinstance(ind, slice):
if np.isnan(tensor_shape[in_axis]):
shape.append(np.nan)
else:
shape.append(calc_sliced_size(tensor_shape[in_axis], ind))
in_axis += 1
out_axis += 1
elif isinstance(ind, Integral):
size = tensor_shape[in_axis]
if not np.isnan(size) and ind >= size:
raise IndexError(
f"index {ind} is out of bounds for axis {in_axis} with size {size}"
)
in_axis += 1
else:
assert ind is None
shape.append(1)
if fancy_index is not None:
try:
if any(np.isnan(np.prod(s)) for s in fancy_index_shapes):
fancy_index_shape = (np.nan,) * len(fancy_index_shapes[0])
else:
fancy_index_shape = broadcast_shape(*fancy_index_shapes)
shape = shape[:fancy_index] + list(fancy_index_shape) + shape[fancy_index:]
except ValueError:
raise IndexError(
"shape mismatch: indexing arrays could not be broadcast together "
"with shapes {0}".format(" ".join(str(s) for s in fancy_index_shapes))
)
return shape