scripts/float16.py (644 lines of code) (raw):

# MIT License # # Copyright (c) Microsoft Corporation, Hugging Face. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. from typing import Optional import itertools import numpy as np import onnx import packaging.version as pv import warnings from onnx import helper, numpy_helper from onnx import onnx_pb as onnx_proto import onnxslim.third_party.onnx_graphsurgeon as gs FLOAT32 = 1 FLOAT16 = 10 def _npfloat16_to_int(np_list): """ Convert numpy float16 to python int. :param np_list: numpy float16 list :return int_list: python int list """ return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list] def convert_np_to_float16(np_array, min_positive_val=1e-7, max_finite_val=1e4): """ Convert float32 numpy array to float16 without changing sign or finiteness. Positive values less than min_positive_val are mapped to min_positive_val. Positive finite values greater than max_finite_val are mapped to max_finite_val. Similar for negative values. NaN, 0, inf, and -inf are unchanged. """ def between(a, b, c): return np.logical_and(a < b, b < c) positive_values = np_array[np.where(np_array > 0)] if positive_values.shape[0] > 0: pos_max = positive_values.max() pos_min = positive_values.min() if pos_max >= max_finite_val: warnings.warn( "the float32 number {} will be truncated to {}".format( pos_max, max_finite_val ) ) if pos_min <= min_positive_val: warnings.warn( "the float32 number {} will be truncated to {}".format( pos_min, min_positive_val ) ) negative_values = np_array[np.where(np_array < 0)] if negative_values.shape[0] > 0: neg_max = negative_values.max() neg_min = negative_values.min() if neg_min <= -max_finite_val: warnings.warn( "the float32 number {} will be truncated to {}".format( neg_min, -max_finite_val ) ) if neg_max >= -min_positive_val: warnings.warn( "the float32 number {} will be truncated to {}".format( neg_max, -min_positive_val ) ) np_array = np.where( between(0, np_array, min_positive_val), min_positive_val, np_array ) np_array = np.where( between(-min_positive_val, np_array, 0), -min_positive_val, np_array ) np_array = np.where( between(max_finite_val, np_array, float("inf")), max_finite_val, np_array ) np_array = np.where( between(float("-inf"), np_array, -max_finite_val), -max_finite_val, np_array ) return np.float16(np_array) def convert_tensor_float_to_float16(tensor, min_positive_val=1e-7, max_finite_val=1e4): """ Convert tensor float to float16. :param tensor: TensorProto object :return tensor_float16: converted TensorProto object """ if not isinstance(tensor, onnx_proto.TensorProto): raise ValueError( "Expected input type is an ONNX TensorProto but got %s" % type(tensor) ) if tensor.data_type == onnx_proto.TensorProto.FLOAT: tensor.data_type = onnx_proto.TensorProto.FLOAT16 # convert float_data (float type) to float16 and write to int32_data if tensor.float_data: float16_data = convert_np_to_float16( np.array(tensor.float_data), min_positive_val, max_finite_val ) int_list = _npfloat16_to_int(float16_data) tensor.int32_data[:] = int_list tensor.float_data[:] = [] # convert raw_data (bytes type) if tensor.raw_data: # convert n.raw_data to float float32_list = np.fromstring(tensor.raw_data, dtype="float32") # convert float to float16 float16_list = convert_np_to_float16( float32_list, min_positive_val, max_finite_val ) # convert float16 to bytes and write back to raw_data tensor.raw_data = float16_list.tostring() return tensor def make_value_info_from_tensor(tensor): shape = numpy_helper.to_array(tensor).shape return helper.make_tensor_value_info(tensor.name, tensor.data_type, shape) DEFAULT_OP_BLOCK_LIST = [ "ArrayFeatureExtractor", "Binarizer", "CastMap", "CategoryMapper", "DictVectorizer", "FeatureVectorizer", "Imputer", "LabelEncoder", "LinearClassifier", "LinearRegressor", "Normalizer", "OneHotEncoder", "RandomUniformLike", "SVMClassifier", "SVMRegressor", "Scaler", "TreeEnsembleClassifier", "TreeEnsembleRegressor", "ZipMap", "NonMaxSuppression", "TopK", "RoiAlign", "Resize", # 'Range', "CumSum", "Min", "Max", "Upsample", # NEW: "RandomNormalLike", # TODO: Ideally, "Cast" nodes should not be here, for the following reasons: # - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime". # - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it. # However, without it, the graphs produced are invalid. Eventually, we will resolve this. "Cast", ] def initial_checking(model, disable_shape_infer): func_infer_shape = None if not disable_shape_infer and pv.Version(onnx.__version__) >= pv.Version("1.2"): try: from onnx.shape_inference import infer_shapes func_infer_shape = infer_shapes finally: pass if not isinstance(model, onnx_proto.ModelProto): raise ValueError( "Expected model type is an ONNX ModelProto but got %s" % type(model) ) if func_infer_shape is not None: model = func_infer_shape(model) is_fp16_ready_flag = check_if_fp16_ready(model.graph) return model, func_infer_shape, is_fp16_ready_flag def convert_float_to_float16( model, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=False, disable_shape_infer=False, op_block_list=None, node_block_list=None, check_fp16_ready=True, ): # create blocklists if op_block_list is None: op_block_list = DEFAULT_OP_BLOCK_LIST if node_block_list is None: node_block_list = [] op_block_list = set(op_block_list) node_block_list = set(node_block_list) global_input_name_dict = ( {} ) # key: input name, value: new output name after Cast node # basic checking, including shape inference model, func_infer_shape, is_fp16_ready_flag = initial_checking( model, disable_shape_infer ) if is_fp16_ready_flag and check_fp16_ready: raise ValueError( "The model is already converted to float16, if convert again, the model might be wrong. \n If you are sure to convert again, please set check_fp16_ready=False." ) graph_stack = [model.graph] is_top_level = True while graph_stack: next_level = [] for curr_graph in graph_stack: process_graph_input( curr_graph, is_top_level, keep_io_types, global_input_name_dict ) value_info_block_list = process_tensor_in_node( curr_graph, op_block_list, node_block_list, min_positive_val, max_finite_val, ) process_value_info(curr_graph, value_info_block_list) process_node_in_block_list( curr_graph, global_input_name_dict, op_block_list, node_block_list ) process_initializers( curr_graph, op_block_list, node_block_list, min_positive_val, max_finite_val, ) process_graph_output(curr_graph, is_top_level, keep_io_types) sub_graph_list = get_next_level_graph( curr_graph, op_block_list, node_block_list ) if len(sub_graph_list) > 0: next_level.extend(sub_graph_list) if not is_top_level: process_node_input_output(curr_graph, global_input_name_dict) is_top_level = False # Going to process sub-graph graph_stack = next_level remove_unnecessary_cast_node(model.graph) # Topologically sort the graph # NOTE: We do not perform another round of optimization as the model is already optimized graph = gs.import_onnx(model) graph.toposort() model = gs.export_onnx(graph) return model # Change the input/output of the node to the new output name after Cast node for sub-graph # Because there have NO value_info start from def process_node_input_output( graph: onnx_proto.GraphProto, global_input_name_dict: dict ): for node in graph.node: for i, input_name in enumerate(node.input): if input_name in global_input_name_dict: node.input[i] = global_input_name_dict[input_name] for i, output_name in enumerate(node.output): if output_name in global_input_name_dict: node.output[i] = global_input_name_dict[output_name] def process_graph_input( graph: onnx_proto.GraphProto, is_top_level: bool, is_io_fp32: bool, global_input_name_dict: dict, ): # The input dtype is float32, need to cast to fp16 if is_top_level and is_io_fp32: for graph_input in graph.input: # n_input is ValueInfoProto if graph_input.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: downstream_nodes = find_downstream_node_by_input_name( graph, graph_input.name ) for d_node in downstream_nodes: # More than one node may consume the model input, so we only create # a single cast node, and then reuse this node when needed. cast_exists = graph_input.name in global_input_name_dict if cast_exists: cast_node_output_name = global_input_name_dict[graph_input.name] else: cast_node_output_name = graph_input.name + "_fp16" add_cast_node( graph, [graph_input.name], [cast_node_output_name], cast_node_output_name, # Set node name same as output name FLOAT16, ) add_new_value_info( graph, graph_input, cast_node_output_name, onnx_proto.TensorProto.FLOAT16, ) for i, input_name in enumerate(d_node.input): if input_name == graph_input.name: d_node.input[i] = ( cast_node_output_name # Change the input of the second node ) global_input_name_dict[graph_input.name] = ( cast_node_output_name ) # For the sub-graph, don't do cast else: # Change the input dtype to fp16 without any cast for graph_input in graph.input: if graph_input.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: graph_input.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 def process_graph_output( graph: onnx_proto.GraphProto, is_top_level: bool, is_io_fp32: bool ): if is_top_level and is_io_fp32: # the output dtype is float32, need to cast to fp16 for i, graph_output in enumerate(graph.output): if graph_output.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: new_producer_name = graph_output.name + "_fp16" original_name = graph_output.name # The correct output name # Get the node(s) that produce the model output # These will most likely be fp16, but could be fp32 if the previous node is in block_list upstream_nodes = find_upstream_node_by_output_name(graph, original_name) assert len(upstream_nodes) == 1 # Should be only one node producer_node = upstream_nodes[0] for i, output_name in enumerate(producer_node.output): if output_name == original_name: producer_node.output[i] = new_producer_name cast_node_name = new_producer_name + "_input_cast" + str(i) add_cast_node( graph, [new_producer_name], [original_name], cast_node_name, onnx_proto.TensorProto.FLOAT, ) for value_info in graph.value_info: if original_name == value_info.name: value_info.type.tensor_type.elem_type = ( onnx_proto.TensorProto.FLOAT ) # Get the node(s) that consume the model output downstream_nodes = find_downstream_node_by_input_name( graph, original_name, include_subgraphs=False, ) # It is possible that the producer node is also input to downstream nodes # So, we update the inputs of these downstream nodes for d_node in downstream_nodes: for i, input_name in enumerate(d_node.input): if input_name == original_name: d_node.input[i] = new_producer_name else: # change the output dtype to fp16 in tensor for graph_output in graph.output: if graph_output.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: graph_output.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 def process_node_in_block_list( graph: onnx_proto.GraphProto, global_input_name_dict: dict, op_block_list: list, node_block_list: list, ): # NB: Important to create a copy of the nodes in the graph to avoid modifying # the graph in-place while iterating (causing an infinite loop) for node in list(graph.node): if (node.op_type in op_block_list) or (node.name in node_block_list): insert_cast32_before_node(graph, node, global_input_name_dict) insert_cast16_after_node(graph, node, global_input_name_dict) # Todo: global_input_name_dict still not fill value def insert_cast32_before_node( graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict ): for i, input_name in enumerate(node.input): for value_info in itertools.chain(graph.value_info, graph.input): if input_name == value_info.name: if ( value_info.type.tensor_type.elem_type != onnx_proto.TensorProto.FLOAT16 ): break cast_output_name = node.name + "_input_cast_" + str(i) add_new_value_info( graph, value_info, cast_output_name, onnx_proto.TensorProto.FLOAT ) cast_node_name = node.name + "_input_cast" + str(i) add_cast_node( graph, [input_name], [cast_output_name], cast_node_name, onnx_proto.TensorProto.FLOAT, ) node.input[i] = cast_output_name break # Todo: global_input_name_dict still not fill value def insert_cast16_after_node( graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict ): for i, output_name in enumerate(node.output): for value_info in itertools.chain(graph.value_info, graph.output): if output_name == value_info.name: if ( value_info.type.tensor_type.elem_type != onnx_proto.TensorProto.FLOAT ): break cast_input_name = node.name + "_output_cast_" + str(i) add_new_value_info( graph, value_info, cast_input_name, onnx_proto.TensorProto.FLOAT ) value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 cast_node_name = node.name + "_output_cast" + str(i) add_cast_node( graph, [cast_input_name], [output_name], cast_node_name, onnx_proto.TensorProto.FLOAT16, ) node.output[i] = cast_input_name break # Process tensor data in attribute of the node def process_tensor_in_node( graph: onnx_proto.GraphProto, op_block_list: list, node_block_list: list, min_positive_val, max_finite_val, ): value_info_block_list = set() # This is for later use, not in this step for node in graph.node: # NOTE: "Cast" operation cannot change its output type because it is strongly typed. if ( (node.op_type in op_block_list) or (node.name in node_block_list) or (node.op_type == "Cast") ): # if (node.op_type in op_block_list) or (node.name in node_block_list): # Only need to block the output value_info changing for output_name in node.output: value_info_block_list.add(output_name) else: for attr in node.attribute: # one tensor if attr.t.data_type == onnx_proto.TensorProto.FLOAT: attr.t.CopyFrom( convert_tensor_float_to_float16( attr.t, min_positive_val, max_finite_val ) ) # list of tensor for t in attr.tensors: if t.data_type == onnx_proto.TensorProto.FLOAT: t.CopyFrom( convert_tensor_float_to_float16( t, min_positive_val, max_finite_val ) ) return value_info_block_list # Change all the value info type from float32 to float16 if not in block list def process_value_info(graph: onnx_proto.GraphProto, value_info_block_list: list): for value_info in graph.value_info: if value_info.name in value_info_block_list: continue else: if value_info.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 # Initializer is 'edge' type, so doesn't have value_info def process_initializers( graph: onnx_proto.GraphProto, op_block_list, node_block_list, min_positive_val, max_finite_val, ): # Find the input of the block node, don't need to change this kind of initializer initializer_block_list = set() for node in graph.node: if (node.op_type in op_block_list) or (node.name in node_block_list): for ( input_name ) in ( node.input ): # some is initializer, some is value_info, can't distinguish but doesn't matter initializer_block_list.add(input_name) # Process initializers for initializer in graph.initializer: if initializer.name not in initializer_block_list: if initializer.data_type == onnx_proto.TensorProto.FLOAT: convert_tensor_float_to_float16( initializer, min_positive_val, max_finite_val ) def get_next_level_graph( graph: onnx_proto.GraphProto, op_block_list: list, node_block_list: list ): sub_graph_list = [] for node in graph.node: if node.op_type in op_block_list or node.name in node_block_list: continue for attr in node.attribute: # Check if sub-graph exist if len(attr.g.node) > 0: # single sub-graph sub_graph_list.append(attr.g) for g in attr.graphs: if len(g.node) > 0: # multiple sub-graphs sub_graph_list.append(g) return sub_graph_list def add_cast_node( graph: onnx_proto.GraphProto, inputs: list, outputs: list, node_name: str, to_type: int, ): new_node = [helper.make_node("Cast", inputs, outputs, to=to_type, name=node_name)] graph.node.extend(new_node) def add_new_value_info( graph: onnx_proto.GraphProto, exist_value_info: onnx_proto.ValueInfoProto, name: str, dtype: int, ): new_value_info = graph.value_info.add() new_value_info.CopyFrom(exist_value_info) new_value_info.name = name new_value_info.type.tensor_type.elem_type = dtype # Find the node that has the specified output name def find_upstream_node_by_output_name(graph: onnx_proto.GraphProto, output_name: str): nodes = [] for node in graph.node: if output_name in node.output: nodes.append(node) assert len(nodes) <= 1 # Suppose there is less than one node found return nodes # Find the node that has the specified input name, including in subgraphs def find_downstream_node_by_input_name( graph: onnx_proto.GraphProto, input_name: str, include_subgraphs=True ): nodes = [] # Check nodes in current graph for node in graph.node: if input_name in node.input: nodes.append(node) if not include_subgraphs: continue # Recursively check subgraphs in node attributes for attr in node.attribute: if attr.type == onnx_proto.AttributeProto.GRAPH: # Single subgraph if len(attr.g.node) > 0: nodes.extend(find_downstream_node_by_input_name(attr.g, input_name)) # Multiple subgraphs if attr.type == onnx_proto.AttributeProto.GRAPHS: for g in attr.graphs: if len(g.node) > 0: nodes.extend(find_downstream_node_by_input_name(g, input_name)) return nodes # Remove identity node def remove_identity_node_from_model(model: onnx_proto.ModelProto): remove_identity_node_from_graph(model.graph) try: from onnx.shape_inference import infer_shapes func_infer_shape = infer_shapes model = func_infer_shape(model) return model finally: pass # Remove identity node def remove_identity_node_from_graph(graph: onnx_proto.GraphProto): for curr_node in graph.node: if curr_node.op_type == "Identity": for input_name in curr_node.input: upstream_nodes = find_upstream_node_by_output_name(graph, input_name) for u_node in upstream_nodes: if u_node is not None: u_node.output[0] = curr_node.output[0] graph.node.remove(curr_node) def convert_float_to_float16_model_path( model_path, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=False ): """ Convert tensor float type in the ONNX Model to tensor float16. *It is to fix an issue that infer_shapes func cannot be used to infer >2GB models. *But this function can be applied to all model sizes. :param model_path: ONNX Model path :return: converted ONNX ModelProto object Examples :: #Convert to ONNX ModelProto object and save model binary file: from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path new_onnx_model = convert_float_to_float16_model_path('model.onnx') onnx.save(new_onnx_model, 'new_model.onnx') """ disable_shape_infer = False if pv.Version(onnx.__version__) >= pv.Version("1.8"): try: # infer_shapes_path can be applied to all model sizes from onnx.shape_inference import infer_shapes_path import tempfile import os # shape_infer_model_path should be in the same folder of model_path with tempfile.NamedTemporaryFile( dir=os.path.dirname(model_path) ) as tmpfile: shape_infer_model_path = tmpfile.name infer_shapes_path(model_path, shape_infer_model_path) model = onnx.load(shape_infer_model_path) disable_shape_infer = True finally: pass if not disable_shape_infer: model = onnx.load(model_path) return convert_float_to_float16( model, min_positive_val, max_finite_val, keep_io_types, disable_shape_infer ) def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto): # 1. find all cast nodes in the graph cast_node_list = [] input_name_to_cast_node_dict = {} output_name_to_cast_node_dict = {} # using name as key to point to a node. because node object cannot be key name_to_node_dict = {} for node in graph_proto.node: if node.op_type == "Cast": # if node.name not in ["graph_input_cast0", "graph_output_cast0"]: cast_node_list.append(node) name_to_node_dict[node.name] = node for input_name in node.input: input_name_to_cast_node_dict[input_name] = node for output_name in node.output: output_name_to_cast_node_dict[output_name] = node # 2. find upstream and downstream node of the cast node cast_node_upstream_dict = {} # mapping cast node(name) to its upstream node cast_node_downstream_dict = {} # mapping cast node(name) to its downstream node for current_node in graph_proto.node: # find the downstream node(s) for input_name in current_node.input: if input_name in output_name_to_cast_node_dict: # found the downstream node of the cast node, might be multiple cast_node = output_name_to_cast_node_dict[input_name] if cast_node.name not in cast_node_downstream_dict: cast_node_downstream_dict[cast_node.name] = current_node else: # already exists one downstream node, make it a list existing_downstream_nodes = cast_node_downstream_dict[ cast_node.name ] if isinstance(existing_downstream_nodes, list): existing_downstream_nodes.append(current_node) else: # make a list existing_downstream_nodes = [ existing_downstream_nodes, current_node, ] cast_node_downstream_dict[cast_node.name] = ( existing_downstream_nodes ) # find the upstream node for output_name in current_node.output: if output_name in input_name_to_cast_node_dict: # found the upstream node of the cast node, should be unique cast_node = input_name_to_cast_node_dict[output_name] cast_node_upstream_dict[cast_node.name] = current_node # 3. remove the cast node which upstream is 'Constant' for cast_node_name, upstream_node in cast_node_upstream_dict.items(): cast_node = name_to_node_dict[cast_node_name] if upstream_node.op_type == "Constant": cast_node_list.remove(cast_node) # 4. find (cast_to_fp16, cast_to_fp32) pairs where --fp32--> cast_to_fp16 --fp16--> cast_to_fp32. remove_candidate = [] name_to_value_info = { value_info.name: value_info for value_info in itertools.chain(graph_proto.value_info, graph_proto.input) } def get_type(name: str) -> Optional[int]: if name in name_to_value_info: return name_to_value_info[name].type else: # `name` has no value info. return None for cast_node_name, downstream_node in cast_node_downstream_dict.items(): cast_node = name_to_node_dict[cast_node_name] if len(cast_node.input) != 1: raise RuntimeError( f"Cast node {cast_node_name} should have only one input, but has {len(cast_node.input)}." ) input_type = get_type(cast_node.input[0]) if input_type != onnx_proto.TensorProto.FLOAT: continue if isinstance(downstream_node, list): for dn in downstream_node: if ( dn.op_type == "Cast" and dn.attribute[0].i == 32 and cast_node.attribute[0].i == 16 and dn in cast_node_list and cast_node in cast_node_list ): remove_candidate.append((cast_node, dn)) else: if ( downstream_node.op_type == "Cast" and cast_node.attribute[0].i == FLOAT16 and downstream_node.attribute[0].i == FLOAT32 and downstream_node in cast_node_list and cast_node in cast_node_list ): remove_candidate.append((cast_node, downstream_node)) # 5. change "upstream --fp32--> cast_to_fp16 --fp16--> cast_to_fp32 --fp32--> downstream" to # "upstream --fp32--> downstream". for cast_node_pair in remove_candidate: first_cast_node = cast_node_pair[0] second_cast_node = cast_node_pair[1] upstream_node = cast_node_upstream_dict.get(first_cast_node.name) downstream_node = cast_node_downstream_dict.get(second_cast_node.name) if upstream_node is None and downstream_node is not None: # The upstream_node should be graph input out = first_cast_node.input[0] for i, input_name in enumerate(downstream_node.input): for output_name in second_cast_node.output: if input_name == output_name: # change the input as the upstream node's output downstream_node.input[i] = out elif upstream_node is not None and downstream_node is None: raise ValueError( "The downstream node of the second cast node should be graph output" ) else: # find the upstream node's output to first_cast_node out = None for output_name in upstream_node.output: if output_name == first_cast_node.input[0]: out = output_name break # find the downstream node's input as second_cast_node's output for i, input_name in enumerate(downstream_node.input): for output_name in second_cast_node.output: if input_name == output_name: # change the input as the upstream node's output downstream_node.input[i] = out # 6. remove the cast node pair for cast_node_pair in remove_candidate: graph_proto.node.remove(cast_node_pair[0]) graph_proto.node.remove(cast_node_pair[1]) # Check if the model is already converted to float16 def check_if_fp16_ready(graph_proto): # Check graph input and ouput is_value_info_fp16 = False for value_info in itertools.chain( graph_proto.output, graph_proto.input, graph_proto.value_info ): if value_info.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT16: is_value_info_fp16 = True break # Check initializer is_initializer_fp16 = False for initializer in graph_proto.initializer: if initializer.data_type == onnx_proto.TensorProto.FLOAT16: is_initializer_fp16 = True break # Check cast node has_cast_node_fp16 = False for node in graph_proto.node: if node.op_type == "Cast" and node.attribute[0].i == FLOAT16: has_cast_node_fp16 = True break # Any of above flags is True, return True if is_value_info_fp16 or is_initializer_fp16 or has_cast_node_fp16: return True # already converted to float16 else: return False # not converted to float16 yet