def params()

in utils/flops-params_py.py [0:0]


def params(hidden_size, num_heads, num_layers, seq_len=2048, vocab_size=32000, ffw_size=None, relative_attention=False):
    if ffw_size is None:
        ffw_size = 4 * hidden_size
    per_layer = 4*hidden_size*hidden_size # attention
    per_layer += 4*hidden_size # attention bias
    per_layer += 2 * ffw_size * hidden_size # dense
    per_layer += ffw_size + hidden_size # dense bias
    per_layer += 2 * hidden_size # layer norm
    if relative_attention:
        per_layer += hidden_size*hidden_size # relative position embeddings according to Dai et al.
    embeddings = 1 * hidden_size*vocab_size + vocab_size
    if not relative_attention:
        embeddings += seq_len*hidden_size
    N = num_layers * (per_layer) + embeddings
    return N