def remove_unnecessary_cast_node()

in scripts/float16.py [0:0]


def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
    # 1. find all cast nodes in the graph
    cast_node_list = []
    input_name_to_cast_node_dict = {}
    output_name_to_cast_node_dict = {}
    # using name as key to point to a node. because node object cannot be key
    name_to_node_dict = {}
    for node in graph_proto.node:
        if node.op_type == "Cast":
            # if node.name not in ["graph_input_cast0", "graph_output_cast0"]:
            cast_node_list.append(node)

            name_to_node_dict[node.name] = node
            for input_name in node.input:
                input_name_to_cast_node_dict[input_name] = node
            for output_name in node.output:
                output_name_to_cast_node_dict[output_name] = node

    # 2. find upstream and downstream node of the cast node
    cast_node_upstream_dict = {}  # mapping cast node(name) to its upstream node
    cast_node_downstream_dict = {}  # mapping cast node(name) to its downstream node
    for current_node in graph_proto.node:
        # find the downstream node(s)
        for input_name in current_node.input:
            if input_name in output_name_to_cast_node_dict:
                # found the downstream node of the cast node, might be multiple
                cast_node = output_name_to_cast_node_dict[input_name]
                if cast_node.name not in cast_node_downstream_dict:
                    cast_node_downstream_dict[cast_node.name] = current_node
                else:  # already exists one downstream node, make it a list
                    existing_downstream_nodes = cast_node_downstream_dict[
                        cast_node.name
                    ]
                    if isinstance(existing_downstream_nodes, list):
                        existing_downstream_nodes.append(current_node)
                    else:  # make a list
                        existing_downstream_nodes = [
                            existing_downstream_nodes,
                            current_node,
                        ]
                        cast_node_downstream_dict[cast_node.name] = (
                            existing_downstream_nodes
                        )
        # find the upstream node
        for output_name in current_node.output:
            if output_name in input_name_to_cast_node_dict:
                # found the upstream node of the cast node, should be unique
                cast_node = input_name_to_cast_node_dict[output_name]
                cast_node_upstream_dict[cast_node.name] = current_node

    # 3. remove the cast node which upstream is 'Constant'
    for cast_node_name, upstream_node in cast_node_upstream_dict.items():
        cast_node = name_to_node_dict[cast_node_name]
        if upstream_node.op_type == "Constant":
            cast_node_list.remove(cast_node)

    # 4. find (cast_to_fp16, cast_to_fp32) pairs where --fp32--> cast_to_fp16 --fp16--> cast_to_fp32.
    remove_candidate = []

    name_to_value_info = {
        value_info.name: value_info
        for value_info in itertools.chain(graph_proto.value_info, graph_proto.input)
    }

    def get_type(name: str) -> Optional[int]:
        if name in name_to_value_info:
            return name_to_value_info[name].type
        else:
            # `name` has no value info.
            return None

    for cast_node_name, downstream_node in cast_node_downstream_dict.items():
        cast_node = name_to_node_dict[cast_node_name]
        if len(cast_node.input) != 1:
            raise RuntimeError(
                f"Cast node {cast_node_name} should have only one input, but has {len(cast_node.input)}."
            )

        input_type = get_type(cast_node.input[0])
        if input_type != onnx_proto.TensorProto.FLOAT:
            continue
        if isinstance(downstream_node, list):
            for dn in downstream_node:
                if (
                    dn.op_type == "Cast"
                    and dn.attribute[0].i == 32
                    and cast_node.attribute[0].i == 16
                    and dn in cast_node_list
                    and cast_node in cast_node_list
                ):
                    remove_candidate.append((cast_node, dn))
        else:
            if (
                downstream_node.op_type == "Cast"
                and cast_node.attribute[0].i == FLOAT16
                and downstream_node.attribute[0].i == FLOAT32
                and downstream_node in cast_node_list
                and cast_node in cast_node_list
            ):
                remove_candidate.append((cast_node, downstream_node))

    # 5. change "upstream --fp32--> cast_to_fp16 --fp16--> cast_to_fp32 --fp32--> downstream" to
    # "upstream --fp32--> downstream".
    for cast_node_pair in remove_candidate:
        first_cast_node = cast_node_pair[0]
        second_cast_node = cast_node_pair[1]
        upstream_node = cast_node_upstream_dict.get(first_cast_node.name)
        downstream_node = cast_node_downstream_dict.get(second_cast_node.name)
        if upstream_node is None and downstream_node is not None:
            # The upstream_node should be graph input
            out = first_cast_node.input[0]
            for i, input_name in enumerate(downstream_node.input):
                for output_name in second_cast_node.output:
                    if input_name == output_name:
                        # change the input as the upstream node's output
                        downstream_node.input[i] = out
        elif upstream_node is not None and downstream_node is None:
            raise ValueError(
                "The downstream node of the second cast node should be graph output"
            )
        else:
            # find the upstream node's output to first_cast_node
            out = None
            for output_name in upstream_node.output:
                if output_name == first_cast_node.input[0]:
                    out = output_name
                    break
            # find the downstream node's input as second_cast_node's output
            for i, input_name in enumerate(downstream_node.input):
                for output_name in second_cast_node.output:
                    if input_name == output_name:
                        # change the input as the upstream node's output
                        downstream_node.input[i] = out

    # 6. remove the cast node pair
    for cast_node_pair in remove_candidate:
        graph_proto.node.remove(cast_node_pair[0])
        graph_proto.node.remove(cast_node_pair[1])