def multiply_layer()

in dags/multipod/legacy_tests/gpt1-like.py [0:0]


def multiply_layer(in_act, in_layer):
  Q = (
      in_act @ in_layer["WQ"]
  )  # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2* BATCH * SEQUENCE_LENGTH * D_MODEL * D_HIDDEN
  K = (
      in_act @ in_layer["WK"]
  )  # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2* BATCH * SEQUENCE_LENGTH * D_MODEL * D_HIDDEN
  V = (
      in_act @ in_layer["WV"]
  )  # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2* BATCH * SEQUENCE_LENGTH * D_MODEL * D_HIDDEN
  A = jax.numpy.einsum(
      "bsd,btd->bst", Q, K
  )  # BATCH x SEQUENCE_LENGTH x SEQUENCE_LENGTH, flops : 2 * BATCH * SEQUENCE_LENGTH^2 * D_HIDDEN
  A = jax.nn.relu(A)  # TODO(correct low arithmetic intensity manips)
  post_attention = (
      A @ V
  )  # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2 * BATCH * SEQUENCE_LENGTH^2 * D_HIDDEN

  right_shape = (
      post_attention @ in_layer["FF"]
  )  # BATCH x SEQUENCE_LENGTH x D_MODEL, flops: 2 * BATCH * SEQUENCE_LENGTH * D_HIDDEN * D_MODEL
  right_shape = jax.nn.relu(
      right_shape
  )  # TODO(correct low arithmetic intensity manips)
  return right_shape + 1 + in_act