def gen_layers()

in dags/multipod/legacy_tests/gpt1-like.py [0:0]


def gen_layers(random_key):
  layers = []
  for _ in range(NUM_LAYERS):
    random_key, sub_key = jax.random.split(random_key)
    layers.append(gen_layer(sub_key))
  return tuple(layers)