in tensorflow_ranking/python/head.py [0:0]
def create_estimator_spec(self,
features,
mode,
logits,
labels=None,
regularization_losses=None):
"""See `_AbstractRankingHead`."""
with tf.compat.v1.name_scope(self.name, 'multi_head'):
self._check_logits_and_labels(logits, labels)
# Get all estimator spec.
all_estimator_spec = []
for head in self._heads:
all_estimator_spec.append(
head.create_estimator_spec(
features=features,
mode=mode,
logits=logits[head.name],
labels=labels[head.name] if labels else None))
# Predict.
if mode == tf.estimator.ModeKeys.PREDICT:
export_outputs = self._merge_predict_export_outputs(all_estimator_spec)
return tf.estimator.EstimatorSpec(
mode=mode, predictions=logits, export_outputs=export_outputs)
# Compute the merged loss and eval metrics.
loss = self._merge_loss(labels, logits, features, mode,
regularization_losses)
eval_metric_ops = self._merge_metrics(all_estimator_spec)
# Eval.
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=logits,
loss=loss,
eval_metric_ops=eval_metric_ops)
# Train.
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=_get_train_op(loss, self._train_op_fn, self._optimizer),
predictions=logits,
eval_metric_ops=eval_metric_ops)
raise ValueError('mode={} unrecognized'.format(mode))