in tools/converter/source/onnx/onnx_model_graph_opt.py [0:0]
def mha_mhca_detected(self, node, mha):
# Go from V GEMM down to the S*V MatMul and all way up to K GEMM
# If we are looking for MHCA inputs of two matmuls (K and V) must be equal.
# If we are looking for MHA inputs (K and V) must be not equal.
if node.op == "MatMul" and len(node.outputs) == 1:
if node.o().op == 'Shape':
if node.o(1).op == 'Shape':
num_dynamic_kv = 3 if node.o(2).op == 'Shape' else 2
else:
num_dynamic_kv = 1
# For Cross-Attention, if batch axis is dynamic (in QKV), assume H*W (in Q) is dynamic as well
num_dynamic_q = num_dynamic_kv# if mha else num_dynamic_kv + 1
else:
num_dynamic_kv = 0
num_dynamic_q = 0
o = node.o(num_dynamic_kv)
# General Unet fmha/fmhca
if o.op == "Reshape" and \
o.o().op == "Transpose" and \
o.o().o().op == "Reshape" and \
o.o().o().o().op == "MatMul" and \
o.o().o().o().i(0).op == "Softmax" and \
o.o().o().o().i(1).op == "Reshape" and \
o.o().o().o().i(0).i().op == "Mul" and \
o.o().o().o().i(0).i().i().op == "MatMul" and \
o.o().o().o().i(0).i().i().i(0).op == "Reshape" and \
o.o().o().o().i(0).i().i().i(1).op == "Transpose" and \
o.o().o().o().i(0).i().i().i(1).i().op == "Reshape" and \
o.o().o().o().i(0).i().i().i(1).i().i().op == "Transpose" and \
o.o().o().o().i(0).i().i().i(1).i().i().i().op == "Reshape" and \
o.o().o().o().i(0).i().i().i(1).i().i().i().i().op == "MatMul" and \
node.name != o.o().o().o().i(0).i().i().i(1).i().i().i().i().name:
# "len(node.outputs) == 1" to make sure we are not in the already fused node
node_q = o.o().o().o().i(0).i().i().i(0).i().i().i()
node_k = o.o().o().o().i(0).i().i().i(1).i().i().i().i()
node_v = node
num_dynamic_transpose = self.get_useful_output_index(o.o().o().o())
final_tranpose = o.o().o().o().o(num_dynamic_transpose).o()
# Sanity check to make sure that the graph looks like expected
if node_q.op == "MatMul" and final_tranpose.op == "Transpose":
#print("node_v:",node_v.name)
return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose, False
# Unet fmhca (KV one source, start from Q)
if o.op == "Reshape" and \
o.o().op == "Transpose" and \
o.o().o().op == "Reshape":
node_reshape = o.o().o()
num_loc_q = 0
if node_reshape.o().op == 'Shape':
num_loc_q = 1
if node_reshape.o(1).op == 'Shape':
num_loc_q = 2
if node_reshape.o(2).op == 'Shape':
num_loc_q = 3
if o.o().o().o(num_loc_q).op == "MatMul" and \
o.o().o().o(num_loc_q).o().op == "Mul":
node_softmax = o.o().o().o(num_loc_q).o().o()
if node_softmax.op == "Add":
node_softmax = node_softmax.o()
if node_softmax.op == "Softmax":
if node_softmax.o().op == "Cast":
node_softmax = node_softmax.o()
node_qk_matmul = o.o().o().o(num_loc_q)
'''
print(node_qk_matmul.i(1).op)
print(node_qk_matmul.i(1).i().op)
print(node_qk_matmul.i(1).i().i().op)
print(node_qk_matmul.i(1).i().i().i().op)
print(node_qk_matmul.i(1).i().i().i().i().op)
'''
# check K
if node_qk_matmul.i(1).op == "Transpose" and \
node_qk_matmul.i(1).i().op == "Reshape" and \
node_qk_matmul.i(1).i().i().op == "Transpose" and \
node_qk_matmul.i(1).i().i().i().op == "Reshape" and \
node_qk_matmul.i(1).i().i().i().i().op == "MatMul":
'''
print(node_softmax.o().i(1).op)
print(node_softmax.o().i(1).i().op)
print(node_softmax.o().i(1).i().i().op)
print(node_softmax.o().i(1).i().i().i().op)
'''
# check V
if node_softmax.o().i(1).op == "Reshape" and \
node_softmax.o().i(1).i().op == "Transpose" and \
node_softmax.o().i(1).i().i().op == "Reshape" and \
node_softmax.o().i(1).i().i().i().op == "MatMul" and \
node.name != node_softmax.o().i(1).i().i().i().name:
# "len(node.outputs) == 1" to make sure we are not in the already fused node
node_q = node
node_k = node_qk_matmul.i(1).i().i().i().i()
node_v = node_softmax.o().i(1).i().i().i()
num_dynamic_transpose = self.get_useful_output_index(node_softmax.o())
final_tranpose = node_softmax.o().o(num_dynamic_transpose).o()
# Sanity check to make sure that the graph looks like expected
'''
print(node_q.op)
print(node_k.op)
print(node_v.op)
print(final_tranpose.op)
'''
if node_v.op == "MatMul" and final_tranpose.op == "Transpose":
#print("node_v:",node_v.name)
return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose, False
# ControlNet fmha
if o.op == "Reshape" and \
o.o().op == "Transpose" and \
o.o().o().op == "Reshape" and \
o.o().o().o().op == "Einsum":
node_softmax = o.o().o().o().i(0)
if node_softmax.op == "Cast":
node_softmax = node_softmax.i()
'''
print(o.o().o().o().i(1).op)
print(node_softmax.i().op)
print(node_softmax.i().i(0).op)
print(node_softmax.i().i(1).op)
print(node_softmax.i().i(1).i().op)
print(node_softmax.i().i(1).i().i().op)
print(node_softmax.i().i(1).i().i().i().op)
print(node_softmax.i().i(1).i().i().i().i().op)
print(node_softmax.i().i(1).i().i().i().i().name)
'''
if node_softmax.op == "Softmax" and \
o.o().o().o().i(1).op == "Reshape" and \
node_softmax.i().op == "Einsum" and \
node_softmax.i().i(0).op == "Slice" and \
node_softmax.i().i(1).op == "Reshape" and \
node_softmax.i().i(1).i().op == "Transpose" and \
node_softmax.i().i(1).i().i().op == "Reshape" and \
node_softmax.i().i(1).i().i().i().op == "Mul" and \
node_softmax.i().i(1).i().i().i().i().op == "MatMul" and \
node.name != node_softmax.i().i(1).i().i().i().i().name:
# "len(node.outputs) == 1" to make sure we are not in the already fused node
node_q = node_softmax.i().i(0).i().i().i().i()
node_k = node_softmax.i().i(1).i().i().i().i()
node_v = node
qkv_einsum = o.o().o().o()
scatter_next_node = qkv_einsum.o().o().o()
while scatter_next_node.op == "ScatterND":
if scatter_next_node.o(0).op == "ScatterND":
scatter_next_node = scatter_next_node.o(0)
elif len(scatter_next_node.outputs[0].outputs) > 1 and scatter_next_node.o(1).op == "ScatterND":
scatter_next_node = scatter_next_node.o(1)
elif len(scatter_next_node.outputs[0].outputs) > 2 and scatter_next_node.o(2).op == "ScatterND":
scatter_next_node = scatter_next_node.o(2)
elif len(scatter_next_node.outputs[0].outputs) > 3 and scatter_next_node.o(3).op == "ScatterND":
scatter_next_node = scatter_next_node.o(3)
elif len(scatter_next_node.outputs[0].outputs) > 4 and scatter_next_node.o(4).op == "ScatterND":
scatter_next_node = scatter_next_node.o(4)
else :
scatter_next_node = scatter_next_node.o()
node_ahead_reshape = scatter_next_node.i()
loop_num = 0
while ((node_ahead_reshape.op != "ScatterND" and node_ahead_reshape.op != "Cast") or (node_ahead_reshape.o().op != "Reshape" and node_ahead_reshape.o().op != "Shape")) and loop_num < 5:
#print(node_ahead_reshape.name)
loop_num = loop_num + 1
node_ahead_reshape = node_ahead_reshape.o()
num_dynamic_transpose = self.get_useful_output_index(node_ahead_reshape)
final_tranpose = node_ahead_reshape.o(num_dynamic_transpose).o()
# Sanity check to make sure that the graph looks like expected
'''
print(node_q.op)
print(node_k.op)
print(node_v.op)
print(final_tranpose.op)
'''
if node_q.op == "MatMul" and final_tranpose.op == "Transpose":
#print("node_v:",node_v.name)
return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose, True
# ControlNet fmhca
node_reshape = o.o().o()
num_loc_q = 0
if node_reshape.o().op == 'Shape':
num_loc_q = 1
if node_reshape.o(1).op == 'Shape':
num_loc_q = 2
if node_reshape.o(2).op == 'Shape':
num_loc_q = 3
if o.op == "Reshape" and \
o.o().op == "Transpose" and \
o.o().o().op == "Reshape" and \
o.o().o().o(num_loc_q).op == "Slice" and \
o.o().o().o(num_loc_q).o().op == "Einsum":
node_softmax = o.o().o().o(num_loc_q).o().o()
'''
print(node_softmax.op)
print(o.o().o().o(num_loc_q).o().i(1).op)
print(node_softmax.i().op)
print(node_softmax.i().i(0).op)
print(node_softmax.i().i(1).op)
print(node_softmax.i().i(1).i().op)
print(node_softmax.i().i(1).i().i().op)
print(node_softmax.i().i(1).i().i().i().op)
print(node_softmax.i().i(1).i().i().i().i().op)
print(node_softmax.i().i(1).i().i().i().i().name)
'''
if node_softmax.op == "Softmax" and \
o.o().o().o(num_loc_q).o().i(1).op == "Reshape" and \
node_softmax.i().op == "Einsum" and \
node_softmax.i().i(0).op == "Slice" and \
node_softmax.i().i(1).op == "Reshape" and \
node_softmax.i().i(1).i().op == "Transpose" and \
node_softmax.i().i(1).i().i().op == "Reshape" and \
node_softmax.i().i(1).i().i().i().op == "Mul" and \
node_softmax.i().i(1).i().i().i().i().op == "MatMul" and \
node.name != node_softmax.i().i(1).i().i().i().i().name:
# "len(node.outputs) == 1" to make sure we are not in the already fused node
node_q = node_softmax.i().i(0).i().i().i().i()
node_k = node_softmax.i().i(1).i().i().i().i()
if node_softmax.o().op == "Cast":
node_softmax = node_softmax.o()
node_v = node_softmax.o().i(1).i().i().i()
#print("node_softmax:", node_softmax.name)
#print("node_v:",node_v.name)
scatter_next_node = node_softmax.o().o().o().o().o()
while scatter_next_node.op == "ScatterND":
if scatter_next_node.o(0).op == "ScatterND":
scatter_next_node = scatter_next_node.o(0)
elif len(scatter_next_node.outputs[0].outputs) > 1 and scatter_next_node.o(1).op == "ScatterND":
scatter_next_node = scatter_next_node.o(1)
elif len(scatter_next_node.outputs[0].outputs) > 2 and scatter_next_node.o(2).op == "ScatterND":
scatter_next_node = scatter_next_node.o(2)
elif len(scatter_next_node.outputs[0].outputs) > 3 and scatter_next_node.o(3).op == "ScatterND":
scatter_next_node = scatter_next_node.o(3)
elif len(scatter_next_node.outputs[0].outputs) > 4 and scatter_next_node.o(4).op == "ScatterND":
scatter_next_node = scatter_next_node.o(4)
else :
scatter_next_node = scatter_next_node.o()
node_ahead_reshape = scatter_next_node.i()
loop_num = 0
while ((node_ahead_reshape.op != "ScatterND" and node_ahead_reshape.op != "Cast") or (node_ahead_reshape.o().op != "Reshape" and node_ahead_reshape.o().op != "Shape")) and loop_num < 5:
loop_num = loop_num + 1
#print(node_ahead_reshape.name)
node_ahead_reshape = node_ahead_reshape.o()
num_dynamic_transpose = self.get_useful_output_index(node_ahead_reshape)
final_tranpose = node_ahead_reshape.o(num_dynamic_transpose).o()
# Sanity check to make sure that the graph looks like expected
'''
print(node_q.op)
print(node_k.op)
print(node_v.op)
print(final_tranpose.op)
'''
if node_v.op == "MatMul" and final_tranpose.op == "Transpose":
#print("node_v:",node_v.name)
return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose, True
return False, 0, 0, None, None, None, None, False