def reshape_mapping()

in tinynn/converter/operators/optimize.py [0:0]


def reshape_mapping(shape_1, shape_2):
    i = 0
    j = 0
    acc_l = 1
    start_l = 0
    acc_r = 1
    start_r = 0
    mapping_l = []
    mapping_r = []
    sign = None
    while i < len(shape_1) or j < len(shape_2):
        if i < len(shape_1) and j < len(shape_2):
            if start_l == i and start_r == j and shape_1[i] == shape_2[j]:
                mapping_l.append([i])
                mapping_r.append([j])
                acc_l = 1
                acc_r = 1
                i += 1
                j += 1
                start_l = i
                start_r = j
                sign = None
            else:
                if sign in ('l', None):
                    acc_l = shape_1[i] * acc_l
                if sign in ('r', None):
                    acc_r = shape_2[j] * acc_r
                if acc_l == acc_r:
                    mapping_l.append(list(range(start_l, i + 1)))
                    mapping_r.append(list(range(start_r, j + 1)))
                    acc_l = 1
                    acc_r = 1
                    i += 1
                    j += 1
                    start_l = i
                    start_r = j
                    sign = None
                elif acc_l < acc_r:
                    sign = 'l'
                    i += 1
                else:
                    sign = 'r'
                    j += 1
        elif i < len(shape_1):
            assert shape_1[i] == 1
            mapping_l[-1].append(i)
            i += 1
        else:
            assert shape_2[j] == 1
            mapping_r[-1].append(j)
            j += 1
    non_one_mapping_l = []
    non_one_mapping_r = []
    for ml, mr in zip(mapping_l, mapping_r):
        new_ml = [i for i in ml if shape_1[i] != 1]
        new_mr = [j for j in mr if shape_2[j] != 1]
        if len(new_ml) > 0 and len(new_mr) > 0:
            non_one_mapping_l.append(new_ml)
            non_one_mapping_r.append(new_mr)
    return mapping_l, mapping_r, non_one_mapping_l, non_one_mapping_r