def mha_mhca_detected()

in tools/converter/source/onnx/onnx_model_graph_opt.py [0:0]


    def mha_mhca_detected(self, node, mha):
        # Go from V GEMM down to the S*V MatMul and all way up to K GEMM
        # If we are looking for MHCA inputs of two matmuls (K and V) must be equal.
        # If we are looking for MHA inputs (K and V) must be not equal.
        
        if node.op == "MatMul" and len(node.outputs) == 1:
            if node.o().op == 'Shape':
                if node.o(1).op == 'Shape':
                    num_dynamic_kv = 3 if node.o(2).op == 'Shape' else 2
                else:
                    num_dynamic_kv = 1
                # For Cross-Attention, if batch axis is dynamic (in QKV), assume H*W (in Q) is dynamic as well
                num_dynamic_q = num_dynamic_kv# if mha else num_dynamic_kv + 1
            else:
                num_dynamic_kv = 0
                num_dynamic_q = 0
            o = node.o(num_dynamic_kv)
            
            # General Unet fmha/fmhca
            if o.op == "Reshape" and \
                o.o().op == "Transpose" and \
                o.o().o().op == "Reshape" and \
                o.o().o().o().op == "MatMul" and \
                o.o().o().o().i(0).op == "Softmax" and \
                o.o().o().o().i(1).op == "Reshape" and \
                o.o().o().o().i(0).i().op == "Mul" and \
                o.o().o().o().i(0).i().i().op == "MatMul" and \
                o.o().o().o().i(0).i().i().i(0).op == "Reshape" and \
                o.o().o().o().i(0).i().i().i(1).op == "Transpose" and \
                o.o().o().o().i(0).i().i().i(1).i().op == "Reshape" and \
                o.o().o().o().i(0).i().i().i(1).i().i().op == "Transpose" and \
                o.o().o().o().i(0).i().i().i(1).i().i().i().op == "Reshape" and \
                o.o().o().o().i(0).i().i().i(1).i().i().i().i().op == "MatMul" and \
                node.name != o.o().o().o().i(0).i().i().i(1).i().i().i().i().name:
                # "len(node.outputs) == 1" to make sure we are not in the already fused node
                node_q = o.o().o().o().i(0).i().i().i(0).i().i().i()
                node_k = o.o().o().o().i(0).i().i().i(1).i().i().i().i()
                node_v = node
                num_dynamic_transpose = self.get_useful_output_index(o.o().o().o())
                final_tranpose = o.o().o().o().o(num_dynamic_transpose).o()
                # Sanity check to make sure that the graph looks like expected
                if node_q.op == "MatMul" and final_tranpose.op == "Transpose":
                    #print("node_v:",node_v.name)                    
                    return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose, False

            # Unet fmhca (KV one source, start from Q)
            if o.op == "Reshape" and \
                o.o().op == "Transpose" and \
                o.o().o().op == "Reshape":
                node_reshape = o.o().o()
                num_loc_q = 0
                if node_reshape.o().op == 'Shape':
                    num_loc_q = 1
                    if node_reshape.o(1).op == 'Shape':
                        num_loc_q = 2
                        if node_reshape.o(2).op == 'Shape':
                            num_loc_q = 3
                if o.o().o().o(num_loc_q).op == "MatMul" and \
                    o.o().o().o(num_loc_q).o().op == "Mul":
                    node_softmax = o.o().o().o(num_loc_q).o().o()
                    if node_softmax.op == "Add":
                        node_softmax = node_softmax.o()
                    if node_softmax.op == "Softmax":
                        if node_softmax.o().op == "Cast":
                            node_softmax = node_softmax.o()
                    node_qk_matmul = o.o().o().o(num_loc_q)
                    '''
                    print(node_qk_matmul.i(1).op)
                    print(node_qk_matmul.i(1).i().op)
                    print(node_qk_matmul.i(1).i().i().op)
                    print(node_qk_matmul.i(1).i().i().i().op)
                    print(node_qk_matmul.i(1).i().i().i().i().op)
                    '''
                    # check K
                    if node_qk_matmul.i(1).op == "Transpose" and \
                        node_qk_matmul.i(1).i().op == "Reshape" and \
                        node_qk_matmul.i(1).i().i().op == "Transpose" and \
                        node_qk_matmul.i(1).i().i().i().op == "Reshape" and \
                        node_qk_matmul.i(1).i().i().i().i().op == "MatMul":
                        '''
                        print(node_softmax.o().i(1).op)
                        print(node_softmax.o().i(1).i().op)
                        print(node_softmax.o().i(1).i().i().op)
                        print(node_softmax.o().i(1).i().i().i().op)
                        '''
                        # check V
                        if node_softmax.o().i(1).op == "Reshape" and \
                            node_softmax.o().i(1).i().op == "Transpose" and \
                            node_softmax.o().i(1).i().i().op == "Reshape" and \
                            node_softmax.o().i(1).i().i().i().op == "MatMul" and \
                            node.name != node_softmax.o().i(1).i().i().i().name:
                            # "len(node.outputs) == 1" to make sure we are not in the already fused node
                            node_q = node
                            node_k = node_qk_matmul.i(1).i().i().i().i()
                            node_v = node_softmax.o().i(1).i().i().i()
                            num_dynamic_transpose = self.get_useful_output_index(node_softmax.o())
                            final_tranpose = node_softmax.o().o(num_dynamic_transpose).o()
                            # Sanity check to make sure that the graph looks like expected
                            '''
                            print(node_q.op)
                            print(node_k.op)
                            print(node_v.op)
                            print(final_tranpose.op)
                            '''
                            if node_v.op == "MatMul" and final_tranpose.op == "Transpose":
                                #print("node_v:",node_v.name)
                                return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose, False
            # ControlNet fmha
            if o.op == "Reshape" and \
                o.o().op == "Transpose" and \
                o.o().o().op == "Reshape" and \
                o.o().o().o().op == "Einsum":
                node_softmax = o.o().o().o().i(0)
                if node_softmax.op == "Cast":
                    node_softmax = node_softmax.i()
                '''
                print(o.o().o().o().i(1).op)
                print(node_softmax.i().op)
                print(node_softmax.i().i(0).op)
                print(node_softmax.i().i(1).op)
                print(node_softmax.i().i(1).i().op)
                print(node_softmax.i().i(1).i().i().op)
                print(node_softmax.i().i(1).i().i().i().op)
                print(node_softmax.i().i(1).i().i().i().i().op)
                print(node_softmax.i().i(1).i().i().i().i().name)
                '''
                if node_softmax.op == "Softmax" and \
                    o.o().o().o().i(1).op == "Reshape" and \
                    node_softmax.i().op == "Einsum" and \
                    node_softmax.i().i(0).op == "Slice" and \
                    node_softmax.i().i(1).op == "Reshape" and \
                    node_softmax.i().i(1).i().op == "Transpose" and \
                    node_softmax.i().i(1).i().i().op == "Reshape" and \
                    node_softmax.i().i(1).i().i().i().op == "Mul" and \
                    node_softmax.i().i(1).i().i().i().i().op == "MatMul" and \
                    node.name != node_softmax.i().i(1).i().i().i().i().name:
                    # "len(node.outputs) == 1" to make sure we are not in the already fused node
                    node_q = node_softmax.i().i(0).i().i().i().i()
                    node_k = node_softmax.i().i(1).i().i().i().i()
                    node_v = node
                    
                    qkv_einsum = o.o().o().o()

                    scatter_next_node = qkv_einsum.o().o().o()
                    
                    while scatter_next_node.op == "ScatterND":
                        if scatter_next_node.o(0).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(0)
                        elif len(scatter_next_node.outputs[0].outputs) > 1 and scatter_next_node.o(1).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(1)
                        elif len(scatter_next_node.outputs[0].outputs) > 2 and scatter_next_node.o(2).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(2)
                        elif len(scatter_next_node.outputs[0].outputs) > 3 and scatter_next_node.o(3).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(3)
                        elif len(scatter_next_node.outputs[0].outputs) > 4 and scatter_next_node.o(4).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(4)
                        else :
                            scatter_next_node = scatter_next_node.o()

                    node_ahead_reshape = scatter_next_node.i()
                    loop_num = 0
                    while ((node_ahead_reshape.op != "ScatterND" and node_ahead_reshape.op != "Cast") or (node_ahead_reshape.o().op != "Reshape" and node_ahead_reshape.o().op != "Shape")) and loop_num < 5:
                        #print(node_ahead_reshape.name)
                        loop_num = loop_num + 1
                        node_ahead_reshape = node_ahead_reshape.o()
                    num_dynamic_transpose = self.get_useful_output_index(node_ahead_reshape)
                    final_tranpose = node_ahead_reshape.o(num_dynamic_transpose).o()                   
                    # Sanity check to make sure that the graph looks like expected
                    '''
                    print(node_q.op)
                    print(node_k.op)
                    print(node_v.op)
                    print(final_tranpose.op)
                    '''
                    if node_q.op == "MatMul" and final_tranpose.op == "Transpose":
                        #print("node_v:",node_v.name)
                        return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose, True
       
            # ControlNet fmhca
            node_reshape = o.o().o()
            num_loc_q = 0
            if node_reshape.o().op == 'Shape':
                num_loc_q = 1
                if node_reshape.o(1).op == 'Shape':
                    num_loc_q = 2
                    if node_reshape.o(2).op == 'Shape':
                        num_loc_q = 3
            if o.op == "Reshape" and \
                o.o().op == "Transpose" and \
                o.o().o().op == "Reshape" and \
                o.o().o().o(num_loc_q).op == "Slice" and \
                o.o().o().o(num_loc_q).o().op == "Einsum":
                node_softmax = o.o().o().o(num_loc_q).o().o()
                '''
                print(node_softmax.op)
                print(o.o().o().o(num_loc_q).o().i(1).op)
                print(node_softmax.i().op)
                print(node_softmax.i().i(0).op)
                print(node_softmax.i().i(1).op)
                print(node_softmax.i().i(1).i().op)
                print(node_softmax.i().i(1).i().i().op)
                print(node_softmax.i().i(1).i().i().i().op)
                print(node_softmax.i().i(1).i().i().i().i().op)
                print(node_softmax.i().i(1).i().i().i().i().name)
                '''
                if node_softmax.op == "Softmax" and \
                    o.o().o().o(num_loc_q).o().i(1).op == "Reshape" and \
                    node_softmax.i().op == "Einsum" and \
                    node_softmax.i().i(0).op == "Slice" and \
                    node_softmax.i().i(1).op == "Reshape" and \
                    node_softmax.i().i(1).i().op == "Transpose" and \
                    node_softmax.i().i(1).i().i().op == "Reshape" and \
                    node_softmax.i().i(1).i().i().i().op == "Mul" and \
                    node_softmax.i().i(1).i().i().i().i().op == "MatMul" and \
                    node.name != node_softmax.i().i(1).i().i().i().i().name:
                    # "len(node.outputs) == 1" to make sure we are not in the already fused node
                    node_q = node_softmax.i().i(0).i().i().i().i()
                    node_k = node_softmax.i().i(1).i().i().i().i()
                    if node_softmax.o().op == "Cast":
                        node_softmax = node_softmax.o()
                    node_v = node_softmax.o().i(1).i().i().i()
                    #print("node_softmax:", node_softmax.name)
                    #print("node_v:",node_v.name)
                    
                    scatter_next_node = node_softmax.o().o().o().o().o()
                    
                    while scatter_next_node.op == "ScatterND":
                        if scatter_next_node.o(0).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(0)
                        elif len(scatter_next_node.outputs[0].outputs) > 1 and scatter_next_node.o(1).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(1)
                        elif len(scatter_next_node.outputs[0].outputs) > 2 and scatter_next_node.o(2).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(2)
                        elif len(scatter_next_node.outputs[0].outputs) > 3 and scatter_next_node.o(3).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(3)
                        elif len(scatter_next_node.outputs[0].outputs) > 4 and scatter_next_node.o(4).op == "ScatterND":
                            scatter_next_node = scatter_next_node.o(4)
                        else :
                            scatter_next_node = scatter_next_node.o()
                    node_ahead_reshape = scatter_next_node.i()
                    loop_num = 0
                    while ((node_ahead_reshape.op != "ScatterND" and node_ahead_reshape.op != "Cast") or (node_ahead_reshape.o().op != "Reshape" and node_ahead_reshape.o().op != "Shape")) and loop_num < 5:
                        loop_num = loop_num + 1
                        #print(node_ahead_reshape.name)
                        node_ahead_reshape = node_ahead_reshape.o()
                    num_dynamic_transpose = self.get_useful_output_index(node_ahead_reshape)
                    final_tranpose = node_ahead_reshape.o(num_dynamic_transpose).o()
                    # Sanity check to make sure that the graph looks like expected
                    '''
                    print(node_q.op)
                    print(node_k.op)
                    print(node_v.op)
                    print(final_tranpose.op)
                    '''
                    if node_v.op == "MatMul" and final_tranpose.op == "Transpose":
                        #print("node_v:",node_v.name)
                        return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose, True
        return False, 0, 0, None, None, None, None, False