onnx/summarization/bart_onnx/reduce_onnx_size.py (79 lines of code) (raw):
"""
Code to remove duplicate initializers to reduce ONNX model size.
"""
import os
import numpy
import onnx
def _is_equal_tensor_proto(a, b):
name_a = a.name
name_b = b.name
a.name = ""
b.name = ""
res = a == b
a.name = name_a
b.name = name_b
return res
def _node_replace_input_with(node_proto, name, new_name):
for i, input_name in enumerate(node_proto.input):
if input_name == name:
node_proto.input.insert(i, new_name)
node_proto.input.pop(i + 1)
if node_proto.op_type == "If":
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
if node_proto.op_type == "Loop":
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
def _graph_replace_input_with(graph_proto, name, new_name):
for n in graph_proto.node:
_node_replace_input_with(n, name, new_name)
def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
inits_with_data = list(model.graph.initializer)
inits = list(model_without_ext.graph.initializer)
for i, ref_i in ind_to_replace:
assert inits_with_data[i].name == inits[i].name
assert inits_with_data[ref_i].name == inits[ref_i].name
assert i > ref_i
name_i = inits[i].name
name_ref = inits[ref_i].name
model_without_ext.graph.initializer.remove(inits[i])
# for n in model.graph.node:
_graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
def remove_dup_initializers(onnx_file_path):
"""
Removes duplicate initializers from the model to reduce its size.
Writes a new file in the same directory as onnx_file_path and returns the path to that file.
"""
model_file_folder = os.path.dirname(onnx_file_path)
model_file_name = os.path.basename(onnx_file_path)
model = onnx.load(os.path.join(model_file_folder, model_file_name))
inits = list(model.graph.initializer)
dup_set = set()
dup_map = {}
ind_to_replace = []
total_reduced_size = 0
for i in range(len(inits)):
if i in dup_set:
continue
for j in range(i + 1, len(inits)):
if j in dup_set:
continue
if _is_equal_tensor_proto(inits[i], inits[j]):
dup_set.add(i)
dup_set.add(j)
dtype = inits[j].data_type
mem_size = numpy.prod(inits[j].dims)
if dtype == 1:
mem_size *= 4
elif dtype == 6:
mem_size *= 4
elif dtype == 7 or dtype == 11:
mem_size *= 8
else:
print("unexpected data type: ", dtype)
total_reduced_size += mem_size
name_i = inits[i].name
name_j = inits[j].name
if name_i in dup_map:
dup_map[name_i].append(name_j)
else:
dup_map[name_i] = [name_j]
ind_to_replace.append((j, i))
print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB")
ind_to_replace = sorted(ind_to_replace)
_remove_dup_initializers_from_model(model, model, ind_to_replace)
optimized_model_file_name = "optimized_" + model_file_name
new_model = os.path.join(model_file_folder, optimized_model_file_name)
onnx.save(model, new_model)
return new_model