in utils/flops-params_py.py [0:0]
def simple_flops(dataset_size, hidden_size, num_heads, num_layers, seq_len=2048, vocab_size=32000, ffw_size=None, relative_attention=False):
if ffw_size is None:
ffw_size = 4 * hidden_size
return 6 * params(hidden_size=hidden_size, num_heads=num_heads, num_layers=num_layers, seq_len=seq_len, vocab_size=vocab_size, ffw_size=ffw_size, relative_attention=relative_attention) * dataset_size