def _create_keras_model_fn()

in tensorflow_estimator/python/estimator/keras_lib.py [0:0]


def _create_keras_model_fn(keras_model,
                           custom_objects=None,
                           save_object_ckpt=False,
                           metric_names_map=None,
                           export_outputs=None):
  """Creates model_fn for keras Estimator.

  Args:
    keras_model: an instance of compiled keras model.
    custom_objects: Dictionary for custom objects.
    save_object_ckpt: Whether to save an object-based checkpoint.
    metric_names_map: Optional dictionary mapping Keras model output metric
      names to custom names.
    export_outputs: Optional dictionary mapping custom names to a subclass of
      `tf.estimator.export.ExportOutput`.

  Returns:
    The model_fn for a keras Estimator.
  """
  # Get optimizer config in the current context (since model_fn is called in the
  # estimator graph and session). OptimizerV2 objects serialize variable/tensor
  # hyperparameters in their configs, resulting to wrong-session errors during
  # model cloning.
  try:
    if isinstance(keras_model.optimizer, (tuple, list)):
      optimizer_config = [opt.get_config() for opt in keras_model.optimizer]
    else:
      optimizer_config = keras_model.optimizer.get_config()
  except (NotImplementedError, AttributeError):
    # TFOptimizers and other custom optimizers do not have a config.
    optimizer_config = None

  def model_fn(features, labels, mode):
    """model_fn for keras Estimator."""
    model = _clone_and_build_model(
        mode=mode,
        keras_model=keras_model,
        custom_objects=custom_objects,
        features=features,
        labels=labels,
        optimizer_config=optimizer_config)
    model_output_names = []
    # We need to make sure that the output names of the last layer in the model
    # is the same for each of the cloned models. This is required for mirrored
    # strategy when we call regroup.
    if tf.distribute.has_strategy():
      for name in model.output_names:
        name = re.compile(r'_\d$').sub('', name)
        model_output_names.append(name)
    else:
      model_output_names = model.output_names

    # Get inputs to EstimatorSpec
    predictions = dict(zip(model_output_names, model.outputs))

    loss = None
    train_op = None
    eval_metric_ops = None

    # Set loss and metric only during train and evaluate.
    if mode is not ModeKeys.PREDICT:
      if mode is ModeKeys.TRAIN:
        model._make_train_function()  # pylint: disable=protected-access
      else:
        model._make_test_function()  # pylint: disable=protected-access
      loss = model.total_loss

      eval_metric_ops = _convert_keras_metrics_to_estimator(
          model, metric_names_map)

    # Set train_op only during train.
    if mode is ModeKeys.TRAIN:
      train_op = model.train_function.updates_op

    if (not model._is_graph_network and
        hasattr(keras_model, '_original_attributes_cache') and
        keras_model._original_attributes_cache is not None):
      # To avoid `model_fn` being destructive for the initial model argument.
      (tf.compat.v2.keras.__internal__.models.
       in_place_subclassed_model_state_restoration(keras_model))

    scaffold = None
    if save_object_ckpt:
      model._track_trackable(tf.compat.v1.train.get_global_step(),
                             'estimator_global_step')
      # Create saver that maps variable names to object-checkpoint keys.
      object_graph = tf.compat.v2.__internal__.tracking.ObjectGraphView(model)
      var_list = object_graph.frozen_saveable_objects()
      saver = tf.compat.v1.train.Saver(var_list=var_list, sharded=True)
      saver._object_restore_saver = trackable_util.frozen_saver(model)
      scaffold = tf.compat.v1.train.Scaffold(saver=saver)

    final_export_outputs = {
        _DEFAULT_SERVING_KEY: export_lib.PredictOutput(predictions)
    }
    if export_outputs is not None:
      different_keys = set(export_outputs.keys()) - set(model.output_names)
      if different_keys:
        raise FormattedKeyError(
            'The list passed into {obj_name} does not cover requested '
            '{order_name} keys defined in the keras model.'
            '\n\tExpected keys: {order_keys}'
            '\n\t{obj_name} keys: {obj_keys}'
            '\n\tMissed keys: {different_keys}'.format(
                order_name=export_outputs,
                order_keys=set(export_outputs.keys()),
                obj_name=model.output_names,
                obj_keys=set(model.output_names),
                different_keys=different_keys))
      for key, export_output_cls in export_outputs.items():
        final_export_outputs[key] = export_output_cls(predictions[key])

    return model_fn_lib.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=final_export_outputs,
        scaffold=scaffold)

  return model_fn