def replace_matmul_to_matmulinteger()

in optimum/amd/brevitas/export.py [0:0]


def replace_matmul_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0):
    for found_pattern in found_nodes:
        node_count += 1

        deq_linear = graph.nodemap[found_pattern[0]]
        dyn_q = graph.nodemap[found_pattern[2]]
        dq_weight = deq_linear.prevnodes[0]
        graph.add_initial(f"dq_weights_0_{node_count}", dq_weight.value.transpose())
        graph.add_initial(f"dq_weights_1_{node_count}", deq_linear.prevnodes[1].value)
        graph.add_initial(f"dq_weights_2_{node_count}", deq_linear.prevnodes[2].value)

        matmul = graph.nodemap[found_pattern[-1]]
        for name in found_pattern:
            if "DynamicQuantizeLinear" in name:
                continue
            graph.remove_node(name)

        graph.remove_node(deq_linear.prevnodes[0].name)
        if deq_linear.prevnodes[1].name in graph.nodemap:
            graph.remove_node(deq_linear.prevnodes[1].name)
        if deq_linear.prevnodes[2].name in graph.nodemap:
            graph.remove_node(deq_linear.prevnodes[2].name)

        graph = create_nodes(
            graph,
            "MatMulInteger",
            f"matmul_integer_{node_count}",
            [dyn_q.output[0], f"dq_weights_0_{node_count}", dyn_q.output[2], f"dq_weights_2_{node_count}"],
            [f"matmul_integer_{node_count}"],
        )
        graph = create_nodes(
            graph,
            "Cast",
            f"cast_{node_count}",
            [f"matmul_integer_{node_count}"],
            [f"cast_{node_count}"],
            to=ONNX_FLOAT32_IDENTIFIER,
        )
        graph = create_nodes(
            graph,
            "Mul",
            f"mulscales_{node_count}",
            [dyn_q.output[1], f"dq_weights_1_{node_count}"],
            [f"mulscales_{node_count}"],
        )
        graph = create_nodes(
            graph,
            "Mul",
            f"mulvalues_{node_count}",
            [f"mulscales_{node_count}", f"cast_{node_count}"],
            [matmul.output[0]],
        )
    return graph