def runner()

in gce/survival-training/wrapper/train.py [0:0]


def runner(
    trainer_initializer,
    job_dir,
    train_steps,
    checkpoint_steps,
    hyperparameters
    ):
  """Runs a training job.

  Args:
    trainer_initializer: Function which accepts hyperparameter dictionary as its
    only argument and returns a callable representing a single step of training.
    job_dir: Directory in which checkpoints should be stored.
    train_steps: Total number of steps for which training should be performed.
    checkpoint_steps: Training steps between checkpoints.
    hyperparameters: Dictionary containing hyperparameter specification for the
    training job.

  Returns:
    None

  Raises:
    ValueError: If hyperparameters are inconsistent with existing checkpoints in
    job_dir.
  """
  current_checkpoint_index = 0
  current_hyperparameters = copy.copy(hyperparameters)

  last_path, last_index = latest_checkpoint(get_checkpoints(job_dir))
  if last_index is not None:
    current_checkpoint_index = last_index + 1
    last_data = load_checkpoint(last_path)
    last_hp = last_data.get("hyperparameters")
    for hyperparameter in current_hyperparameters:
      if (current_hyperparameters[hyperparameter] is not None and
          current_hyperparameters[hyperparameter] != last_hp[hyperparameter]):
        raise ValueError(
            "Inconsistent values for {}: ".format(hyperparameter) +
            "command line -- {}, checkpoint -- {}".format(
                hyperparameters[hyperparameter],
                last_data[hyperparameter]
            )
        )

    current_hyperparameters = last_hp

  train_step = trainer_initializer(hyperparameters)

  def finished(step):
    """Returns True if job is complete and False otherwise."""
    if train_steps is None:
      return False
    else:
      return step > train_steps

  result = None
  # TODO(nkashy1): Add test for "up to N steps" rather than "additional N steps"
  current_step = current_checkpoint_index*checkpoint_steps + 1
  while not finished(current_step):
    result = train_step()

    if current_step%checkpoint_steps == 0:
      checkpoint_data = generate_checkpoint(
          current_checkpoint_index,
          hyperparameters,
          result
      )
      save_checkpoint(job_dir, current_checkpoint_index, checkpoint_data)
      current_checkpoint_index += 1

    current_step += 1

  checkpoint_data = generate_checkpoint(
      current_checkpoint_index,
      hyperparameters,
      result
  )
  save_checkpoint(job_dir, current_checkpoint_index, checkpoint_data)