in dags/multipod/legacy_tests/gpt1-like.py [0:0]
def gen_layer(random_key):
keys = jax.random.split(random_key, num=4)
return {
"WQ": 1e-4
* jax.random.normal(
keys[0], (D_MODEL, D_HIDDEN), dtype=jax.numpy.bfloat16
),
"WK": 1e-4
* jax.random.normal(
keys[1], (D_MODEL, D_HIDDEN), dtype=jax.numpy.bfloat16
),
"WV": 1e-4
* jax.random.normal(
keys[2], (D_MODEL, D_HIDDEN), dtype=jax.numpy.bfloat16
),
"FF": 1e-4
* jax.random.normal(
keys[3], (D_HIDDEN, D_MODEL), dtype=jax.numpy.bfloat16
),
}