in models/vision_transformer.py [0:0]
def flops(self, x):
flops = 0
seq_len, batch_size, hidden_dim = x.shape
num_elems = x.numel() // batch_size
flops += num_elems * 6 # ln_1 (* 2), x + input, ln_2 (* 2), x + y
# self_attention
# calculations are based on the fact that head_dim * num_heads = hidden_dim
# so we collapse (hidden_dim // num_heads) * num_heads to hidden_dim
flops += 3 * seq_len * (hidden_dim + 1) * hidden_dim # projection with bias
flops += hidden_dim * seq_len # scaling
flops += hidden_dim * seq_len * seq_len # attention weights
flops += self.num_heads * seq_len * seq_len # softmax
flops += hidden_dim * seq_len * seq_len # attention application
flops += seq_len * (hidden_dim + 1) * hidden_dim # out projection with bias
# mlp
mlp_dim = self.mlp.linear_1.out_features
flops += seq_len * (hidden_dim + 1) * mlp_dim # linear_1
flops += seq_len * mlp_dim # act
flops += seq_len * (mlp_dim + 1) * hidden_dim # linear_2
return flops * batch_size