def full_flops()

in utils/flops-params_py.py [0:0]


def full_flops(dataset_size, hidden_size, num_heads, num_layers, seq_len=2048, vocab_size=32000, ffw_size=None):
    if ffw_size is None:
        ffw_size = 4 * hidden_size
    embeddings_flops = 2 * seq_len * vocab_size * hidden_size
    attention_kqv_proj = 2 * 3 * seq_len * hidden_size * hidden_size
    attention_kq_logits = 2 * seq_len * seq_len * hidden_size
    attention_softmax = 3* num_heads* seq_len * seq_len
    attention_softmax_q_red = 2 * seq_len * seq_len * hidden_size
    attention_final_layer = 2 * seq_len * hidden_size * hidden_size
    dense_flops = 2 * seq_len * (hidden_size * ffw_size + ffw_size * hidden_size)
    final_logits = 2 * seq_len * hidden_size * vocab_size
    total_flops = embeddings_flops + num_layers*(attention_kqv_proj + attention_kq_logits +\
         attention_softmax + attention_softmax_q_red + attention_final_layer + \
            dense_flops) + final_logits
    return total_flops*3 * dataset_size/seq_len