def decide_unify_split()

in core/maxframe/tensor/utils.py [0:0]


def decide_unify_split(*splits):
    # TODO (jisheng): In the future, we need more sophisticated way to decide the rechunk split
    # right now, for (2, 2) and (3, 1), we get the rechunk split as (2, 1, 1)
    if not splits:
        return ()
    raw_splits = splits
    # support broadcasting rules
    # decide_unify_splits((1,), (5,))  --> (5,)
    splits = set(s for s in splits if ((len(s) > 1) or (len(s) == 1 and s[0] != 1)))
    if len(splits) == 1:
        return splits.pop()
    if len(splits) == 0:
        return raw_splits[0]

    if any(np.isnan(sum(s)) for s in splits):
        raise ValueError(f"Tensor chunk sizes are unknown: {splits}")
    if len(set(sum(s) for s in splits)) > 1:
        raise ValueError(f"Splits not of same size: {splits}")

    q = [list(s) for s in splits]
    size = sum(q[0])
    cum = 0

    res = []
    while cum < size:
        m = min(s[0] for s in q)
        res.append(m)
        for s in q:
            s[0] -= m
            if s[0] == 0:
                s.pop(0)

        cum += m

    return tuple(res)