in optimum/amd/brevitas/export.py [0:0]
def replace_gemm_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0):
for found_pattern in found_nodes:
node_count += 1
gemm = graph.nodemap[found_pattern[-1]]
bias = gemm.input[-1]
deq_linear = graph.nodemap[found_pattern[0]]
dyn_q = graph.nodemap[found_pattern[1]]
dq_weight = deq_linear.prevnodes[0]
graph.add_initial(f"dq_weights_0_{node_count}", dq_weight.value.transpose())
graph.add_initial(f"dq_weights_1_{node_count}", deq_linear.prevnodes[1].value)
graph.add_initial(f"dq_weights_2_{node_count}", deq_linear.prevnodes[2].value)
matmul = graph.nodemap[found_pattern[-1]]
for name in found_pattern:
if "DynamicQuantizeLinear" in name:
continue
graph.remove_node(name)
graph.remove_node(deq_linear.prevnodes[0].name)
if deq_linear.prevnodes[1].name in graph.nodemap:
graph.remove_node(deq_linear.prevnodes[1].name)
if deq_linear.prevnodes[2].name in graph.nodemap:
graph.remove_node(deq_linear.prevnodes[2].name)
graph = create_nodes(
graph,
"MatMulInteger",
f"matmul_integer_{node_count}",
[dyn_q.output[0], f"dq_weights_0_{node_count}", dyn_q.output[2], f"dq_weights_2_{node_count}"],
[f"matmul_integer_{node_count}"],
)
graph = create_nodes(
graph,
"Cast",
f"cast_{node_count}",
[f"matmul_integer_{node_count}"],
[f"cast_{node_count}"],
to=ONNX_FLOAT32_IDENTIFIER,
)
graph = create_nodes(
graph,
"Mul",
f"mulscales_{node_count}",
[dyn_q.output[1], f"dq_weights_1_{node_count}"],
[f"mulscales_{node_count}"],
)
graph = create_nodes(
graph,
"Mul",
f"mulvalues_{node_count}",
[f"mulscales_{node_count}", f"cast_{node_count}"],
[f"mulvalues_{node_count}"],
)
graph = create_nodes(
graph, "Add", f"addbias_{node_count}", [bias, f"mulvalues_{node_count}"], [matmul.output[0]]
)
return graph