def is_ending_with_noop_edge()

in tinynn/converter/operators/optimize.py [0:0]


def is_ending_with_noop_edge(edge: ig.Edge, graph_converter: ig.Graph, branch: bool = False):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]

    if branch:
        source_cond_var = source_vertex.outdegree() >= 1
    else:
        source_cond_var = source_vertex.outdegree() == 1

    return (
        source_cond_var
        and target_vertex.outdegree() >= 1
        and target_vertex['op'] is not None
        and target_vertex['op'].inputs[0].name in source_vertex['outputs']
        and (
            (
                target_vertex['node_type'] == ExtendedOperator.RESHAPE
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
                and (np.diff(target_vertex['op'].inputs[1].tensor) == 1).all()
            )
            or (
                target_vertex['node_type']
                in (ExtendedOperator.PAD, ExtendedOperator.PADV2, ExtendedOperator.MIRROR_PAD)
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.TILE
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.CONCATENATION
                and len(target_vertex['op'].inputs) == 1
                and len(target_vertex['op'].outputs) == 1
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.GATHER
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
                and (np.diff(target_vertex['op'].inputs[1].tensor) == 1).all()
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.CAST
                and target_vertex['op'].inDataType == target_vertex['op'].outDataType
            )
        )
    )