in tensorflow_ranking/python/head.py [0:0]
def create_estimator_spec(self,
features,
mode,
logits,
labels=None,
regularization_losses=None):
"""See `_AbstractRankingHead`."""
logits = tf.convert_to_tensor(value=logits)
# Predict.
with tf.compat.v1.name_scope(self._name, 'head'):
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=logits,
export_outputs={
_DEFAULT_SERVING_KEY:
tf.estimator.export.RegressionOutput(logits),
_REGRESS_SERVING_KEY:
tf.estimator.export.RegressionOutput(logits),
_PREDICT_SERVING_KEY:
tf.estimator.export.PredictOutput(logits),
})
training_loss = self.create_loss(
features=features, mode=mode, logits=logits, labels=labels)
if regularization_losses:
regularization_loss = tf.add_n(regularization_losses)
regularized_training_loss = tf.add(training_loss, regularization_loss)
else:
regularized_training_loss = training_loss
# Eval.
if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = {
name:
metric_fn(labels=labels, predictions=logits, features=features)
for name, metric_fn in six.iteritems(self._eval_metric_fns)
}
eval_metric_ops.update(self._labels_and_logits_metrics(labels, logits))
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=logits,
loss=regularized_training_loss,
eval_metric_ops=eval_metric_ops)
# Train.
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(
mode=mode,
loss=regularized_training_loss,
train_op=_get_train_op(regularized_training_loss, self._train_op_fn,
self._optimizer),
predictions=logits)
raise ValueError('mode={} unrecognized'.format(mode))