in mesh_tensorflow/transformer/utils.py [0:0]
def tpu_estimator_model_fn(model_type,
transformer_model,
vocabulary,
model_dir,
use_tpu,
mesh_shape,
layout_rules,
batch_size,
sequence_length,
autostack,
keep_checkpoint_max,
save_checkpoints_steps,
learning_rate_schedule=None,
optimizer=None,
outer_batch_size=1,
tpu_summaries=False,
predict_fn=None,
score_in_predict_mode=False,
variable_filter=None,
init_checkpoint=None,
init_variable_filter="",
ensemble_inputs=None,
mesh_devices=None,
model_info_file=None,
hierarchical_tiling_spec=None,
weight_decay_checkpoint=None # GOOGLE-INTERNAL,
):
"""Create a TPUEstimator model function.
Args:
model_type: a string. One of "bitransformer", "lm", "delimited_lm",
"aligned", or "bi_teacher_student"
transformer_model: a transformer.Unitransformer or transformer.Bitransformer
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple. Used for decoding in predict mode.
model_dir: a string, directory to save the model to.
use_tpu: a boolean
mesh_shape: a mtf.Shape
layout_rules: a mtf.LayoutRules
batch_size: an integer
sequence_length: an integer or a dict from feature-key to integer
the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
autostack: a boolean
keep_checkpoint_max: an integer, maximum number of checkpoints to keep
save_checkpoints_steps: an integer, save a checkpoint every this number of
steps
learning_rate_schedule: a constant or a function from step to learning rate
optimizer: a class extending optimize.Optimizer, required for training
outer_batch_size: outer batch dimension that could be used to enable the mix
of data-parallel and model-parallel training of Mixture of Experts (MoE)
models
tpu_summaries: a boolean, use rewrites to make summaries work on TPU. This
may be slow, since it uses a host call hack.
predict_fn: an optional function, see docs for `run` for more information.
score_in_predict_mode: compute log-likelihood scores instead of predictions
variable_filter: controls which variables are trained.
If None (default), train all trainable variables.
If a string regex, train all variables that match this regex.
If a function (mtf.Variable -> boolean), then train variables for which
the function returns True.
init_checkpoint: a string, if not None then read in variables from this
checkpoint path when initializing variables. Will only initialize
variables that appear both in the current graph and the checkpoint.
init_variable_filter: a string, used only when init_checkpoint is set.
controls which variables are loaded from the checkpoint using regex.
if empty string (default), all variables from the checkpoint are loaded.
ensemble_inputs: an optional integer - pass the size of the ensemble to
train an ensemble where each model gets different inputs.
You also need to configure Unitransformer.ensemble to the right size.
If None, then all models are trained on the same inputs.
mesh_devices: a list of strings, the device names to use for each mesh
slice. Only required for GPU.
model_info_file: an optional string, information about variables and
operations will be logged to this file during the TRAIN mode.
hierarchical_tiling_spec: an optional list that can be passed as the
spec argument to simd_mesh_impl.HierarchicalTiling
weight_decay_checkpoint: an optional checkpoint dir to weight decay from. #
GOOGE-INTERNAL
Returns:
a function to be passed to TPUEstimator
"""
mesh_devices = mesh_devices or [""] * mesh_shape.size
def my_model_fn(features, labels, mode, params=None, config=None):
"""Estimator model function.
Args:
features: dictionary where keys are strings like "inputs" and "targets"
and the values are the actual values of "inputs". See TPUEstimator's
docs for more information
labels: ignored argument
mode: a tf.estimator.ModeKeys
params: dictionary containing the key "context"
config: ignored argument
Returns:
a TPUEstimatorSpec
"""
del labels, config
if mode == tf.estimator.ModeKeys.PREDICT and score_in_predict_mode:
mode = "score"
global_step = tf.train.get_global_step()
if use_tpu and "context" in params:
ctx = params["context"]
num_hosts = ctx.num_hosts
host_placement_fn = ctx.tpu_host_placement_function
device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
# TODO(ylc): Better estimation of replica cache size?
replica_cache_size = 300 * 1000000 # 300M per replica
# Worker 0 caches all the TPU binaries.
worker0_mem = replica_cache_size * ctx.num_replicas
devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
var_placer = mtf.utils.BalancedVariablePlacer(device_list,
devices_memeory_usage)
physical_shape = [int(i) for i in
params["context"].device_assignment.topology.mesh_shape]
mesh_4d = False
if len(physical_shape) == 4:
mesh_4d = True if physical_shape[2] > 1 else False
physical_shape = (
mtf.simd_mesh_impl.physical_shape_3d_from_topology_proto_4d(
physical_shape))
if mesh_4d or hierarchical_tiling_spec is None:
logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
mesh_shape.to_integer_list,
physical_shape,
device_assignment=params["context"].device_assignment)
else:
logical_to_physical = mtf.simd_mesh_impl.HierarchicalTiling(
hierarchical_tiling_spec,
physical_shape).logical_to_physical
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
mesh_shape, layout_rules, mesh_devices, ctx.device_assignment,
logical_to_physical=logical_to_physical)
else:
var_placer = None
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, layout_rules, mesh_devices)
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh", var_placer)
if (outer_batch_size and
mode not in [tf.estimator.ModeKeys.PREDICT, "score"]):
outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
batch_dims = [outer_batch_dim, batch_dim]
else:
batch_dim = mtf.Dimension("batch", batch_size)
batch_dims = [batch_dim]
ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)]
if ensemble_inputs else [])
predict_batch_size = features.pop("predict_batch_size", None)
mtf_features = {}
for key, x in features.items():
# Some auxiliary features may have been generated in packing.
# The names of these new features are of the form
# "<original_feature_name>_<suffix>", e.g. "inputs_segmentation".
# We look up the lengths based on the original feature name, without
# the "_<suffix>".
feature_length = sequence_length[key.split("_")[0]]
length_dim = mtf.Dimension("length", feature_length)
feature_shape = mtf.Shape(
ensemble_dims + batch_dims + [length_dim])
x = tf.cast(features[key], tf.int32)
x = tf.reshape(x, feature_shape.to_integer_list)
if not use_tpu:
tf.logging.info("feature %s : %s", key, x)
mtf_features[key] = mtf.import_fully_replicated(
mesh, x, feature_shape, name=key)
def _verify_feature_exists(feature_name, should_exist):
if should_exist != (feature_name in mtf_features):
message = (
"mode=%s model_type=%s should%s have feature %s" %
(mode, model_type, "" if should_exist else " not", feature_name))
if "lm" in model_type:
message += (
"\nA common mistake is that model_type=\"delimited_lm\" should "
"be used with tasks that produce inputs and targets, while "
"model_type=\"lm\" should be used with tasks that produce "
"targets only.")
raise ValueError(message)
# Verify that the right features exist, and transform them if necessary
if mode == tf.estimator.ModeKeys.PREDICT:
_verify_feature_exists("inputs", True)
# "targets" may or may not exist depending on whether we are doing
# evaluation or open-ended inference.
elif model_type in ("lm", "delimited_lm") and mode == "score":
# in scoring mode the inputs and targets may already be combined.
if "inputs" in mtf_features:
if model_type == "lm":
tf.logging.warning(
"Scoring of lm models will include loss from the 'inputs'.")
mtf_features = _dynamic_text2self(mtf_features)
else:
_verify_feature_exists("targets", True)
_verify_feature_exists("inputs", model_type != "lm")
if model_type == "delimited_lm":
mtf_features = _dynamic_text2self(mtf_features)
# Detokenize in the graph if supported by vocabulary and accelerator.
def _maybe_detokenize(ids, vocab):
if not use_tpu and hasattr(vocab, "decode_tf"):
return vocab.decode_tf(ids)
return ids
if mode == "score":
# compute log-likelihoods per sequence
targets = mtf_features["targets"]
if predict_fn:
# predict_fn contains a custom scoring function
scores = predict_fn(
model=transformer_model,
features=mtf_features,
variable_dtype=get_variable_dtype())
else:
if isinstance(transformer_model, transformer.Unitransformer):
length_dim = targets.shape.dims[-1]
inputs = transformer.autoregressive_inputs(
mtf_features["targets"])
elif isinstance(transformer_model,
(transformer.Bitransformer,
transformer.StudentTeacher)):
inputs = mtf_features["inputs"]
else:
raise ValueError("unrecognized class")
logits, _ = transformer_model.call_simple(
inputs=inputs,
targets=targets,
compute_loss=False,
mode=mode,
variable_dtype=get_variable_dtype())
logits = mtf.cast(logits, tf.float32)
_, length_dim, vocab_dim = logits.shape.dims
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
logits, mtf_features["targets"], vocab_dim)
# 0=padding and negative targets are a hack to indicate no loss
cross_entropy *= mtf.cast(
mtf.greater(targets, 0), cross_entropy.dtype)
if model_type == "delimited_lm":
cross_entropy *= mtf.cast(mtf.logical_not(
transformer.delimited_lm_inputs_mask(targets)),
cross_entropy.dtype)
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
scores = mtf.anonymize(scores)
targets = mtf.anonymize(targets)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
targets = clean_decodes(lowering.export_to_tf_tensor(targets))
targets = _maybe_detokenize(targets, targets_vocabulary(vocabulary))
predictions = {
"targets": targets,
"scores": lowering.export_to_tf_tensor(scores)
}
elif mode == tf.estimator.ModeKeys.PREDICT:
inputs = mtf_features["inputs"]
if predict_fn:
mtf_samples = predict_fn(
model=transformer_model,
features=mtf_features,
variable_dtype=get_variable_dtype())
elif isinstance(transformer_model, transformer.Unitransformer):
# pad so that there is enough room for the targets
inputs = mtf.pad(
inputs, [0, sequence_length["targets"]], length_dim.name)
mtf_samples = transformer_model.sample_autoregressive(
inputs, variable_dtype=get_variable_dtype(),
remove_partial_sequences=True)
elif isinstance(
transformer_model,
(transformer.Bitransformer, transformer.StudentTeacher)):
mtf_samples = transformer_model.decode(
inputs, variable_dtype=get_variable_dtype())
else:
raise ValueError("unrecognized class")
mtf_samples = mtf.anonymize(mtf_samples)
inputs = mtf.anonymize(inputs)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
inputs = clean_decodes(lowering.export_to_tf_tensor(inputs))
outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples))
inputs = _maybe_detokenize(inputs, inputs_vocabulary(vocabulary))
outputs = _maybe_detokenize(outputs, targets_vocabulary(vocabulary))
if predict_batch_size is not None:
inputs = inputs[:predict_batch_size]
outputs = outputs[:predict_batch_size]
predictions = {
"inputs": inputs,
"outputs": outputs}
if mode in ["score", tf.estimator.ModeKeys.PREDICT]:
# When exporting a model, we need to communicate to TF-Serving that
# master variables need to be copied to their slave slice variables.
# Estimator uses a Scaffold's "local_init_op" for this purpose, so we
# augment the default "local_init_op" here.
#
# The "ready_op" is also constructed here to ensure the variables
# initialized by "local_init_op" are the same ones checked by "ready_op".
#
# WARNING: Any variables created outside of this model_fn()
# (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor
# checked by these ops.
def scaffold_fn():
return tf.train.Scaffold(
local_init_op=tf.group(
tf.train.Scaffold.default_local_init_op(),
lowering.copy_masters_to_slices(),
name="mtf_local_init_op"),
ready_op=tf.concat(
[tf.report_uninitialized_variables(),
resources.report_uninitialized_resources()],
axis=0,
name="mtf_ready_op"))
return tpu_estimator.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
scaffold_fn=scaffold_fn,
prediction_hooks=[mtf.MtfRestoreHook(lowering)])
assert (mode == tf.estimator.ModeKeys.TRAIN or
mode == tf.estimator.ModeKeys.EVAL)
def logits_and_loss(mtf_features, num_microbatches=1):
"""Compute logits and loss.
Args:
mtf_features: a dictionary
num_microbatches: integer
Returns:
logits: a mtf.Tensor
loss: a mtf.Tensor
"""
if model_type in ["lm", "delimited_lm"]:
inputs = transformer.autoregressive_inputs(
mtf_features["targets"],
sequence_id=mtf_features.get("targets_segmentation", None))
else:
inputs = mtf_features["inputs"]
if isinstance(transformer_model, transformer.Unitransformer):
position_kwargs = dict(
sequence_id=mtf_features.get("targets_segmentation", None),
position=mtf_features.get("targets_position", None),
)
elif isinstance(
transformer_model,
transformer.Bitransformer) or model_type == "bi_student_teacher":
position_kwargs = dict(
encoder_sequence_id=mtf_features.get("inputs_segmentation", None),
decoder_sequence_id=mtf_features.get("targets_segmentation",
None),
decoder_subsequence_id=mtf_features.get("targets_subsegmentation",
None),
encoder_position=mtf_features.get("inputs_position", None),
decoder_position=mtf_features.get("targets_position", None),
)
else:
raise ValueError("unrecognized class")
return transformer_model.call_simple(
inputs=inputs,
targets=mtf_features["targets"],
compute_loss=True,
mode=mode,
variable_dtype=get_variable_dtype(),
num_microbatches=num_microbatches,
**position_kwargs)
if mode == tf.estimator.ModeKeys.TRAIN:
num_microbatches = serialize_num_microbatches(batch_dim,
sequence_length,
mesh_shape,
layout_rules)
if num_microbatches > 1:
def serialized_fn(mtf_features):
return {"loss": logits_and_loss(mtf_features, num_microbatches)[1]}
var_grads, loss_dict = mtf.serialize_training_step(
mtf_features, serialized_fn, batch_dim, num_microbatches)
loss = loss_dict["loss"]
else:
loss = logits_and_loss(mtf_features)[1]
var_grads = mtf.gradients(
[loss], [v.outputs[0] for v in graph.trainable_variables])
if tpu_summaries:
mtf.scalar_summary("loss", loss)
if callable(learning_rate_schedule):
# the following happens on CPU since TPU can't handle summaries.
with mtf.utils.outside_all_rewrites():
learning_rate = learning_rate_schedule(
step=tf.train.get_global_step())
tf.summary.scalar("learning_rate", learning_rate)
else:
learning_rate = learning_rate_schedule
if isinstance(variable_filter, str):
pattern = re.compile(variable_filter)
variable_filter_fn = lambda v: pattern.search(v.name)
elif variable_filter is None:
variable_filter_fn = lambda v: True
elif callable(variable_filter):
variable_filter_fn = variable_filter
else:
raise ValueError(
"variable_filter must be None, a string, or a callable function")
trainable_vars = [
v for v in graph.trainable_variables if variable_filter_fn(v)]
trainable_var_grads = [
g for g, v in zip(var_grads, graph.trainable_variables)
if variable_filter_fn(v)]
if len(trainable_vars) != len(graph.trainable_variables):
tf.logging.info("Variables being trained:")
tf.logging.info([v.name for v in trainable_vars])
tf.logging.info("Variables not being trained:")
tf.logging.info([v.name for v in graph.trainable_variables
if not variable_filter_fn(v)])
opt = optimizer(learning_rate=learning_rate)
update_ops = opt.apply_grads(trainable_var_grads, trainable_vars)
lowering = mtf.Lowering(
graph, {mesh: mesh_impl},
autostack=autostack,
log_file=model_info_file)
tf_loss = lowering.export_to_tf_tensor(loss)
tf_loss = tf.cast(tf_loss, tf.float32)
if not use_tpu:
tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()],
"step, tf_loss")
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
tf_update_ops.append(tf.assign_add(global_step, 1))
train_op = tf.group(tf_update_ops)
if hasattr(transformer_model, "initialize"):
with mtf.utils.outside_all_rewrites():
transformer_model.initialize()
if tpu_summaries:
# has to be outside of
# with mtf.utils.outside_all_rewrites()
host_call = mtf.utils.create_host_call(model_dir)
mtf.utils.remove_summaries()
else:
host_call = None
with mtf.utils.outside_all_rewrites():
if init_checkpoint:
ckpt_vars = {v for v, _ in tf.train.list_variables(init_checkpoint)}
if init_variable_filter:
pattern = re.compile(init_variable_filter)
ckpt_vars = {v for v in ckpt_vars if pattern.search(v)}
global_vars = {v.op.name for v in tf.global_variables()}
filtered_global_vars = {
v for v in global_vars if should_load_variable(v)}
restore_vars = {
v for v in filtered_global_vars
if init_checkpoint_variable_mapping(v) in ckpt_vars}
tf.logging.info("Initializing variables from %s:", init_checkpoint)
tf.logging.debug("\n".join(sorted(restore_vars)))
tf.logging.info("Variables in %s but not in graph:", init_checkpoint)
tf.logging.info("\n".join(sorted(
ckpt_vars -
{init_checkpoint_variable_mapping(v)
for v in filtered_global_vars})))
tf.logging.info("Variables in graph but not in %s:", init_checkpoint)
tf.logging.info("\n".join(sorted(global_vars - restore_vars)))
tf.train.init_from_checkpoint(
init_checkpoint,
{init_checkpoint_variable_mapping(v): v for v in restore_vars}
)
# Copy master variables to slices. Must be called first.
restore_hook = mtf.MtfRestoreHook(lowering)
saver = tf.train.Saver(
tf.global_variables(),
sharded=True,
max_to_keep=keep_checkpoint_max,
keep_checkpoint_every_n_hours=2,
defer_build=False,
save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
saver_listener = mtf.MtfCheckpointSaverListener(lowering)
saver_hook = tf.train.CheckpointSaverHook(
model_dir,
save_steps=save_checkpoints_steps,
saver=saver,
listeners=[saver_listener])
gin_config_saver_hook = gin.tf.GinConfigSaverHook(
model_dir, summarize_config=True, include_step_in_filename=False)
training_hooks = [
restore_hook,
saver_hook,
gin_config_saver_hook,
]
if use_tpu:
return tpu_estimator.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=tf_loss,
train_op=train_op,
host_call=host_call,
training_hooks=training_hooks)
else:
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.TRAIN,
loss=tf_loss,
train_op=train_op,
training_chief_hooks=training_hooks)
elif mode == tf.estimator.ModeKeys.EVAL:
# perplexity eval
logits, loss = logits_and_loss(mtf_features)
# compute cross-entropy while still on TPU to avoid having to outfeed the
# logits, which might be big.
logits = mtf.cast(logits, tf.float32)
vocab_dim = logits.shape.dims[-1]
targets = mtf_features["targets"]
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
logits, targets, vocab_dim)
anon_cross_entropy = mtf.anonymize(cross_entropy)
predictions = mtf.cast(mtf.argmax(logits, vocab_dim), targets.dtype)
anon_predictions = mtf.anonymize(predictions)
anon_targets = mtf.anonymize(targets)
# 0=padding and negative targets are a hack to indicate no loss
anon_weights = mtf.cast(mtf.greater(anon_targets, 0), tf.float32)
if model_type == "delimited_lm":
anon_weights *= mtf.cast(
mtf.logical_not(transformer.delimited_lm_inputs_mask(anon_targets)),
dtype=tf.float32)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
tf_loss = tf.cast(tf_loss, tf.float32)
tf_predictions = lowering.export_to_tf_tensor(anon_predictions)
tf_cross_entropy = lowering.export_to_tf_tensor(anon_cross_entropy)
def simple_metrics(xent, predictions, labels, weights):
"""Simple metrics for teacher-forced eval."""
token_correct = tf.cast(
tf.equal(predictions, labels), tf.float32) * weights
sequence_correct = tf.cast(
tf.equal(tf.reduce_sum(token_correct, -1),
tf.reduce_sum(weights, -1)),
tf.float32)
sequence_weights = tf.cast(
tf.not_equal(tf.reduce_sum(weights, -1), 0),
tf.float32)
# the purpose of "mean_label" is as a checksum to ensure that
# models were evaluated on the same data.
return {"neg_log_perplexity": tf.metrics.mean(-xent, weights),
"token_accuracy": tf.metrics.mean(token_correct, weights),
"sequence_accuracy": tf.metrics.mean(
sequence_correct, sequence_weights),
"mean_label": tf.metrics.mean(
tf.cast(labels, tf.float32), weights),
"num_eval_tokens": metric_sum(weights, name="num_eval_tokens"),
"max_targets_length": metric_max(tf.reduce_sum(
weights, axis=-1), name="max_targets_length"),
}
labels = lowering.export_to_tf_tensor(anon_targets)
weights = lowering.export_to_tf_tensor(anon_weights)
eval_metrics = (simple_metrics, [
tf_cross_entropy, tf_predictions, labels, weights])
with mtf.utils.outside_all_rewrites():
restore_hook = mtf.MtfRestoreHook(lowering)
return tpu_estimator.TPUEstimatorSpec(
tf.estimator.ModeKeys.EVAL,
evaluation_hooks=[restore_hook],
loss=tf_loss,
eval_metrics=eval_metrics)
return my_model_fn