gce/survival-training/wrapper/train.py (135 lines of code) (raw):
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import glob
import json
import os
import random
def generate_trainer(hyperparameters):
"""Generates a callable which performs a single step of training when called.
Args:
1. hyperparameters - hyperparameters to train with.
Returns:
trainer callable, which performs a single step of training every time it is
called and returns a JSON serializable representation of its state at the
end
"""
def _trainer():
"""Dummy callable.
Args:
None
Returns:
A random number between 0 and 1.
"""
return random.random()
return _trainer
def runner(
trainer_initializer,
job_dir,
train_steps,
checkpoint_steps,
hyperparameters
):
"""Runs a training job.
Args:
trainer_initializer: Function which accepts hyperparameter dictionary as its
only argument and returns a callable representing a single step of training.
job_dir: Directory in which checkpoints should be stored.
train_steps: Total number of steps for which training should be performed.
checkpoint_steps: Training steps between checkpoints.
hyperparameters: Dictionary containing hyperparameter specification for the
training job.
Returns:
None
Raises:
ValueError: If hyperparameters are inconsistent with existing checkpoints in
job_dir.
"""
current_checkpoint_index = 0
current_hyperparameters = copy.copy(hyperparameters)
last_path, last_index = latest_checkpoint(get_checkpoints(job_dir))
if last_index is not None:
current_checkpoint_index = last_index + 1
last_data = load_checkpoint(last_path)
last_hp = last_data.get("hyperparameters")
for hyperparameter in current_hyperparameters:
if (current_hyperparameters[hyperparameter] is not None and
current_hyperparameters[hyperparameter] != last_hp[hyperparameter]):
raise ValueError(
"Inconsistent values for {}: ".format(hyperparameter) +
"command line -- {}, checkpoint -- {}".format(
hyperparameters[hyperparameter],
last_data[hyperparameter]
)
)
current_hyperparameters = last_hp
train_step = trainer_initializer(hyperparameters)
def finished(step):
"""Returns True if job is complete and False otherwise."""
if train_steps is None:
return False
else:
return step > train_steps
result = None
# TODO(nkashy1): Add test for "up to N steps" rather than "additional N steps"
current_step = current_checkpoint_index*checkpoint_steps + 1
while not finished(current_step):
result = train_step()
if current_step%checkpoint_steps == 0:
checkpoint_data = generate_checkpoint(
current_checkpoint_index,
hyperparameters,
result
)
save_checkpoint(job_dir, current_checkpoint_index, checkpoint_data)
current_checkpoint_index += 1
current_step += 1
checkpoint_data = generate_checkpoint(
current_checkpoint_index,
hyperparameters,
result
)
save_checkpoint(job_dir, current_checkpoint_index, checkpoint_data)
def generate_checkpoint(step, hyperparameters, model_state):
"""Generates checkpoint contents.
Args:
step: Training step at which this checkpoint was generated.
hyperparameters: Dictionary specifying the model hyperparameters.
model_state: A JSON serializable representation of the model state.
Returns:
Dictionary representing the content to be checkpointed.
"""
checkpoint_data = {
"steps": step,
"hyperparameters": hyperparameters,
"model": model_state
}
return checkpoint_data
def get_checkpoints(job_dir):
"""Get all the checkpoints in a given directory.
Args:
job_dir: Directory containing checkpoints.
Returns:
List of paths to checkpoint files in the given directory.
"""
checkpoint_glob = os.path.join(job_dir, "dummy-checkpoint-*.json")
checkpoints = glob.glob(checkpoint_glob)
return checkpoints
def latest_checkpoint(checkpoint_paths):
"""Returns the path to the most recently stored checkpoint from a list of
checkpoints.
Args:
checkpoint_paths: List of paths to checkpoint files.
Returns:
Path to the most recent checkpoint from the provided list.
"""
if not checkpoint_paths:
return (None, None)
checkpoint_indices = map(checkpoint_index, checkpoint_paths)
indexed_checkpoints = zip(checkpoint_paths, checkpoint_indices)
sorted_indexed_checkpoints = sorted(indexed_checkpoints, key=lambda p: p[1])
return sorted_indexed_checkpoints[-1]
def checkpoint_index(checkpoint_path):
"""Returns the index of the checkpoint along a given path.
Args:
checkpoint_path: Path to a checkpoint file.
Returns:
Integer specifying which checkpoint the path represents. For example,
dummy-checkpoint-173.json represents the 173rd checkpoint, and this function
would return the integer 173.
"""
checkpoint_file = os.path.basename(checkpoint_path)
prefix = "dummy-checkpoint-"
suffix = ".json"
return int(checkpoint_file[len(prefix):-len(suffix)])
def load_checkpoint(checkpoint_path):
"""Loads the checkpoint object stored at a given path.
Args:
checkpoint_path: Path along which checkpoint is stored.
Returns:
Python dictionary representing the data serialized in the checkpoint.
"""
with open(checkpoint_path, "r") as fp:
checkpoint_data = json.load(fp)
return checkpoint_data
def save_checkpoint(job_dir, index, checkpoint_data):
"""Serializes checkpoint data and stores it in a given directory.
Args:
job_dir: Directory in which to store checkpoint data.
index: Ordinal index of the checkpoint.
checkpoint_data: Data to be stored in the checkpoint. (Note: currently
assumed to be JSON serializable.)
Returns:
The path to the saved checkpoint file.
"""
checkpoint_file = "dummy-checkpoint-{}.json".format(index)
checkpoint_path = os.path.join(job_dir, checkpoint_file)
with open(checkpoint_path, "w") as fp:
json.dump(checkpoint_data, fp)
return checkpoint_path
if __name__ == "__main__":
parser = argparse.ArgumentParser("Dummy trainer")
parser.add_argument(
"--job-dir",
help="Directory where checkpoints and checkpoint metadata will be written"
)
parser.add_argument(
"--checkpoint-steps",
type=int,
help="Number of steps per checkpointing operation"
)
parser.add_argument(
"--train-steps",
type=int,
default=None,
help=("Total number of steps that you would like to train for -- "
"trains forever if this argument is not specified")
)
parser.add_argument(
"--hyperparameter-1",
type=int,
required=False,
help="Generic integer hyperparameter for dummy model"
)
parser.add_argument(
"--hyperparameter-2",
type=float,
required=False,
help="Generic floating point hyperparameter for dummy model"
)
args = parser.parse_args()
hparams = {
"hyperparameter_1": args.hyperparameter_1,
"hyperparameter_2": args.hyperparameter_2
}
runner(
generate_trainer,
args.job_dir,
args.train_steps,
args.checkpoint_steps,
hparams
)