def model_to_estimator()

in tensorflow_ranking/python/keras/estimator.py [0:0]


def model_to_estimator(model,
                       model_dir=None,
                       config=None,
                       custom_objects=None,
                       weights_feature_name=None,
                       warm_start_from=None,
                       serving_default="regress"):
  """Keras ranking model to Estimator.

  This function is based on the custom model_fn in TF2.0 migration guide.
  https://www.tensorflow.org/guide/migrate#custom_model_fn_with_tf_20_symbols

  Args:
    model: (tf.keras.Model) A ranking keras model, which  can be created using
      `tfr.keras.model.create_keras_model`. Masking is handled inside this
      function.
    model_dir: (str) Directory to save `Estimator` model graph and checkpoints.
    config: (tf.estimator.RunConfig) Specified config for distributed training
      and checkpointing.
    custom_objects: (dict) mapping names (strings) to custom objects (classes
      and functions) to be considered during deserialization.
    weights_feature_name: (str) A string specifying the name of the per-example
      (of shape [batch_size, list_size]) or per-list (of shape [batch_size, 1])
      weights feature in `features` dict.
    warm_start_from: (`tf.estimator.WarmStartSettings`) settings to warm-start
      the `tf.estimator.Estimator`.
    serving_default: (str) Specifies "regress" or "predict" as the
      serving_default signature.

  Returns:
    (tf.estimator.Estimator) A ranking estimator.

  Raises:
    ValueError: if weights_feature_name is not in features.
  """

  def _clone_fn(obj):
    """Clone keras object."""
    fn_args = function_utils.fn_args(obj.__class__.from_config)

    if "custom_objects" in fn_args:
      return obj.__class__.from_config(
          obj.get_config(), custom_objects=custom_objects)

    return obj.__class__.from_config(obj.get_config())

  def _model_fn(features, labels, mode, params, config):
    """Defines an `Estimator` `model_fn`."""
    del [config, params]

    # In Estimator, all sub-graphs need to be constructed inside the model_fn.
    # Hence, ranker, losses, metrics and optimizer are cloned inside this
    # function.
    ranker = tf.keras.models.clone_model(model, clone_function=_clone_fn)
    training = (mode == tf.compat.v1.estimator.ModeKeys.TRAIN)

    weights = None
    if weights_feature_name and mode != tf.compat.v1.estimator.ModeKeys.PREDICT:
      if weights_feature_name not in features:
        raise ValueError(
            "weights_feature '{0}' can not be found in 'features'.".format(
                weights_feature_name))
      else:
        weights = utils.reshape_to_2d(features.pop(weights_feature_name))

    logits = ranker(features, training=training)

    if serving_default not in ["regress", "predict"]:
      raise ValueError("serving_default should be 'regress' or 'predict', "
                       "but got {}".format(serving_default))

    if serving_default == "regress":
      default_export_output = tf.compat.v1.estimator.export.RegressionOutput(
          logits)
    else:
      default_export_output = tf.compat.v1.estimator.export.PredictOutput(
          logits)
    export_outputs = {
        _DEFAULT_SERVING_KEY: default_export_output,
        _REGRESS_SERVING_KEY:
            tf.compat.v1.estimator.export.RegressionOutput(logits),
        _PREDICT_SERVING_KEY:
            tf.compat.v1.estimator.export.PredictOutput(logits)
    }

    if mode == tf.compat.v1.estimator.ModeKeys.PREDICT:
      return tf.compat.v1.estimator.EstimatorSpec(mode=mode, predictions=logits,
                                                  export_outputs=export_outputs)

    loss = _clone_fn(model.loss)
    total_loss = loss(labels, logits, sample_weight=weights)

    keras_metrics = []
    for metric in model.metrics:
      keras_metrics.append(_clone_fn(metric))
    # Adding default metrics here as model.metrics does not contain custom
    # metrics.
    keras_metrics += metrics.default_keras_metrics()
    eval_metric_ops = {}
    for keras_metric in keras_metrics:
      keras_metric.update_state(labels, logits, sample_weight=weights)
      eval_metric_ops[keras_metric.name] = keras_metric

    train_op = None
    if training:
      optimizer = _clone_fn(model.optimizer)
      optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()
      # Get both the unconditional updates (the None part)
      # and the input-conditional updates (the features part).
      # These updates are for layers like BatchNormalization, which have
      # separate update and minimize ops.
      update_ops = ranker.get_updates_for(None) + ranker.get_updates_for(
          features)
      minimize_op = optimizer.get_updates(
          loss=total_loss, params=ranker.trainable_variables)[0]
      train_op = tf.group(minimize_op, *update_ops)

    return tf.compat.v1.estimator.EstimatorSpec(
        mode=mode,
        predictions=logits,
        loss=total_loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=export_outputs)

  return tf.compat.v1.estimator.Estimator(
      model_fn=_model_fn,
      config=config,
      model_dir=model_dir,
      warm_start_from=warm_start_from)