def flops()

in models/vision_transformer.py [0:0]


    def flops(self, x):
        flops = 0
        seq_len, batch_size, hidden_dim = x.shape

        num_elems = x.numel() // batch_size
        flops += num_elems * 6  # ln_1 (* 2), x + input, ln_2 (* 2), x + y

        # self_attention
        # calculations are based on the fact that head_dim * num_heads = hidden_dim
        # so we collapse (hidden_dim // num_heads) * num_heads to hidden_dim
        flops += 3 * seq_len * (hidden_dim + 1) * hidden_dim  # projection with bias
        flops += hidden_dim * seq_len  # scaling
        flops += hidden_dim * seq_len * seq_len  # attention weights
        flops += self.num_heads * seq_len * seq_len  # softmax
        flops += hidden_dim * seq_len * seq_len  # attention application
        flops += seq_len * (hidden_dim + 1) * hidden_dim  # out projection with bias

        # mlp
        mlp_dim = self.mlp.linear_1.out_features
        flops += seq_len * (hidden_dim + 1) * mlp_dim  # linear_1
        flops += seq_len * mlp_dim  # act
        flops += seq_len * (mlp_dim + 1) * hidden_dim  # linear_2
        return flops * batch_size