in mesh_tensorflow/transformer/utils.py [0:0]
def run(tpu_job_name,
tpu,
gcp_project,
tpu_zone,
model_dir,
model_type="bitransformer",
vocabulary=None,
train_dataset_fn=None,
eval_dataset_fn=None,
dataset_split="train",
autostack=True,
eval_checkpoint_step=None,
export_checkpoint_step=None,
export_path="",
mode="train",
iterations_per_loop=100,
save_checkpoints_steps=5000,
keep_checkpoint_max=None,
eval_summary_dir=None,
batch_size=("tokens_per_replica", 2048),
train_steps=auto_train_steps,
total_run_steps=None,
sequence_length=None,
mesh_shape=gin.REQUIRED,
mesh_devices=None,
layout_rules=gin.REQUIRED,
learning_rate_schedule=None,
optimizer=None,
predict_fn=None,
variable_filter=None,
perplexity_eval_steps=100,
init_checkpoint=None,
ensemble_inputs=None,
train_model_fn=train_model,
skip_seen_data=False,
seen_data_init_step=0,
output_eval_examples=True,
checkpoint_input_pipeline=False,
eval_dir_suffix=None):
"""Run training, eval, or inference depending on `mode`.
Args:
tpu_job_name: string, name of TPU worker binary
tpu: string, the Cloud TPU to use for training
gcp_project: string, project name for the Cloud TPU-enabled project
tpu_zone: string, GCE zone where the Cloud TPU is located in
model_dir: string, estimator model_dir
model_type: a string, see `get_estimator` docstring for details.
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple.
train_dataset_fn: A function returning a tf.data.Dataset, see `train_model`
docstring for details.
eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
See `eval_model` docstring for details.
dataset_split: a string
autostack: boolean, see `get_estimator` docstring for details.
eval_checkpoint_step: int, list of ints, or None, see `eval_model` doc
string for details.
export_checkpoint_step: int or None, see `export_model` doc string for
details.
export_path: a string, path to export the saved model
mode: string, one of
train - train the model
eval - eval the model by decoding predictions
score_eval - eval the model by computing log likelihood scores of targets
perplexity_eval - eval the model by computing perplexity
infer - decode predictions based on inputs
score_from_dataset - compute scores of targets from a dataset
score_from_strings - compute scores of targets from strings or a file
export_score - export a model that scores provided examples
export_infer - export a model that decodes predictions based on inputs
iterations_per_loop: integer, steps per train loop
save_checkpoints_steps: integer, see `get_estimator` docstring.
keep_checkpoint_max: an integer, see `get_estimator` docstring.
eval_summary_dir: str, see `eval_model` docstring for details.
batch_size: An integer or a (method, value) pair to pass to
compute_batch_size(). Note that this is the global batch size and not the
per-shard batch size.
train_steps: An integer or a function with the same signature as
auto_train_steps(). Total number of training steps in this run.
total_run_steps: An integer, used when training is split over multiple
runs. This value is gin-configurable and used to set the total_run_steps
for the learning_rate_schedule.
sequence_length: an integer or a dict from feature-key to integer
the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}.
May also be set to `None` in eval mode to automatically compute the
maximum length of the examples.
mesh_shape: an input to mtf.convert_to_shape()
mesh_devices: a list of strings, see `get_estimator` docstring.
layout_rules: an input to mtf.convert_to_layout_rules()
learning_rate_schedule: a function which takes the scalar name argument
`step` and the numeric argument `total_train_steps` and returns the scalar
learning rate. Alternatively a float. Alternatively, a list of
such factos to be multiplied together.
optimizer: a class extending optimize.Optimizer, required for training
predict_fn: an optional function, see `get_estimator` docstring for details.
variable_filter: a string, see `get_estimator` docstring for details.
perplexity_eval_steps: an integer - number of steps for perplexity eval
init_checkpoint: a string, see `get_estimator` docstring for details.
ensemble_inputs: an integer, see `train_model` docstring for details.
train_model_fn: an optional train function, is `train_model` by default.
skip_seen_data: a boolean, is `False` by default. Used when a training run
restarts to skip already seen data. This flag is only consistent when
every setting (such as batch size and random seed) on the model is the
same between the original run and the new run. May require a significant
amount of time to skip a large number of steps.
seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
steps from this starting point. Useful when finetuning.
output_eval_examples: a boolean, is `True` by default. Used to decide
whether to output whether to dump inputs, targets, and predictions of the
eval examples in plaintext to eval_summary_dir.
checkpoint_input_pipeline: a boolean, whether to checkpoint the input
pipeline in order to restart from the previous run. May require a large
amount of disk space for complicated input pipelines.
eval_dir_suffix: a string, if not None then will be appended to the eval
subdirectory name for all three eval modes:
`perplexity_eval`, `eval`, `score_eval`.
"""
if isinstance(sequence_length, int):
sequence_length = {"inputs": sequence_length,
"targets": sequence_length}
if not isinstance(batch_size, int):
batch_size = compute_batch_size(
sequence_length, mesh_shape, layout_rules, batch_size)
if not isinstance(train_steps, int):
train_steps = train_steps(batch_size, sequence_length)
if total_run_steps is None:
total_run_steps = train_steps
if isinstance(learning_rate_schedule, list):
learning_rate_schedule = functools.partial(
learning_rate_schedules.product_learning_rate,
total_train_steps=total_run_steps, factors=learning_rate_schedule)
if callable(learning_rate_schedule):
learning_rate_schedule = functools.partial(
learning_rate_schedule, total_train_steps=total_run_steps)
tf.logging.info("model_type=%s", model_type,)
tf.logging.info("mode=%s", mode,)
tf.logging.info("sequence_length=%s", sequence_length,)
tf.logging.info("batch_size=%s", batch_size,)
tf.logging.info("train_steps=%s", train_steps,)
if total_run_steps is not None:
tf.logging.info("total_run_steps=%s", total_run_steps,)
tf.logging.info("mesh_shape=%s", mesh_shape,)
tf.logging.info("layout_rules=%s", layout_rules,)
if mode == "train" and dataset_split != "train":
raise ValueError("mode==\"train\" requires dataset_split==\"train\"")
if mode != "train":
ensemble_inputs = None
mesh_shape = mtf.convert_to_shape(mesh_shape)
layout_rules = mtf.convert_to_layout_rules(layout_rules)
cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu, zone=tpu_zone, project=gcp_project) if tpu else None
tf.logging.info("Building TPUConfig with tpu_job_name=%s", tpu_job_name)
score_in_predict_mode = "score" in mode
estimator_fn = functools.partial(
get_estimator,
model_type=model_type,
vocabulary=vocabulary,
layout_rules=layout_rules,
mesh_shape=mesh_shape,
model_dir=model_dir,
batch_size=batch_size,
sequence_length=sequence_length,
autostack=autostack,
learning_rate_schedule=learning_rate_schedule,
keep_checkpoint_max=keep_checkpoint_max,
save_checkpoints_steps=save_checkpoints_steps,
optimizer=optimizer,
predict_fn=predict_fn,
score_in_predict_mode=score_in_predict_mode,
variable_filter=variable_filter,
init_checkpoint=init_checkpoint,
ensemble_inputs=ensemble_inputs,
use_tpu=tpu,
tpu_job_name=tpu_job_name,
iterations_per_loop=iterations_per_loop,
cluster=cluster,
mesh_devices=mesh_devices)
if mode not in ("eval", "score_eval"):
if sequence_length is None:
raise ValueError(f"`sequence_length` must be specified in '{mode}' mode.")
estimator = estimator_fn()
if mode == "train":
# train_dataset_fn could be None if train_model_fn is not equal to
# train_model
if train_dataset_fn is None:
raise ValueError("Must provide train_dataset_fn through gin")
train_model_fn(estimator, vocabulary, sequence_length, batch_size,
train_dataset_fn, train_steps, ensemble_inputs,
skip_seen_data=skip_seen_data,
seen_data_init_step=seen_data_init_step,
checkpoint_input_pipeline=checkpoint_input_pipeline)
elif mode == "perplexity_eval":
if eval_dataset_fn is None:
if train_dataset_fn is not None:
tf.logging.warning("Using train_dataset_fn for perplexity eval")
eval_datasets = [transformer_dataset.EvalDataset(
name="eval",
dataset_fn=functools.partial(train_dataset_fn,
sequence_length=sequence_length,
vocabulary=vocabulary,
dataset_split=dataset_split),
postprocess_fn=None,
metric_fns=None)]
else:
raise ValueError(
"for perplexity_eval, "
"must provide one of eval_dataset_fn and train_dataset_fn")
else:
eval_datasets = eval_dataset_fn(
sequence_length=sequence_length,
vocabulary=vocabulary,
dataset_split=dataset_split,
)
def _input_fn(params, eval_dataset):
del params
ds = eval_dataset.dataset_fn().map(
_filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = transformer_dataset.pad_dataset_with_zeroed_out_examples(ds)
ds = (ds.batch(batch_size * (ensemble_inputs or 1), drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE))
return ds
checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir)
for checkpoint_path in checkpoint_paths:
for eval_dataset in eval_datasets:
tf.random.set_random_seed(12345)
random.seed(12345)
num_examples = batch_size * perplexity_eval_steps
# include the number of examples in the evaluation name so as to
# make sure we are comparing apples to apples.
name = "%s_%s_%d" % (eval_dataset.name, dataset_split, num_examples)
if eval_dir_suffix is not None:
name += "_%s" % eval_dir_suffix
_ = estimator.evaluate(
input_fn=functools.partial(_input_fn, eval_dataset=eval_dataset),
steps=perplexity_eval_steps,
checkpoint_path=checkpoint_path,
name=name)
elif mode in ("eval", "score_eval"):
eval_model(
estimator_fn,
vocabulary,
sequence_length,
batch_size,
dataset_split,
model_dir,
eval_dataset_fn,
eval_summary_dir,
eval_checkpoint_step,
eval_with_score=(mode == "score_eval"),
output_eval_examples=output_eval_examples,
eval_dir_suffix=eval_dir_suffix)
elif mode == "infer":
infer_model(estimator, vocabulary, sequence_length, batch_size, model_type,
model_dir, eval_checkpoint_step)
elif mode == "score_from_strings":
score_from_strings(estimator=estimator,
vocabulary=vocabulary,
model_type=model_type,
batch_size=batch_size,
sequence_length=sequence_length,
model_dir=model_dir,
eval_checkpoint_step=eval_checkpoint_step)
elif mode == "score_from_dataset":
score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
model_dir, eval_checkpoint_step, dataset_split)
elif mode in ["export_score", "export_infer", "export"]:
if mode == "export":
tf.logging.warning("Mode 'export' is deprecated. "
"Defaulting to 'export_infer'.")
if export_checkpoint_step:
checkpoint_path = get_checkpoint_iterator(
export_checkpoint_step, model_dir)
if isinstance(checkpoint_path, list):
checkpoint_path = checkpoint_path[0]
else:
checkpoint_path = next(checkpoint_path)
else:
# Use the latest checkpoint in the model directory.
checkpoint_path = None
export_model(estimator, export_path, vocabulary, sequence_length,
model_type, score_in_predict_mode, batch_size, checkpoint_path)
else:
raise ValueError(
"unknown mode %s - must be train/perplexity_eval/eval/infer/export"
% mode)