in mesh_tensorflow/bert/run_pretraining.py [0:0]
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
# MTF setup.
graph = mtf.Graph()
mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
if FLAGS.use_tpu:
ctx = params["context"]
num_hosts = ctx.num_hosts
host_placement_fn = ctx.tpu_host_placement_function
device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
tf.logging.info("device_list = %s" % device_list,)
replica_cache_size = 300 * 1000000 # 300M per replica
# Worker 0 caches all the TPU binaries.
worker0_mem = replica_cache_size * ctx.num_replicas
devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
var_placer = mtf.utils.BalancedVariablePlacer(device_list,
devices_memeory_usage)
mesh_devices = [""] * mesh_shape.size
physical_shape = list(ctx.device_assignment.topology.mesh_shape)
logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
mesh_shape.to_integer_list, physical_shape)
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
mesh_shape,
layout_rules,
mesh_devices,
ctx.device_assignment,
logical_to_physical=logical_to_physical)
else:
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, layout_rules, [""] * mesh_shape.size)
var_placer = None
mesh = mtf.Mesh(graph, "bert_mesh", var_placer)
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
masked_lm_positions = features["masked_lm_positions"]
masked_lm_ids = features["masked_lm_ids"]
masked_lm_weights = features["masked_lm_weights"]
next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1)
batch_size = input_ids.get_shape()[0].value
batch_dim = mtf.Dimension("batch", batch_size)
seq_length = input_ids.get_shape()[1].value
seq_dim = mtf.Dimension("seq", seq_length)
max_predictions_per_seq = masked_lm_positions.get_shape()[1].value
max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq",
max_predictions_per_seq)
mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim])
mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
[batch_dim, seq_dim])
mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
[batch_dim, seq_dim])
mtf_masked_lm_positions = mtf.import_tf_tensor(
mesh, masked_lm_positions, [batch_dim, max_predictions_per_seq_dim])
mtf_masked_lm_ids = mtf.import_tf_tensor(
mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim])
mtf_masked_lm_weights = mtf.import_tf_tensor(
mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim])
mtf_next_sentence_labels = mtf.import_tf_tensor(
mesh, next_sentence_labels, [batch_dim])
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
model = bert_lib.BertModel(
config=bert_config,
is_training=is_training,
input_ids=mtf_input_ids,
input_mask=mtf_input_mask,
token_type_ids=mtf_segment_ids,
layout=layout_rules,
mesh_shape=mesh_shape)
(masked_lm_loss, masked_lm_example_loss,
masked_lm_logits) = model.get_masked_lm_output(
mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights)
(next_sentence_loss, next_sentence_example_loss,
next_sentence_logits) = model.get_next_sentence_output(
mtf_next_sentence_labels)
extra_loss = model.get_extra_loss()
total_loss = masked_lm_loss + next_sentence_loss
total_loss = mtf.anonymize(total_loss)
masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss)
masked_lm_logits = mtf.anonymize(masked_lm_logits)
next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss)
next_sentence_logits = mtf.anonymize(next_sentence_logits)
# TRAIN mode
if mode == tf.estimator.ModeKeys.TRAIN:
_, update_ops = optimization_lib.create_optimizer(
total_loss + extra_loss,
learning_rate,
num_train_steps,
num_warmup_steps,
optimizer=FLAGS.optimizer,
clip_gradients=FLAGS.clip_gradients)
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_global_step()
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
tf_update_ops.append(tf.assign_add(global_step, 1))
tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
train_op = tf.group(tf_update_ops)
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids,
masked_lm_weights, next_sentence_example_loss,
next_sentence_logits, next_sentence_labels):
"""Computes the loss and accuracy of the model."""
masked_lm_logits = tf.reshape(masked_lm_logits,
[-1, masked_lm_logits.shape[-1]])
masked_lm_predictions = tf.argmax(
masked_lm_logits, axis=-1, output_type=tf.int32)
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
masked_lm_accuracy = tf.metrics.accuracy(
labels=masked_lm_ids,
predictions=masked_lm_predictions,
weights=masked_lm_weights)
masked_lm_mean_loss = tf.metrics.mean(
values=masked_lm_example_loss, weights=masked_lm_weights)
next_sentence_logits = tf.reshape(
next_sentence_logits, [-1, next_sentence_logits.shape[-1]])
next_sentence_predictions = tf.argmax(
next_sentence_logits, axis=-1, output_type=tf.int32)
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
next_sentence_accuracy = tf.metrics.accuracy(
labels=next_sentence_labels, predictions=next_sentence_predictions)
next_sentence_mean_loss = tf.metrics.mean(
values=next_sentence_example_loss)
return {
"masked_lm_accuracy": masked_lm_accuracy,
"masked_lm_loss": masked_lm_mean_loss,
"next_sentence_accuracy": next_sentence_accuracy,
"next_sentence_loss": next_sentence_mean_loss,
}
eval_metrics = (metric_fn, [
lowering.export_to_tf_tensor(masked_lm_example_loss),
lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids,
masked_lm_weights,
lowering.export_to_tf_tensor(next_sentence_example_loss),
lowering.export_to_tf_tensor(next_sentence_logits),
next_sentence_labels
])
with mtf.utils.outside_all_rewrites():
# Copy master variables to slices. Must be called first.
restore_hook = mtf.MtfRestoreHook(lowering)
if mode == tf.estimator.ModeKeys.TRAIN:
saver = tf.train.Saver(
tf.global_variables(),
sharded=True,
max_to_keep=10,
keep_checkpoint_every_n_hours=2,
defer_build=False,
save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
saver_listener = mtf.MtfCheckpointSaverListener(lowering)
saver_hook = tf.train.CheckpointSaverHook(
FLAGS.output_dir,
save_steps=1000,
saver=saver,
listeners=[saver_listener])
return tf.estimator.tpu.TPUEstimatorSpec(
tf.estimator.ModeKeys.TRAIN,
loss=tf_loss,
train_op=train_op,
training_hooks=[restore_hook, saver_hook])
elif mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.tpu.TPUEstimatorSpec(
tf.estimator.ModeKeys.EVAL,
evaluation_hooks=[restore_hook],
loss=tf_loss,
eval_metrics=eval_metrics)
return model_fn