def gen_layer()

in dags/multipod/legacy_tests/gpt1-like.py [0:0]


def gen_layer(random_key):
  keys = jax.random.split(random_key, num=4)
  return {
      "WQ": 1e-4
      * jax.random.normal(
          keys[0], (D_MODEL, D_HIDDEN), dtype=jax.numpy.bfloat16
      ),
      "WK": 1e-4
      * jax.random.normal(
          keys[1], (D_MODEL, D_HIDDEN), dtype=jax.numpy.bfloat16
      ),
      "WV": 1e-4
      * jax.random.normal(
          keys[2], (D_MODEL, D_HIDDEN), dtype=jax.numpy.bfloat16
      ),
      "FF": 1e-4
      * jax.random.normal(
          keys[3], (D_HIDDEN, D_MODEL), dtype=jax.numpy.bfloat16
      ),
  }