in lingvo/jax/train.py [0:0]
def train_and_evaluate_spmd_model(
model_p: InstantiableParams, train_input_p: InstantiableParams,
job_log_dir: Optional[str],
checkpoint_manager: checkpoint_managers.CheckpointManager,
checkpoint_type: CheckpointType, restore_checkpoint_dir: Optional[str],
restore_checkpoint_step: Optional[int],
eval_input_p: Optional[Sequence[InstantiableParams]]) -> None:
"""Runs the training and evaluation loop.
Args:
model_p: Params for the SPMD model.
train_input_p: Params for the train data pipeline.
job_log_dir: Directory for the job logs.
checkpoint_manager: A checkpoint manager controlling how often to save and
delete checkpoints.
checkpoint_type: The type of checkpoint to use.
restore_checkpoint_dir: If set, the directory from which to restore
checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory
instead.
restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
try to restore from the latest checkpoint if any.
eval_input_p: Optional list of params for the eval input pipelines.
"""
logging.info('Using SPMD sharding for model parallelism.')
train_input_pipeline = train_input_p.Instantiate()
if eval_input_p is not None:
eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]
# Do not mutate eval_input_pipelines itself. Instantiate a new one
# to get sample input.
sample_eval_model_inputs = eval_input_p[0].Instantiate().get_next()
eval_test_inputs_shape = tf.nest.map_structure(
py_utils.get_global_input_shape_dtype, sample_eval_model_inputs)
eval_test_inputs_pspecs = trainer_lib.get_input_partition_specs(
model_p.mesh_axis_names, eval_test_inputs_shape)
# TODO(bf-jax): Retrieve the seeds from the model definition instead.
prng_key = jax.random.PRNGKey(1234)
prng_key, init_key = jax.random.split(prng_key)
checkpoint_dir = _checkpoint_dir(job_log_dir)
restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir
# Note that GDA checkpoint requires all processes to participate in
# checkpointing but it does not require a separate checkpoint_dir per process.
if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
checkpoint_task_dir = os.path.join(checkpoint_dir,
f'{jax.process_index():03d}')
restore_checkpoint_task_dir = os.path.join(restore_checkpoint_dir,
f'{jax.process_index():03d}')
else:
checkpoint_task_dir = checkpoint_dir
restore_checkpoint_task_dir = restore_checkpoint_dir
multi_host_checkpointing = bool(checkpoint_type in {
CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA
})
if jax.process_index() == 0:
tf.io.gfile.makedirs(checkpoint_dir)
if multi_host_checkpointing:
# Block all hosts until directory is ready.
py_utils.sync_global_devices(f'checkpointer:makedirs:{checkpoint_dir}')
logging.info('Retrieving model inputs for shape info.')
model_inputs_for_shape = train_input_pipeline.get_next()
inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
model_inputs_for_shape)
mesh_shape = model_p.device_mesh.shape
device_mesh = mesh_utils.create_device_mesh(mesh_shape)
logging.info('device_mesh: %s', device_mesh)
# TODO(zhangqiaorjc): maps.mesh should yield Mesh.
global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
with maps.mesh(device_mesh, model_p.mesh_axis_names):
(partitioned_train_state, train_state_pspecs, inputs_pspecs, train_step,
eval_step, total_num_params) = trainer_lib.partition_spmd_model(
model_p, init_key, inputs_shape)
partitioned_train_state = checkpoints.restore_checkpoint(
partitioned_train_state,
restore_checkpoint_task_dir,
global_mesh=global_mesh,
checkpoint_type=checkpoint_type,
state_specs=train_state_pspecs,
step=restore_checkpoint_step)
logging.info(
'partitioned_train_state shapes '
'(global shape for GDA, host-local shape for non-GDA: %s',
jax.tree_map(lambda x: x.shape, partitioned_train_state))
if multi_host_checkpointing:
py_utils.sync_global_devices(f'checkpointer:restored:{checkpoint_dir}')
# We do not fold in jax.process_index in contrast to the pmap version and
# use a single global key instead to rely on pjit to split for different
# replicas.
logging.info('root prng_key: %s', prng_key)
prng_key, train_key, eval_key = jax.random.split(prng_key, 3)
logging.info('train prng_key: %s', train_key)
logging.info('eval prng_key: %s', eval_key)
train_p = model_p.train
logging.info('Training loop starting...')
summary_base_dir = os.path.join(job_log_dir, 'summaries')
summary_train_dir = os.path.join(summary_base_dir, 'train')
summary_eval_dir = os.path.join(summary_base_dir, 'eval_train')
summary_writer = summary_utils.get_summary_writer
if eval_input_p is not None:
summary_eval_test_dirs = [
os.path.join(summary_base_dir, f'eval_test_{split}')
for split, _ in enumerate(eval_input_p)
]
# We either run p.eval_loop_num_batches steps or one epoch (when supported
# by a resettable input) per eval loop during training. When
# p.reset_for_eval is set to True, we run the eval loop until
# tf.errors.OutOfRangeError (or StopIteration) is raised, which can be
# triggered either because input pipeline has reached the end of the input
# sequence, or a pre-determined num_batches has reached.
eval_num_steps = [
-1 if p.reset_for_eval else p.eval_loop_num_batches
for p in eval_input_p
]
else:
summary_eval_test_dirs = []
with contextlib.ExitStack() as exit_stack:
train_summary_writer = exit_stack.enter_context(
summary_writer(summary_train_dir))
eval_summary_writer = exit_stack.enter_context(
summary_writer(summary_eval_dir))
eval_test_summary_writers = [
exit_stack.enter_context(summary_writer(d))
for d in summary_eval_test_dirs
]
# This only prints the view from the first host machine.
summary_utils.write_model_structure(
train_summary_writer,
partitioned_train_state,
is_vars_replicated=False)
summary_utils.write_total_num_params(train_summary_writer,
total_num_params)
summary_last_time = time.time()
summary_last_step = None
step_i = int(
jax.device_get(
py_utils.maybe_unreplicate_gda(partitioned_train_state.step)))
# Start the train loop. Make sure all at the same step.
py_utils.sync_global_devices(f'Start training loop from step: {step_i}')
while True:
logging.debug('step=`%d`: Beginning', step_i)
if step_i >= train_p.num_train_steps:
logging.info(
'Training loop completed (step (`%d`) greater than '
'num_train_step (`%d`).', step_i, train_p.num_train_steps)
break
if summary_last_step is None:
summary_last_step = step_i - 1
if checkpoint_manager.should_save(step_i):
logging.info('Saving a ckpt at step: %d', step_i)
if multi_host_checkpointing:
py_utils.sync_global_devices(
f'checkpointer:saving:{checkpoint_dir}:step-{step_i}')
if multi_host_checkpointing or jax.process_index() == 0:
checkpoints.save_checkpoint(
partitioned_train_state,
checkpoint_task_dir,
checkpoint_type=checkpoint_type,
state_specs=train_state_pspecs,
unreplicate=False)
checkpoint_manager.save_metadata(global_step_id=step_i)
if multi_host_checkpointing:
py_utils.sync_global_devices(
f'checkpointer:saved:{checkpoint_dir}:step-{step_i}')
# Get new model inputs
if step_i <= 5:
logging.info('step=`%d`: Retrieving model inputs.', step_i)
logging.debug(' Retrieving inputs.')
model_inputs = train_input_pipeline.get_next()
if jax.config.jax_parallel_functions_output_gda:
start = time.time()
py_utils.assert_same_shape_and_dtype(
inputs_shape,
tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
model_inputs))
model_inputs = py_utils.create_gda(model_inputs, inputs_shape,
global_mesh, inputs_pspecs)
logging.info('GDA train batch input creation time %s',
time.time() - start)
logging.debug(' Retrieved inputs.')
logging.debug(' Performing train_step().')
with jax.profiler.StepTraceAnnotation('train', step_num=step_i):
(partitioned_train_state, loss, metrics, per_example_out,
summary_tensors) = train_step(partitioned_train_state, train_key,
model_inputs)
logging.debug(' Completed train_step().')
logging.debug(' Writing summaries (attempt).')
if summary_utils.write_summary_every_n_steps(
partitioned_train_state,
train_summary_writer,
step_i,
train_p.summary_interval_steps,
loss,
metrics,
per_example_out,
summary_tensors,
train_p.norm_summary_interval_steps,
summary_last_time,
summary_last_step,
unreplicate_mdl_vars=False,
unreplicate_metrics=False):
summary_last_time = time.time()
summary_last_step = step_i
step_i = int(
py_utils.maybe_unreplicate_gda(partitioned_train_state.step))
else:
# Increment train step locally to avoid an explicit device sync.
step_i += 1
logging.debug(' Wrote summaries (attempted).')
# Run eval at regular step interval.
if step_i % train_p.eval_interval_steps == 0:
logging.debug(' Starting eval_step().')
logging.debug(' Retrieving eval model_inputs.')
eval_inputs = train_input_pipeline.get_next()
if jax.config.jax_parallel_functions_output_gda:
eval_inputs = py_utils.create_gda(eval_inputs, inputs_shape,
global_mesh, inputs_pspecs)
logging.debug(' Retrieved eval model_inputs.')
logging.debug(' Performing eval_step() runs on training split.')
eval_step_fn = functools.partial(eval_step,
partitioned_train_state.mdl_vars,
eval_key,
partitioned_train_state.step)
loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step(
eval_inputs, eval_step_fn, reshard_inputs=False)
logging.debug(' Completed eval_step() runs on training split.')
logging.info('step=`%d`', step_i)
logging.info(' eval loss: %s', loss)
logging.info(' mean_metrics: %s', mean_metrics)
logging.info(' summary_tensors: %s', summary_tensors)
if step_i % train_p.summary_interval_steps == 0:
logging.debug(' Writing eval summaries.')
summary_utils.write_summary_entry(
eval_summary_writer,
step_i,
loss,
mean_metrics,
summary_tensors,
unreplicate_metrics=False)
logging.debug(' Wrote eval summaries.')
# If we have eval test then also evaluate on test.
if eval_input_p is not None:
logging.debug(' Performing eval_step() runs on test splits.')
model_utils.run_eval_loop_over_test_splits(
eval_num_steps,
eval_step_fn,
eval_test_summary_writers,
step_i,
eval_input_pipelines,
eval_test_inputs_pspecs,
eval_test_inputs_shape,
global_mesh,
reshard_inputs=False)
logging.debug(' Completed eval_step() runs on test splits.')
logging.debug('step=`%d`: End', step_i - 1)