in scripts/float16.py [0:0]
def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
# 1. find all cast nodes in the graph
cast_node_list = []
input_name_to_cast_node_dict = {}
output_name_to_cast_node_dict = {}
# using name as key to point to a node. because node object cannot be key
name_to_node_dict = {}
for node in graph_proto.node:
if node.op_type == "Cast":
# if node.name not in ["graph_input_cast0", "graph_output_cast0"]:
cast_node_list.append(node)
name_to_node_dict[node.name] = node
for input_name in node.input:
input_name_to_cast_node_dict[input_name] = node
for output_name in node.output:
output_name_to_cast_node_dict[output_name] = node
# 2. find upstream and downstream node of the cast node
cast_node_upstream_dict = {} # mapping cast node(name) to its upstream node
cast_node_downstream_dict = {} # mapping cast node(name) to its downstream node
for current_node in graph_proto.node:
# find the downstream node(s)
for input_name in current_node.input:
if input_name in output_name_to_cast_node_dict:
# found the downstream node of the cast node, might be multiple
cast_node = output_name_to_cast_node_dict[input_name]
if cast_node.name not in cast_node_downstream_dict:
cast_node_downstream_dict[cast_node.name] = current_node
else: # already exists one downstream node, make it a list
existing_downstream_nodes = cast_node_downstream_dict[
cast_node.name
]
if isinstance(existing_downstream_nodes, list):
existing_downstream_nodes.append(current_node)
else: # make a list
existing_downstream_nodes = [
existing_downstream_nodes,
current_node,
]
cast_node_downstream_dict[cast_node.name] = (
existing_downstream_nodes
)
# find the upstream node
for output_name in current_node.output:
if output_name in input_name_to_cast_node_dict:
# found the upstream node of the cast node, should be unique
cast_node = input_name_to_cast_node_dict[output_name]
cast_node_upstream_dict[cast_node.name] = current_node
# 3. remove the cast node which upstream is 'Constant'
for cast_node_name, upstream_node in cast_node_upstream_dict.items():
cast_node = name_to_node_dict[cast_node_name]
if upstream_node.op_type == "Constant":
cast_node_list.remove(cast_node)
# 4. find (cast_to_fp16, cast_to_fp32) pairs where --fp32--> cast_to_fp16 --fp16--> cast_to_fp32.
remove_candidate = []
name_to_value_info = {
value_info.name: value_info
for value_info in itertools.chain(graph_proto.value_info, graph_proto.input)
}
def get_type(name: str) -> Optional[int]:
if name in name_to_value_info:
return name_to_value_info[name].type
else:
# `name` has no value info.
return None
for cast_node_name, downstream_node in cast_node_downstream_dict.items():
cast_node = name_to_node_dict[cast_node_name]
if len(cast_node.input) != 1:
raise RuntimeError(
f"Cast node {cast_node_name} should have only one input, but has {len(cast_node.input)}."
)
input_type = get_type(cast_node.input[0])
if input_type != onnx_proto.TensorProto.FLOAT:
continue
if isinstance(downstream_node, list):
for dn in downstream_node:
if (
dn.op_type == "Cast"
and dn.attribute[0].i == 32
and cast_node.attribute[0].i == 16
and dn in cast_node_list
and cast_node in cast_node_list
):
remove_candidate.append((cast_node, dn))
else:
if (
downstream_node.op_type == "Cast"
and cast_node.attribute[0].i == FLOAT16
and downstream_node.attribute[0].i == FLOAT32
and downstream_node in cast_node_list
and cast_node in cast_node_list
):
remove_candidate.append((cast_node, downstream_node))
# 5. change "upstream --fp32--> cast_to_fp16 --fp16--> cast_to_fp32 --fp32--> downstream" to
# "upstream --fp32--> downstream".
for cast_node_pair in remove_candidate:
first_cast_node = cast_node_pair[0]
second_cast_node = cast_node_pair[1]
upstream_node = cast_node_upstream_dict.get(first_cast_node.name)
downstream_node = cast_node_downstream_dict.get(second_cast_node.name)
if upstream_node is None and downstream_node is not None:
# The upstream_node should be graph input
out = first_cast_node.input[0]
for i, input_name in enumerate(downstream_node.input):
for output_name in second_cast_node.output:
if input_name == output_name:
# change the input as the upstream node's output
downstream_node.input[i] = out
elif upstream_node is not None and downstream_node is None:
raise ValueError(
"The downstream node of the second cast node should be graph output"
)
else:
# find the upstream node's output to first_cast_node
out = None
for output_name in upstream_node.output:
if output_name == first_cast_node.input[0]:
out = output_name
break
# find the downstream node's input as second_cast_node's output
for i, input_name in enumerate(downstream_node.input):
for output_name in second_cast_node.output:
if input_name == output_name:
# change the input as the upstream node's output
downstream_node.input[i] = out
# 6. remove the cast node pair
for cast_node_pair in remove_candidate:
graph_proto.node.remove(cast_node_pair[0])
graph_proto.node.remove(cast_node_pair[1])