in core/maxframe/tensor/utils.py [0:0]
def decide_unify_split(*splits):
# TODO (jisheng): In the future, we need more sophisticated way to decide the rechunk split
# right now, for (2, 2) and (3, 1), we get the rechunk split as (2, 1, 1)
if not splits:
return ()
raw_splits = splits
# support broadcasting rules
# decide_unify_splits((1,), (5,)) --> (5,)
splits = set(s for s in splits if ((len(s) > 1) or (len(s) == 1 and s[0] != 1)))
if len(splits) == 1:
return splits.pop()
if len(splits) == 0:
return raw_splits[0]
if any(np.isnan(sum(s)) for s in splits):
raise ValueError(f"Tensor chunk sizes are unknown: {splits}")
if len(set(sum(s) for s in splits)) > 1:
raise ValueError(f"Splits not of same size: {splits}")
q = [list(s) for s in splits]
size = sum(q[0])
cum = 0
res = []
while cum < size:
m = min(s[0] for s in q)
res.append(m)
for s in q:
s[0] -= m
if s[0] == 0:
s.pop(0)
cum += m
return tuple(res)