def create_estimator_spec()

in tensorflow_ranking/python/head.py [0:0]


  def create_estimator_spec(self,
                            features,
                            mode,
                            logits,
                            labels=None,
                            regularization_losses=None):
    """See `_AbstractRankingHead`."""
    logits = tf.convert_to_tensor(value=logits)
    # Predict.
    with tf.compat.v1.name_scope(self._name, 'head'):
      if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=logits,
            export_outputs={
                _DEFAULT_SERVING_KEY:
                    tf.estimator.export.RegressionOutput(logits),
                _REGRESS_SERVING_KEY:
                    tf.estimator.export.RegressionOutput(logits),
                _PREDICT_SERVING_KEY:
                    tf.estimator.export.PredictOutput(logits),
            })

      training_loss = self.create_loss(
          features=features, mode=mode, logits=logits, labels=labels)
      if regularization_losses:
        regularization_loss = tf.add_n(regularization_losses)
        regularized_training_loss = tf.add(training_loss, regularization_loss)
      else:
        regularized_training_loss = training_loss

      # Eval.
      if mode == tf.estimator.ModeKeys.EVAL:
        eval_metric_ops = {
            name:
            metric_fn(labels=labels, predictions=logits, features=features)
            for name, metric_fn in six.iteritems(self._eval_metric_fns)
        }
        eval_metric_ops.update(self._labels_and_logits_metrics(labels, logits))
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=logits,
            loss=regularized_training_loss,
            eval_metric_ops=eval_metric_ops)

      # Train.
      if mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=regularized_training_loss,
            train_op=_get_train_op(regularized_training_loss, self._train_op_fn,
                                   self._optimizer),
            predictions=logits)
      raise ValueError('mode={} unrecognized'.format(mode))