def remove_redundant_transposes()

in coremltools/converters/mil/backend/nn/passes/mlmodel_passes.py [0:0]


def remove_redundant_transposes(spec):
    """
    Removes layers from model specification that are back to back transposes
    that compose to the identity.
    """

    def blob_name_to_layers(nn_layers):
        """
        output_to_layers: {str: layer_proto_message} : {blob name: layers that it feeds into}
        input_to_parent_layers: {str: layer_proto_message} : {blob name: parent layers that feed in}
        """
        output_to_layers = {}
        for layer in nn_layers:
            for input in layer.input:
                if not input in output_to_layers:
                    output_to_layers[input] = [layer]
                else:
                    output_to_layers[input].append(layer)

        input_to_parent_layers = {}
        for layer in nn_layers:
            for output in layer.output:
                if not layer.WhichOneof("layer") == "copy":
                    assert output not in input_to_parent_layers, \
                        "'{}' blob is generated by more than 1 layers".format(output)
                input_to_parent_layers[output] = layer

        return input_to_parent_layers, output_to_layers

    def _delete_layers(nn_spec, layers_to_delete):
        """
        Given a neural network spec and pairs of transposes to remove, rewire
        the network to bypass those transposes and remove them from the spec.
        """
        nn_layers = nn_spec.layers
        _, output_to_layers = blob_name_to_layers(nn_layers)

        # First pass: rewire layers to bypass those that will be deleted.
        for layers in layers_to_delete:
            start_layer = layers[0]
            end_layer = layers[-1]

            # Replace children's input by layer_start's input
            children = output_to_layers[end_layer.output[0]]
            for child in children:
                idx = [
                    i
                    for i, input in enumerate(child.input)
                    if input == end_layer.output[0]
                ]
                assert len(idx) == 1
                idx = idx[0]
                child.input[idx] = start_layer.input[0]

        # Second pass: delete the layers.
        for layers in layers_to_delete:
            for layer in layers:
                nn_layers.remove(layer)

    def _find_redundant_transposes(nn_spec):
        """
        Search the neural network spec for sequence of transposes that together
        are the identity, and return a list of those sequence.
        """
        nn_layers = nn_spec.layers
        layers_to_delete = []

        input_to_parent_layers, output_to_layers = blob_name_to_layers(nn_layers)

        for layer in nn_layers:
            # Only start with the last element of the transpose layers sequence
            if not layer.WhichOneof("layer") == "transpose":
                continue
            if (
                layer.output[0] in output_to_layers
                and len(output_to_layers[layer.output[0]]) == 1
                and output_to_layers[layer.output[0]][0].WhichOneof("layer")
                == "transpose"
            ):
                continue

            # Get the transpose layers sequence
            layers = []
            cursor = layer
            while True:
                if cursor.output[0] in output_to_layers:
                    layers.append(cursor)
                if not cursor.input[0] in input_to_parent_layers:
                    break
                cursor = input_to_parent_layers[cursor.input[0]]
                if cursor.WhichOneof("layer") != "transpose":
                    break
                if len(output_to_layers[cursor.output[0]]) != 1:
                    break
            layers = layers[::-1]

            if len(layers) == 0:
                continue

            # Optimize for the number of layers which can be merged using dynamic programming
            def solve_dp(layers):
                """
                The resulting dp[i] means the maximum length of transpose sequence resulting
                in identity starting at index i
                For example, dp[0] = 0 means there is no sequence starting at 0 results in identity
                dp[10] = 5 means the longest identity sequence starts at 10 is 5,
                so [layers[10],layer[11],..,layer[14]] is the longest identity sequence start at 10.

                # dic: {tuple:int}
                # key is the net transpose axes pattern starting from the first layer
                # value is the highest id of the layer which has this pattern
                # e.g. if dic[(1,2,0)] = 34, it means that starting from the 1st layer,
                # the net transpose pattern  `(1,2,0)` is last seen at layer id 34. No layer after 34-th
                # layer will result in the net pattern `(1,2,0)`
                """
                dim = len(layers[0].transpose.axes)
                dp = [0] * len(layers)
                dic = {}
                axes = list(range(dim))
                dic[tuple(axes)] = 0
                for i in range(len(layers)):
                    axes = [axes[k] for k in layers[i].transpose.axes]
                    key = tuple(axes)
                    if key in dic:
                        dp[dic[key]] = i - dic[key] + 1
                    dic[key] = i + 1
                for i in range(len(layers) - 1, -1, -1):
                    j = i + dp[i]
                    if j < len(layers):
                        dp[i] = dp[i] + dp[j]
                return dp

            dp = solve_dp(layers)

            """
            Once we know the maximum identity sequence starts at each index, we solve
            for the maximum total node we can remove.
            I think there must be lots of different solution for this, but I use DP again.
            sol_num[i] keeps track of the maximum number of nodes can be remove after index i
            For example, if sol_num[10] = 5, this means after index 10, we can at most remove 5 nodes.
            sol_bt[i] keeps the first starting point of identity sequence which results in the
            optimal solution after index i.
            For example, if sol_num[10] = 12, means that in order to get rid of the maxium number of
            nodes after 10, the first starting point is index 12.
            After construct sol_num and sol_bt by dynamic programming, we backtrack for the optimal
            solution using sol_bt.
            """
            sol_num = [0] * len(dp)
            sol_bt = [None] * len(dp)
            if dp[-1] != 0:
                sol_num[-1] = dp[-1]
                sol_bt[-1] = len(dp) - 1
            for i in range(len(sol_num) - 2, -1, -1):
                if dp[i] == 0:
                    sol_num[i] = sol_num[i + 1]
                    sol_bt[i] = sol_bt[i + 1]
                else:
                    num = dp[i]
                    j = i + dp[i]
                    if j < len(sol_num):
                        num += sol_num[j]
                    if num > sol_num[i + 1]:
                        sol_num[i] = num
                        sol_bt[i] = i
                    else:
                        sol_num[i] = sol_num[i + 1]
                        sol_bt[i] = sol_bt[i + 1]

            # Get layers to delete using sol_bt
            cursor = 0
            while cursor < len(dp):
                if sol_bt[cursor] == None:
                    break
                cursor = sol_bt[cursor]
                tmp = [layers[i] for i in range(cursor, cursor + dp[cursor])]
                layers_to_delete.append(tmp)
                cursor += dp[cursor]

        return layers_to_delete

    nn_spec = _get_nn_spec(spec)
    layers_to_delete = _find_redundant_transposes(nn_spec)
    if len(layers_to_delete) > 0:
        _delete_layers(nn_spec, layers_to_delete)
        print("{} transpose pairs deleted".format(len(layers_to_delete)))