def multihead_attention_counter_hook()

in models/src/ptflops/flops_counter.py [0:0]


def multihead_attention_counter_hook(multihead_attention_module, input, output):
    flops = 0
    q, k, v = input
    batch_size = q.shape[1]

    num_heads = multihead_attention_module.num_heads
    embed_dim = multihead_attention_module.embed_dim
    kdim = multihead_attention_module.kdim
    vdim = multihead_attention_module.vdim
    if kdim is None:
        kdim = embed_dim
    if vdim is None:
        vdim = embed_dim

    # initial projections
    flops = (
        q.shape[0] * q.shape[2] * embed_dim
        + k.shape[0] * k.shape[2] * kdim
        + v.shape[0] * v.shape[2] * vdim
    )
    if multihead_attention_module.in_proj_bias is not None:
        flops += (q.shape[0] + k.shape[0] + v.shape[0]) * embed_dim

    # attention heads: scale, matmul, softmax, matmul
    head_dim = embed_dim // num_heads
    head_flops = (
        q.shape[0] * head_dim
        + head_dim * q.shape[0] * k.shape[0]
        + q.shape[0] * k.shape[0]
        + q.shape[0] * k.shape[0] * head_dim
    )

    flops += num_heads * head_flops

    # final projection, bias is always enabled
    flops += q.shape[0] * embed_dim * (embed_dim + 1)

    flops *= batch_size
    multihead_attention_module.__flops__ += int(flops)