in dags/multipod/legacy_tests/gpt1-like.py [0:0]
def multiply_layer(in_act, in_layer):
Q = (
in_act @ in_layer["WQ"]
) # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2* BATCH * SEQUENCE_LENGTH * D_MODEL * D_HIDDEN
K = (
in_act @ in_layer["WK"]
) # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2* BATCH * SEQUENCE_LENGTH * D_MODEL * D_HIDDEN
V = (
in_act @ in_layer["WV"]
) # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2* BATCH * SEQUENCE_LENGTH * D_MODEL * D_HIDDEN
A = jax.numpy.einsum(
"bsd,btd->bst", Q, K
) # BATCH x SEQUENCE_LENGTH x SEQUENCE_LENGTH, flops : 2 * BATCH * SEQUENCE_LENGTH^2 * D_HIDDEN
A = jax.nn.relu(A) # TODO(correct low arithmetic intensity manips)
post_attention = (
A @ V
) # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2 * BATCH * SEQUENCE_LENGTH^2 * D_HIDDEN
right_shape = (
post_attention @ in_layer["FF"]
) # BATCH x SEQUENCE_LENGTH x D_MODEL, flops: 2 * BATCH * SEQUENCE_LENGTH * D_HIDDEN * D_MODEL
right_shape = jax.nn.relu(
right_shape
) # TODO(correct low arithmetic intensity manips)
return right_shape + 1 + in_act