in optimum/amd/brevitas/export.py [0:0]
def find_and_insert_matmulinteger(model_path: str):
# onnx_tool requires python 3.9+
if sys.version_info[0] == 3 and sys.version_info[1] <= 8:
raise RuntimeError("onnx_tool requires Python 3.9 or higher")
LOGGER.info("Rewriting ONNX Graph with MatMulInteger")
model_path = os.path.join(model_path, "model.onnx")
cfg = {"constant_folding": False, "node_rename": False, "if_fixed_branch": None, "fixed_topk": 0, "verbose": False}
onnx_model = onnx.load(model_path)
# Extract model output
original_output = copy.deepcopy(onnx_model.graph.output)
model = Model(onnx_model, cfg)
graph = model.graph
pattern = FusionPattern(MATMUL)
found_matmul_nodes = pattern.search_pattern(graph)
matmul_node_count = len(found_matmul_nodes)
LOGGER.info(f"Replacing {matmul_node_count} MatMul nodes with MatMulInteger")
graph = replace_matmul_to_matmulinteger(graph, found_matmul_nodes)
pattern = FusionPattern(GEMM)
found_gemm_nodes = pattern.search_pattern(graph)
gemm_node_count = len(found_gemm_nodes)
LOGGER.info(f"Replacing {gemm_node_count} Gemm nodes with MatMulInteger + Add")
graph = replace_gemm_to_matmulinteger(graph, found_gemm_nodes, matmul_node_count)
graph.graph_reorder_nodes()
LOGGER.info("Saving the new ONNX model")
full_path = Path(model_path)
graph = graph.make_graph_onnx(
graph.nodemap.keys(), "graph", graph.input, graph.output, with_initializer=True, with_shape_info=False
)
attr = {"producer_name": "onnx_tool"}
model_to_save = onnx.helper.make_model(graph, **attr)
# onnx_tools might remove the output nodes from the ONNX graph, so we need to restore it.
for out in original_output:
if out not in model_to_save.graph.output:
model_to_save.graph.output.append(out)
model_to_save.ir_version = model.mproto.ir_version
model_to_save.opset_import.pop()
for opset in model.mproto.opset_import:
model_to_save.opset_import.append(opset)
check_and_save_model(model_to_save, full_path)