in dags/multipod/legacy_tests/gpt1-like.py [0:0]
def gen_layers(random_key): layers = [] for _ in range(NUM_LAYERS): random_key, sub_key = jax.random.split(random_key) layers.append(gen_layer(sub_key)) return tuple(layers)