def train_and_evaluate_pmap()

in lingvo/jax/train.py [0:0]


def train_and_evaluate_pmap(
    model_p: InstantiableParams, train_input_p: InstantiableParams,
    job_log_dir: Optional[str],
    checkpoint_manager: checkpoint_managers.CheckpointManager,
    restore_checkpoint_dir: Optional[str],
    restore_checkpoint_step: Optional[int],
    eval_input_p: Optional[Sequence[InstantiableParams]]) -> None:
  """Runs the training and evaluation loop.

  Args:
    model_p: Params for the data parallel model.
    train_input_p: Params for the train data input pipeline.
    job_log_dir: Directory for the job logs.
    checkpoint_manager: A checkpoint manager controlling how often to save and
      delete checkpoints.
    restore_checkpoint_dir: If set, the directory from which to restore
      checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory
      instead.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    eval_input_p: Optional list of params for the eval input pipelines.
  """
  logging.info('Using pmap for data parallelism.')
  if jax.config.jax_parallel_functions_output_gda:
    raise NotImplementedError(
        'jax.pmap does not yet support GlobalDeviceArray.')
  jax_model = model_p.Instantiate()

  train_input_pipeline = train_input_p.Instantiate()
  if eval_input_p is not None:
    eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]

  # TODO(shafey): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  checkpoint_dir = _checkpoint_dir(job_log_dir)
  restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir
  model_states = trainer_lib.initialize_model_state(jax_model, init_key)
  model_states = checkpoints.restore_checkpoint(
      model_states,
      restore_checkpoint_dir,
      step=restore_checkpoint_step)
  total_num_params = jax_model.total_num_vars
  replicated_model_states = trainer_lib.replicate_model_state(model_states)
  # Unreplicated model states are not needed anymore at that point.
  del model_states

  logging.info('replicated_model_states shapes: %s',
               jax.tree_map(lambda x: x.shape, replicated_model_states))
  # From now on, different replicas should use different random seeds.
  # Here, each process will have its unique prng_key.
  # prng_key will be further split so that each core on a host will get
  # different prng_key.
  prng_key = jax.random.fold_in(prng_key, jax.process_index())
  logging.info('root prng_key: %s', prng_key)

  fprop_dtype = model_p.fprop_dtype

  def train_step(states, prng_key, inputs):
    return trainer_lib.train_step_single_learner(
        jax_model,
        states,
        prng_key,
        inputs,
        data_parallel_axis_name='batch',
        fprop_dtype=fprop_dtype)

  def eval_step(mdl_vars, prng_key, global_step, inputs):
    return trainer_lib.eval_step_single_learner(
        jax_model,
        mdl_vars,
        prng_key,
        global_step,
        inputs,
        data_parallel_axis_name='batch',
        fprop_dtype=fprop_dtype)

  num_devices = jax.local_device_count()
  prng_key, train_key, eval_key = jax.random.split(prng_key, 3)
  train_prng_seed = jax.random.split(train_key, num=num_devices)
  eval_prng_seed = jax.random.split(eval_key, num=num_devices)
  logging.info('train prng_seed: %s', train_prng_seed)
  logging.info('eval prng_seed: %s', eval_prng_seed)

  p_train_step = jax.pmap(train_step, donate_argnums=(0,), axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  train_p = model_p.train

  logging.info('Training loop starting...')
  summary_base_dir = os.path.join(job_log_dir, 'summaries')
  summary_train_dir = os.path.join(summary_base_dir, 'train')
  summary_eval_dir = os.path.join(summary_base_dir, 'eval_train')
  summary_writer = summary_utils.get_summary_writer
  if eval_input_p is not None:
    summary_test_split_dirs = [
        os.path.join(summary_base_dir, f'eval_test_{split}')
        for split, _ in enumerate(eval_input_p)
    ]
    # We either run p.eval_loop_num_batches steps or one epoch (when supported
    # by a resettable input) per eval loop during training. When
    # p.reset_for_eval is set to True, we run the eval loop until
    # tf.errors.OutOfRangeError (or StopIteration) is raised, which can be
    # triggered either because input pipeline has reached the end of the input
    # sequence, or a pre-determined num_batches has reached.
    eval_num_steps = [
        -1 if p.reset_for_eval else p.eval_loop_num_batches
        for p in eval_input_p
    ]
  else:
    summary_test_split_dirs = []

  with contextlib.ExitStack() as exit_stack:
    train_summary_writer = exit_stack.enter_context(
        summary_writer(summary_train_dir))
    eval_summary_writer = exit_stack.enter_context(
        summary_writer(summary_eval_dir))
    eval_test_summary_writers = [
        exit_stack.enter_context(summary_writer(d))
        for d in summary_test_split_dirs
    ]

    summary_utils.write_model_structure(
        train_summary_writer, replicated_model_states, is_vars_replicated=True)
    summary_utils.write_total_num_params(train_summary_writer, total_num_params)

    summary_last_time = time.time()
    summary_last_step = None

    step_i = int(jax.device_get(replicated_model_states.step)[0])
    while True:
      logging.debug('step=`%d`: Beginning', step_i)
      if step_i >= train_p.num_train_steps:
        logging.info(
            'Training loop completed (step (`%d`) greater than '
            'num_train_step (`%d`).', step_i, train_p.num_train_steps)
        break
      if summary_last_step is None:
        summary_last_step = step_i - 1

      if checkpoint_manager.should_save(step_i):
        if jax.process_index() == 0:
          checkpoints.save_checkpoint(replicated_model_states, checkpoint_dir)
        checkpoint_manager.save_metadata(global_step_id=step_i)

      if step_i <= 5:
        logging.info('step=`%d`: Retrieving model inputs.', step_i)
      logging.debug('  Retrieving inputs.')
      model_inputs = tf.nest.map_structure(py_utils.reshard,
                                           train_input_pipeline.get_next())
      logging.debug('  Retrieved inputs.')
      logging.debug('  Performing train_step().')
      with jax.profiler.StepTraceAnnotation('train', step_num=step_i):
        (replicated_model_states, loss, metrics, per_example_out,
         summary_tensors) = p_train_step(replicated_model_states,
                                         train_prng_seed, model_inputs)
      logging.debug('  Completed train_step().')

      logging.debug('  Writing summaries (attempt).')
      if summary_utils.write_summary_every_n_steps(
          replicated_model_states,
          train_summary_writer,
          step_i,
          train_p.summary_interval_steps,
          loss,
          metrics,
          per_example_out,
          summary_tensors,
          train_p.norm_summary_interval_steps,
          summary_last_time,
          summary_last_step,
          unreplicate_mdl_vars=True,
          unreplicate_metrics=True):
        summary_last_time = time.time()
        summary_last_step = step_i
        # Synchronize step_i
        step_i = int(jax.device_get(replicated_model_states.step)[0])
      else:
        # Increment locally to avoid an explicit sync.
        step_i += 1
      logging.debug('  Wrote summaries (attempted).')

      # Run eval at regular step interval.
      if step_i % train_p.eval_interval_steps == 0:
        logging.debug('  Starting eval_step().')
        logging.debug('  Retrieving eval model_inputs.')
        eval_inputs = train_input_pipeline.get_next()
        logging.debug('  Retrieved eval model_inputs.')
        logging.debug('  Performing eval_step() runs on training split.')
        eval_step_fn = functools.partial(p_eval_step,
                                         replicated_model_states.mdl_vars,
                                         eval_prng_seed,
                                         replicated_model_states.step)
        loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step(
            eval_inputs, eval_step_fn, reshard_inputs=True)
        logging.debug('  Completed eval_step() runs on training split.')
        logging.info('step=`%d`', step_i)
        logging.info('  eval loss: %s', loss)
        logging.info('  mean_metrics: %s', mean_metrics)
        logging.info('  summary_tensors: %s', summary_tensors)
        if step_i % train_p.summary_interval_steps == 0:
          logging.debug('  Writing eval summaries.')
          summary_utils.write_summary_entry(
              eval_summary_writer,
              step_i,
              loss,
              mean_metrics,
              summary_tensors,
              unreplicate_metrics=True)
          logging.debug('  Wrote eval summaries.')
        # Eval on the test sets.
        if eval_input_p is not None:
          logging.debug('  Performing eval_step() runs on test splits.')
          model_utils.run_eval_loop_over_test_splits(
              eval_num_steps,
              eval_step_fn,
              eval_test_summary_writers,
              step_i,
              eval_input_pipelines,
              reshard_inputs=True)
        logging.debug('  Completed eval_step() runs on test splits.')
      logging.debug('step=`%d`: End', step_i - 1)