in lingvo/jax/train.py [0:0]
def train_and_evaluate_pmap(
model_p: InstantiableParams, train_input_p: InstantiableParams,
job_log_dir: Optional[str],
checkpoint_manager: checkpoint_managers.CheckpointManager,
restore_checkpoint_dir: Optional[str],
restore_checkpoint_step: Optional[int],
eval_input_p: Optional[Sequence[InstantiableParams]]) -> None:
"""Runs the training and evaluation loop.
Args:
model_p: Params for the data parallel model.
train_input_p: Params for the train data input pipeline.
job_log_dir: Directory for the job logs.
checkpoint_manager: A checkpoint manager controlling how often to save and
delete checkpoints.
restore_checkpoint_dir: If set, the directory from which to restore
checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory
instead.
restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
try to restore from the latest checkpoint if any.
eval_input_p: Optional list of params for the eval input pipelines.
"""
logging.info('Using pmap for data parallelism.')
if jax.config.jax_parallel_functions_output_gda:
raise NotImplementedError(
'jax.pmap does not yet support GlobalDeviceArray.')
jax_model = model_p.Instantiate()
train_input_pipeline = train_input_p.Instantiate()
if eval_input_p is not None:
eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]
# TODO(shafey): Retrieve the seeds from the model definition instead.
prng_key = jax.random.PRNGKey(1234)
prng_key, init_key = jax.random.split(prng_key)
checkpoint_dir = _checkpoint_dir(job_log_dir)
restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir
model_states = trainer_lib.initialize_model_state(jax_model, init_key)
model_states = checkpoints.restore_checkpoint(
model_states,
restore_checkpoint_dir,
step=restore_checkpoint_step)
total_num_params = jax_model.total_num_vars
replicated_model_states = trainer_lib.replicate_model_state(model_states)
# Unreplicated model states are not needed anymore at that point.
del model_states
logging.info('replicated_model_states shapes: %s',
jax.tree_map(lambda x: x.shape, replicated_model_states))
# From now on, different replicas should use different random seeds.
# Here, each process will have its unique prng_key.
# prng_key will be further split so that each core on a host will get
# different prng_key.
prng_key = jax.random.fold_in(prng_key, jax.process_index())
logging.info('root prng_key: %s', prng_key)
fprop_dtype = model_p.fprop_dtype
def train_step(states, prng_key, inputs):
return trainer_lib.train_step_single_learner(
jax_model,
states,
prng_key,
inputs,
data_parallel_axis_name='batch',
fprop_dtype=fprop_dtype)
def eval_step(mdl_vars, prng_key, global_step, inputs):
return trainer_lib.eval_step_single_learner(
jax_model,
mdl_vars,
prng_key,
global_step,
inputs,
data_parallel_axis_name='batch',
fprop_dtype=fprop_dtype)
num_devices = jax.local_device_count()
prng_key, train_key, eval_key = jax.random.split(prng_key, 3)
train_prng_seed = jax.random.split(train_key, num=num_devices)
eval_prng_seed = jax.random.split(eval_key, num=num_devices)
logging.info('train prng_seed: %s', train_prng_seed)
logging.info('eval prng_seed: %s', eval_prng_seed)
p_train_step = jax.pmap(train_step, donate_argnums=(0,), axis_name='batch')
p_eval_step = jax.pmap(eval_step, axis_name='batch')
train_p = model_p.train
logging.info('Training loop starting...')
summary_base_dir = os.path.join(job_log_dir, 'summaries')
summary_train_dir = os.path.join(summary_base_dir, 'train')
summary_eval_dir = os.path.join(summary_base_dir, 'eval_train')
summary_writer = summary_utils.get_summary_writer
if eval_input_p is not None:
summary_test_split_dirs = [
os.path.join(summary_base_dir, f'eval_test_{split}')
for split, _ in enumerate(eval_input_p)
]
# We either run p.eval_loop_num_batches steps or one epoch (when supported
# by a resettable input) per eval loop during training. When
# p.reset_for_eval is set to True, we run the eval loop until
# tf.errors.OutOfRangeError (or StopIteration) is raised, which can be
# triggered either because input pipeline has reached the end of the input
# sequence, or a pre-determined num_batches has reached.
eval_num_steps = [
-1 if p.reset_for_eval else p.eval_loop_num_batches
for p in eval_input_p
]
else:
summary_test_split_dirs = []
with contextlib.ExitStack() as exit_stack:
train_summary_writer = exit_stack.enter_context(
summary_writer(summary_train_dir))
eval_summary_writer = exit_stack.enter_context(
summary_writer(summary_eval_dir))
eval_test_summary_writers = [
exit_stack.enter_context(summary_writer(d))
for d in summary_test_split_dirs
]
summary_utils.write_model_structure(
train_summary_writer, replicated_model_states, is_vars_replicated=True)
summary_utils.write_total_num_params(train_summary_writer, total_num_params)
summary_last_time = time.time()
summary_last_step = None
step_i = int(jax.device_get(replicated_model_states.step)[0])
while True:
logging.debug('step=`%d`: Beginning', step_i)
if step_i >= train_p.num_train_steps:
logging.info(
'Training loop completed (step (`%d`) greater than '
'num_train_step (`%d`).', step_i, train_p.num_train_steps)
break
if summary_last_step is None:
summary_last_step = step_i - 1
if checkpoint_manager.should_save(step_i):
if jax.process_index() == 0:
checkpoints.save_checkpoint(replicated_model_states, checkpoint_dir)
checkpoint_manager.save_metadata(global_step_id=step_i)
if step_i <= 5:
logging.info('step=`%d`: Retrieving model inputs.', step_i)
logging.debug(' Retrieving inputs.')
model_inputs = tf.nest.map_structure(py_utils.reshard,
train_input_pipeline.get_next())
logging.debug(' Retrieved inputs.')
logging.debug(' Performing train_step().')
with jax.profiler.StepTraceAnnotation('train', step_num=step_i):
(replicated_model_states, loss, metrics, per_example_out,
summary_tensors) = p_train_step(replicated_model_states,
train_prng_seed, model_inputs)
logging.debug(' Completed train_step().')
logging.debug(' Writing summaries (attempt).')
if summary_utils.write_summary_every_n_steps(
replicated_model_states,
train_summary_writer,
step_i,
train_p.summary_interval_steps,
loss,
metrics,
per_example_out,
summary_tensors,
train_p.norm_summary_interval_steps,
summary_last_time,
summary_last_step,
unreplicate_mdl_vars=True,
unreplicate_metrics=True):
summary_last_time = time.time()
summary_last_step = step_i
# Synchronize step_i
step_i = int(jax.device_get(replicated_model_states.step)[0])
else:
# Increment locally to avoid an explicit sync.
step_i += 1
logging.debug(' Wrote summaries (attempted).')
# Run eval at regular step interval.
if step_i % train_p.eval_interval_steps == 0:
logging.debug(' Starting eval_step().')
logging.debug(' Retrieving eval model_inputs.')
eval_inputs = train_input_pipeline.get_next()
logging.debug(' Retrieved eval model_inputs.')
logging.debug(' Performing eval_step() runs on training split.')
eval_step_fn = functools.partial(p_eval_step,
replicated_model_states.mdl_vars,
eval_prng_seed,
replicated_model_states.step)
loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step(
eval_inputs, eval_step_fn, reshard_inputs=True)
logging.debug(' Completed eval_step() runs on training split.')
logging.info('step=`%d`', step_i)
logging.info(' eval loss: %s', loss)
logging.info(' mean_metrics: %s', mean_metrics)
logging.info(' summary_tensors: %s', summary_tensors)
if step_i % train_p.summary_interval_steps == 0:
logging.debug(' Writing eval summaries.')
summary_utils.write_summary_entry(
eval_summary_writer,
step_i,
loss,
mean_metrics,
summary_tensors,
unreplicate_metrics=True)
logging.debug(' Wrote eval summaries.')
# Eval on the test sets.
if eval_input_p is not None:
logging.debug(' Performing eval_step() runs on test splits.')
model_utils.run_eval_loop_over_test_splits(
eval_num_steps,
eval_step_fn,
eval_test_summary_writers,
step_i,
eval_input_pipelines,
reshard_inputs=True)
logging.debug(' Completed eval_step() runs on test splits.')
logging.debug('step=`%d`: End', step_i - 1)