def find_and_insert_matmulinteger()

in optimum/amd/brevitas/export.py [0:0]


def find_and_insert_matmulinteger(model_path: str):
    # onnx_tool requires python 3.9+
    if sys.version_info[0] == 3 and sys.version_info[1] <= 8:
        raise RuntimeError("onnx_tool requires Python 3.9 or higher")

    LOGGER.info("Rewriting ONNX Graph with MatMulInteger")
    model_path = os.path.join(model_path, "model.onnx")
    cfg = {"constant_folding": False, "node_rename": False, "if_fixed_branch": None, "fixed_topk": 0, "verbose": False}

    onnx_model = onnx.load(model_path)

    # Extract model output
    original_output = copy.deepcopy(onnx_model.graph.output)

    model = Model(onnx_model, cfg)
    graph = model.graph

    pattern = FusionPattern(MATMUL)
    found_matmul_nodes = pattern.search_pattern(graph)
    matmul_node_count = len(found_matmul_nodes)
    LOGGER.info(f"Replacing {matmul_node_count} MatMul nodes with MatMulInteger")
    graph = replace_matmul_to_matmulinteger(graph, found_matmul_nodes)

    pattern = FusionPattern(GEMM)
    found_gemm_nodes = pattern.search_pattern(graph)
    gemm_node_count = len(found_gemm_nodes)
    LOGGER.info(f"Replacing {gemm_node_count} Gemm nodes with MatMulInteger + Add")
    graph = replace_gemm_to_matmulinteger(graph, found_gemm_nodes, matmul_node_count)

    graph.graph_reorder_nodes()

    LOGGER.info("Saving the new ONNX model")
    full_path = Path(model_path)

    graph = graph.make_graph_onnx(
        graph.nodemap.keys(), "graph", graph.input, graph.output, with_initializer=True, with_shape_info=False
    )

    attr = {"producer_name": "onnx_tool"}
    model_to_save = onnx.helper.make_model(graph, **attr)

    # onnx_tools might remove the output nodes from the ONNX graph, so we need to restore it.
    for out in original_output:
        if out not in model_to_save.graph.output:
            model_to_save.graph.output.append(out)

    model_to_save.ir_version = model.mproto.ir_version
    model_to_save.opset_import.pop()
    for opset in model.mproto.opset_import:
        model_to_save.opset_import.append(opset)

    check_and_save_model(model_to_save, full_path)