in models/src/ptflops/flops_counter.py [0:0]
def multihead_attention_counter_hook(multihead_attention_module, input, output):
flops = 0
q, k, v = input
batch_size = q.shape[1]
num_heads = multihead_attention_module.num_heads
embed_dim = multihead_attention_module.embed_dim
kdim = multihead_attention_module.kdim
vdim = multihead_attention_module.vdim
if kdim is None:
kdim = embed_dim
if vdim is None:
vdim = embed_dim
# initial projections
flops = (
q.shape[0] * q.shape[2] * embed_dim
+ k.shape[0] * k.shape[2] * kdim
+ v.shape[0] * v.shape[2] * vdim
)
if multihead_attention_module.in_proj_bias is not None:
flops += (q.shape[0] + k.shape[0] + v.shape[0]) * embed_dim
# attention heads: scale, matmul, softmax, matmul
head_dim = embed_dim // num_heads
head_flops = (
q.shape[0] * head_dim
+ head_dim * q.shape[0] * k.shape[0]
+ q.shape[0] * k.shape[0]
+ q.shape[0] * k.shape[0] * head_dim
)
flops += num_heads * head_flops
# final projection, bias is always enabled
flops += q.shape[0] * embed_dim * (embed_dim + 1)
flops *= batch_size
multihead_attention_module.__flops__ += int(flops)