def model_fn_builder()

in mesh_tensorflow/bert/run_pretraining.py [0:0]


def model_fn_builder(bert_config, init_checkpoint, learning_rate,
                     num_train_steps, num_warmup_steps):
  """Returns `model_fn` closure for TPUEstimator."""

  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    # MTF setup.
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    if FLAGS.use_tpu:
      ctx = params["context"]
      num_hosts = ctx.num_hosts
      host_placement_fn = ctx.tpu_host_placement_function
      device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
      tf.logging.info("device_list = %s" % device_list,)
      replica_cache_size = 300 * 1000000  # 300M per replica
      # Worker 0 caches all the TPU binaries.
      worker0_mem = replica_cache_size * ctx.num_replicas
      devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
      var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                    devices_memeory_usage)
      mesh_devices = [""] * mesh_shape.size
      physical_shape = list(ctx.device_assignment.topology.mesh_shape)
      logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
          mesh_shape.to_integer_list, physical_shape)
      mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          mesh_shape,
          layout_rules,
          mesh_devices,
          ctx.device_assignment,
          logical_to_physical=logical_to_physical)
    else:
      mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          mesh_shape, layout_rules, [""] * mesh_shape.size)
      var_placer = None

    mesh = mtf.Mesh(graph, "bert_mesh", var_placer)
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    masked_lm_positions = features["masked_lm_positions"]
    masked_lm_ids = features["masked_lm_ids"]
    masked_lm_weights = features["masked_lm_weights"]
    next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1)

    batch_size = input_ids.get_shape()[0].value
    batch_dim = mtf.Dimension("batch", batch_size)

    seq_length = input_ids.get_shape()[1].value
    seq_dim = mtf.Dimension("seq", seq_length)
    max_predictions_per_seq = masked_lm_positions.get_shape()[1].value
    max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq",
                                                max_predictions_per_seq)

    mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim])
    mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                          [batch_dim, seq_dim])
    mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                           [batch_dim, seq_dim])
    mtf_masked_lm_positions = mtf.import_tf_tensor(
        mesh, masked_lm_positions, [batch_dim, max_predictions_per_seq_dim])
    mtf_masked_lm_ids = mtf.import_tf_tensor(
        mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim])

    mtf_masked_lm_weights = mtf.import_tf_tensor(
        mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim])
    mtf_next_sentence_labels = mtf.import_tf_tensor(
        mesh, next_sentence_labels, [batch_dim])

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    model = bert_lib.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=mtf_input_ids,
        input_mask=mtf_input_mask,
        token_type_ids=mtf_segment_ids,
        layout=layout_rules,
        mesh_shape=mesh_shape)

    (masked_lm_loss, masked_lm_example_loss,
     masked_lm_logits) = model.get_masked_lm_output(
         mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights)

    (next_sentence_loss, next_sentence_example_loss,
     next_sentence_logits) = model.get_next_sentence_output(
         mtf_next_sentence_labels)

    extra_loss = model.get_extra_loss()

    total_loss = masked_lm_loss + next_sentence_loss
    total_loss = mtf.anonymize(total_loss)
    masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss)
    masked_lm_logits = mtf.anonymize(masked_lm_logits)
    next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss)
    next_sentence_logits = mtf.anonymize(next_sentence_logits)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
      _, update_ops = optimization_lib.create_optimizer(
          total_loss + extra_loss,
          learning_rate,
          num_train_steps,
          num_warmup_steps,
          optimizer=FLAGS.optimizer,
          clip_gradients=FLAGS.clip_gradients)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

    if mode == tf.estimator.ModeKeys.TRAIN:
      global_step = tf.train.get_global_step()
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)
    elif mode == tf.estimator.ModeKeys.EVAL:

      def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids,
                    masked_lm_weights, next_sentence_example_loss,
                    next_sentence_logits, next_sentence_labels):
        """Computes the loss and accuracy of the model."""
        masked_lm_logits = tf.reshape(masked_lm_logits,
                                      [-1, masked_lm_logits.shape[-1]])
        masked_lm_predictions = tf.argmax(
            masked_lm_logits, axis=-1, output_type=tf.int32)
        masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
        masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
        masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
        masked_lm_accuracy = tf.metrics.accuracy(
            labels=masked_lm_ids,
            predictions=masked_lm_predictions,
            weights=masked_lm_weights)
        masked_lm_mean_loss = tf.metrics.mean(
            values=masked_lm_example_loss, weights=masked_lm_weights)

        next_sentence_logits = tf.reshape(
            next_sentence_logits, [-1, next_sentence_logits.shape[-1]])
        next_sentence_predictions = tf.argmax(
            next_sentence_logits, axis=-1, output_type=tf.int32)
        next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
        next_sentence_accuracy = tf.metrics.accuracy(
            labels=next_sentence_labels, predictions=next_sentence_predictions)
        next_sentence_mean_loss = tf.metrics.mean(
            values=next_sentence_example_loss)

        return {
            "masked_lm_accuracy": masked_lm_accuracy,
            "masked_lm_loss": masked_lm_mean_loss,
            "next_sentence_accuracy": next_sentence_accuracy,
            "next_sentence_loss": next_sentence_mean_loss,
        }

      eval_metrics = (metric_fn, [
          lowering.export_to_tf_tensor(masked_lm_example_loss),
          lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids,
          masked_lm_weights,
          lowering.export_to_tf_tensor(next_sentence_example_loss),
          lowering.export_to_tf_tensor(next_sentence_logits),
          next_sentence_labels
      ])

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      if mode == tf.estimator.ModeKeys.TRAIN:
        saver = tf.train.Saver(
            tf.global_variables(),
            sharded=True,
            max_to_keep=10,
            keep_checkpoint_every_n_hours=2,
            defer_build=False,
            save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(
            FLAGS.output_dir,
            save_steps=1000,
            saver=saver,
            listeners=[saver_listener])

        return tf.estimator.tpu.TPUEstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook])
      elif mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.tpu.TPUEstimatorSpec(
            tf.estimator.ModeKeys.EVAL,
            evaluation_hooks=[restore_hook],
            loss=tf_loss,
            eval_metrics=eval_metrics)

  return model_fn