in gce/survival-training/wrapper/train.py [0:0]
def runner(
trainer_initializer,
job_dir,
train_steps,
checkpoint_steps,
hyperparameters
):
"""Runs a training job.
Args:
trainer_initializer: Function which accepts hyperparameter dictionary as its
only argument and returns a callable representing a single step of training.
job_dir: Directory in which checkpoints should be stored.
train_steps: Total number of steps for which training should be performed.
checkpoint_steps: Training steps between checkpoints.
hyperparameters: Dictionary containing hyperparameter specification for the
training job.
Returns:
None
Raises:
ValueError: If hyperparameters are inconsistent with existing checkpoints in
job_dir.
"""
current_checkpoint_index = 0
current_hyperparameters = copy.copy(hyperparameters)
last_path, last_index = latest_checkpoint(get_checkpoints(job_dir))
if last_index is not None:
current_checkpoint_index = last_index + 1
last_data = load_checkpoint(last_path)
last_hp = last_data.get("hyperparameters")
for hyperparameter in current_hyperparameters:
if (current_hyperparameters[hyperparameter] is not None and
current_hyperparameters[hyperparameter] != last_hp[hyperparameter]):
raise ValueError(
"Inconsistent values for {}: ".format(hyperparameter) +
"command line -- {}, checkpoint -- {}".format(
hyperparameters[hyperparameter],
last_data[hyperparameter]
)
)
current_hyperparameters = last_hp
train_step = trainer_initializer(hyperparameters)
def finished(step):
"""Returns True if job is complete and False otherwise."""
if train_steps is None:
return False
else:
return step > train_steps
result = None
# TODO(nkashy1): Add test for "up to N steps" rather than "additional N steps"
current_step = current_checkpoint_index*checkpoint_steps + 1
while not finished(current_step):
result = train_step()
if current_step%checkpoint_steps == 0:
checkpoint_data = generate_checkpoint(
current_checkpoint_index,
hyperparameters,
result
)
save_checkpoint(job_dir, current_checkpoint_index, checkpoint_data)
current_checkpoint_index += 1
current_step += 1
checkpoint_data = generate_checkpoint(
current_checkpoint_index,
hyperparameters,
result
)
save_checkpoint(job_dir, current_checkpoint_index, checkpoint_data)