in mesh_tensorflow/bert/run_classifier.py [0:0]
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
# MTF setup.
graph = mtf.Graph()
mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
ctx = params["context"]
num_hosts = ctx.num_hosts
host_placement_fn = ctx.tpu_host_placement_function
device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
tf.logging.info("device_list = %s" % device_list,)
replica_cache_size = 300 * 1000000 # 300M per replica
# Worker 0 caches all the TPU binaries.
worker0_mem = replica_cache_size * ctx.num_replicas
devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
var_placer = mtf.utils.BalancedVariablePlacer(device_list,
devices_memeory_usage)
mesh_devices = [""] * mesh_shape.size
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
mesh_devices,
ctx.device_assignment)
mesh = mtf.Mesh(graph, "bert_mesh", var_placer)
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
is_real_example = None
if "is_real_example" in features:
is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
else:
is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
batch_size = input_ids.get_shape()[0].value
batch_dim = mtf.Dimension("batch", batch_size)
seq_length = input_ids.get_shape()[1].value
seq_dim = mtf.Dimension("seq", seq_length)
num_labels_dim = mtf.Dimension("seq", num_labels)
mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim])
mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
[batch_dim, seq_dim])
mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
[batch_dim, seq_dim])
mtf_label_ids = mtf.import_tf_tensor(mesh, label_ids, [batch_dim])
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
(total_loss, per_example_loss, logits,
probabilities) = create_model(bert_config, is_training, mtf_input_ids,
mtf_input_mask, mtf_segment_ids,
mtf_label_ids, num_labels_dim,
layout_rules, mesh_shape)
total_loss = mtf.anonymize(total_loss)
per_example_loss = mtf.anonymize(per_example_loss)
logits = mtf.anonymize(logits)
if mode == tf.estimator.ModeKeys.TRAIN:
_, update_ops = optimization_lib.create_optimizer(
total_loss,
learning_rate,
num_train_steps,
num_warmup_steps,
max_optimized_variable_size=FLAGS.max_optimized_variable_size,
optimizer=FLAGS.optimizer,
clip_gradients=FLAGS.clip_gradients)
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_global_step()
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
tf_update_ops.append(tf.assign_add(global_step, 1))
tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
train_op = tf.group(tf_update_ops)
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(
labels=label_ids, predictions=predictions, weights=is_real_example)
loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
return {
"eval_accuracy": accuracy,
"eval_loss": loss,
}
eval_metrics = (metric_fn, [
lowering.export_to_tf_tensor(per_example_loss), label_ids,
lowering.export_to_tf_tensor(logits), is_real_example
])
tvars = tf.trainable_variables()
initialized_variable_names = {}
scaffold_fn = None
if init_checkpoint:
(assignment_map, initialized_variable_names
) = bert_lib.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
if use_tpu:
def tpu_scaffold():
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
return tf.train.Scaffold()
scaffold_fn = tpu_scaffold
else:
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
with mtf.utils.outside_all_rewrites():
# Copy master variables to slices. Must be called first.
restore_hook = mtf.MtfRestoreHook(lowering)
if mode == tf.estimator.ModeKeys.TRAIN:
saver = tf.train.Saver(
tf.global_variables(),
sharded=True,
max_to_keep=10,
keep_checkpoint_every_n_hours=2,
defer_build=False,
save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
saver_listener = mtf.MtfCheckpointSaverListener(lowering)
saver_hook = tf.train.CheckpointSaverHook(
FLAGS.output_dir,
save_steps=1000,
saver=saver,
listeners=[saver_listener])
return tf.estimator.tpu.TPUEstimatorSpec(
mode,
loss=tf_loss,
train_op=train_op,
training_hooks=[restore_hook, saver_hook],
scaffold_fn=scaffold_fn)
elif mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.tpu.TPUEstimatorSpec(
mode,
evaluation_hooks=[restore_hook],
loss=tf_loss,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
else:
return tf.estimator.tpu.TPUEstimatorSpec(
mode,
prediction_hooks=[restore_hook],
predictions={
"probabilities": lowering.export_to_tf_tensor(probabilities)
},
scaffold_fn=scaffold_fn)
return model_fn