def train_and_evaluate_spmd_model()

in lingvo/jax/train.py [0:0]


def train_and_evaluate_spmd_model(
    model_p: InstantiableParams, train_input_p: InstantiableParams,
    job_log_dir: Optional[str],
    checkpoint_manager: checkpoint_managers.CheckpointManager,
    checkpoint_type: CheckpointType, restore_checkpoint_dir: Optional[str],
    restore_checkpoint_step: Optional[int],
    eval_input_p: Optional[Sequence[InstantiableParams]]) -> None:
  """Runs the training and evaluation loop.

  Args:
    model_p: Params for the SPMD model.
    train_input_p: Params for the train data pipeline.
    job_log_dir: Directory for the job logs.
    checkpoint_manager: A checkpoint manager controlling how often to save and
      delete checkpoints.
    checkpoint_type: The type of checkpoint to use.
    restore_checkpoint_dir: If set, the directory from which to restore
      checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory
      instead.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    eval_input_p: Optional list of params for the eval input pipelines.
  """
  logging.info('Using SPMD sharding for model parallelism.')
  train_input_pipeline = train_input_p.Instantiate()

  if eval_input_p is not None:
    eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]
    # Do not mutate eval_input_pipelines itself. Instantiate a new one
    # to get sample input.
    sample_eval_model_inputs = eval_input_p[0].Instantiate().get_next()
    eval_test_inputs_shape = tf.nest.map_structure(
        py_utils.get_global_input_shape_dtype, sample_eval_model_inputs)
    eval_test_inputs_pspecs = trainer_lib.get_input_partition_specs(
        model_p.mesh_axis_names, eval_test_inputs_shape)

  # TODO(bf-jax): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  checkpoint_dir = _checkpoint_dir(job_log_dir)
  restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir
  # Note that GDA checkpoint requires all processes to participate in
  # checkpointing but it does not require a separate checkpoint_dir per process.
  if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
    checkpoint_task_dir = os.path.join(checkpoint_dir,
                                       f'{jax.process_index():03d}')
    restore_checkpoint_task_dir = os.path.join(restore_checkpoint_dir,
                                               f'{jax.process_index():03d}')
  else:
    checkpoint_task_dir = checkpoint_dir
    restore_checkpoint_task_dir = restore_checkpoint_dir

  multi_host_checkpointing = bool(checkpoint_type in {
      CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA
  })

  if jax.process_index() == 0:
    tf.io.gfile.makedirs(checkpoint_dir)
  if multi_host_checkpointing:
    # Block all hosts until directory is ready.
    py_utils.sync_global_devices(f'checkpointer:makedirs:{checkpoint_dir}')

  logging.info('Retrieving model inputs for shape info.')
  model_inputs_for_shape = train_input_pipeline.get_next()
  inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
                                       model_inputs_for_shape)

  mesh_shape = model_p.device_mesh.shape
  device_mesh = mesh_utils.create_device_mesh(mesh_shape)
  logging.info('device_mesh: %s', device_mesh)
  # TODO(zhangqiaorjc): maps.mesh should yield Mesh.
  global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
  with maps.mesh(device_mesh, model_p.mesh_axis_names):
    (partitioned_train_state, train_state_pspecs, inputs_pspecs, train_step,
     eval_step, total_num_params) = trainer_lib.partition_spmd_model(
         model_p, init_key, inputs_shape)

    partitioned_train_state = checkpoints.restore_checkpoint(
        partitioned_train_state,
        restore_checkpoint_task_dir,
        global_mesh=global_mesh,
        checkpoint_type=checkpoint_type,
        state_specs=train_state_pspecs,
        step=restore_checkpoint_step)
    logging.info(
        'partitioned_train_state shapes '
        '(global shape for GDA, host-local shape for non-GDA: %s',
        jax.tree_map(lambda x: x.shape, partitioned_train_state))
    if multi_host_checkpointing:
      py_utils.sync_global_devices(f'checkpointer:restored:{checkpoint_dir}')

    # We do not fold in jax.process_index in contrast to the pmap version and
    # use a single global key instead to rely on pjit to split for different
    # replicas.
    logging.info('root prng_key: %s', prng_key)
    prng_key, train_key, eval_key = jax.random.split(prng_key, 3)
    logging.info('train prng_key: %s', train_key)
    logging.info('eval prng_key: %s', eval_key)

    train_p = model_p.train

    logging.info('Training loop starting...')
    summary_base_dir = os.path.join(job_log_dir, 'summaries')
    summary_train_dir = os.path.join(summary_base_dir, 'train')
    summary_eval_dir = os.path.join(summary_base_dir, 'eval_train')
    summary_writer = summary_utils.get_summary_writer
    if eval_input_p is not None:
      summary_eval_test_dirs = [
          os.path.join(summary_base_dir, f'eval_test_{split}')
          for split, _ in enumerate(eval_input_p)
      ]
      # We either run p.eval_loop_num_batches steps or one epoch (when supported
      # by a resettable input) per eval loop during training. When
      # p.reset_for_eval is set to True, we run the eval loop until
      # tf.errors.OutOfRangeError (or StopIteration) is raised, which can be
      # triggered either because input pipeline has reached the end of the input
      # sequence, or a pre-determined num_batches has reached.
      eval_num_steps = [
          -1 if p.reset_for_eval else p.eval_loop_num_batches
          for p in eval_input_p
      ]
    else:
      summary_eval_test_dirs = []

    with contextlib.ExitStack() as exit_stack:
      train_summary_writer = exit_stack.enter_context(
          summary_writer(summary_train_dir))
      eval_summary_writer = exit_stack.enter_context(
          summary_writer(summary_eval_dir))
      eval_test_summary_writers = [
          exit_stack.enter_context(summary_writer(d))
          for d in summary_eval_test_dirs
      ]

      # This only prints the view from the first host machine.
      summary_utils.write_model_structure(
          train_summary_writer,
          partitioned_train_state,
          is_vars_replicated=False)
      summary_utils.write_total_num_params(train_summary_writer,
                                           total_num_params)

      summary_last_time = time.time()
      summary_last_step = None

      step_i = int(
          jax.device_get(
              py_utils.maybe_unreplicate_gda(partitioned_train_state.step)))

      # Start the train loop. Make sure all at the same step.
      py_utils.sync_global_devices(f'Start training loop from step: {step_i}')
      while True:
        logging.debug('step=`%d`: Beginning', step_i)
        if step_i >= train_p.num_train_steps:
          logging.info(
              'Training loop completed (step (`%d`) greater than '
              'num_train_step (`%d`).', step_i, train_p.num_train_steps)
          break

        if summary_last_step is None:
          summary_last_step = step_i - 1

        if checkpoint_manager.should_save(step_i):
          logging.info('Saving a ckpt at step: %d', step_i)
          if multi_host_checkpointing:
            py_utils.sync_global_devices(
                f'checkpointer:saving:{checkpoint_dir}:step-{step_i}')
          if multi_host_checkpointing or jax.process_index() == 0:
            checkpoints.save_checkpoint(
                partitioned_train_state,
                checkpoint_task_dir,
                checkpoint_type=checkpoint_type,
                state_specs=train_state_pspecs,
                unreplicate=False)
          checkpoint_manager.save_metadata(global_step_id=step_i)
          if multi_host_checkpointing:
            py_utils.sync_global_devices(
                f'checkpointer:saved:{checkpoint_dir}:step-{step_i}')

        # Get new model inputs
        if step_i <= 5:
          logging.info('step=`%d`: Retrieving model inputs.', step_i)
        logging.debug('  Retrieving inputs.')
        model_inputs = train_input_pipeline.get_next()

        if jax.config.jax_parallel_functions_output_gda:
          start = time.time()
          py_utils.assert_same_shape_and_dtype(
              inputs_shape,
              tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
                                    model_inputs))
          model_inputs = py_utils.create_gda(model_inputs, inputs_shape,
                                             global_mesh, inputs_pspecs)
          logging.info('GDA train batch input creation time %s',
                       time.time() - start)

        logging.debug('  Retrieved inputs.')

        logging.debug('  Performing train_step().')
        with jax.profiler.StepTraceAnnotation('train', step_num=step_i):
          (partitioned_train_state, loss, metrics, per_example_out,
           summary_tensors) = train_step(partitioned_train_state, train_key,
                                         model_inputs)
        logging.debug('  Completed train_step().')

        logging.debug('  Writing summaries (attempt).')
        if summary_utils.write_summary_every_n_steps(
            partitioned_train_state,
            train_summary_writer,
            step_i,
            train_p.summary_interval_steps,
            loss,
            metrics,
            per_example_out,
            summary_tensors,
            train_p.norm_summary_interval_steps,
            summary_last_time,
            summary_last_step,
            unreplicate_mdl_vars=False,
            unreplicate_metrics=False):
          summary_last_time = time.time()
          summary_last_step = step_i
          step_i = int(
              py_utils.maybe_unreplicate_gda(partitioned_train_state.step))
        else:
          # Increment train step locally to avoid an explicit device sync.
          step_i += 1
        logging.debug('  Wrote summaries (attempted).')

        # Run eval at regular step interval.
        if step_i % train_p.eval_interval_steps == 0:
          logging.debug('  Starting eval_step().')
          logging.debug('  Retrieving eval model_inputs.')
          eval_inputs = train_input_pipeline.get_next()

          if jax.config.jax_parallel_functions_output_gda:
            eval_inputs = py_utils.create_gda(eval_inputs, inputs_shape,
                                              global_mesh, inputs_pspecs)

          logging.debug('  Retrieved eval model_inputs.')
          logging.debug('  Performing eval_step() runs on training split.')

          eval_step_fn = functools.partial(eval_step,
                                           partitioned_train_state.mdl_vars,
                                           eval_key,
                                           partitioned_train_state.step)
          loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step(
              eval_inputs, eval_step_fn, reshard_inputs=False)
          logging.debug('  Completed eval_step() runs on training split.')

          logging.info('step=`%d`', step_i)
          logging.info('  eval loss: %s', loss)
          logging.info('  mean_metrics: %s', mean_metrics)
          logging.info('  summary_tensors: %s', summary_tensors)
          if step_i % train_p.summary_interval_steps == 0:
            logging.debug('  Writing eval summaries.')
            summary_utils.write_summary_entry(
                eval_summary_writer,
                step_i,
                loss,
                mean_metrics,
                summary_tensors,
                unreplicate_metrics=False)
            logging.debug('  Wrote eval summaries.')
          # If we have eval test then also evaluate on test.
          if eval_input_p is not None:
            logging.debug('  Performing eval_step() runs on test splits.')
            model_utils.run_eval_loop_over_test_splits(
                eval_num_steps,
                eval_step_fn,
                eval_test_summary_writers,
                step_i,
                eval_input_pipelines,
                eval_test_inputs_pspecs,
                eval_test_inputs_shape,
                global_mesh,
                reshard_inputs=False)
            logging.debug('  Completed eval_step() runs on test splits.')

        logging.debug('step=`%d`: End', step_i - 1)