def create_estimator_spec()

in tensorflow_ranking/python/head.py [0:0]


  def create_estimator_spec(self,
                            features,
                            mode,
                            logits,
                            labels=None,
                            regularization_losses=None):
    """See `_AbstractRankingHead`."""
    with tf.compat.v1.name_scope(self.name, 'multi_head'):
      self._check_logits_and_labels(logits, labels)
      # Get all estimator spec.
      all_estimator_spec = []
      for head in self._heads:
        all_estimator_spec.append(
            head.create_estimator_spec(
                features=features,
                mode=mode,
                logits=logits[head.name],
                labels=labels[head.name] if labels else None))
      # Predict.
      if mode == tf.estimator.ModeKeys.PREDICT:
        export_outputs = self._merge_predict_export_outputs(all_estimator_spec)
        return tf.estimator.EstimatorSpec(
            mode=mode, predictions=logits, export_outputs=export_outputs)

      # Compute the merged loss and eval metrics.
      loss = self._merge_loss(labels, logits, features, mode,
                              regularization_losses)
      eval_metric_ops = self._merge_metrics(all_estimator_spec)

      # Eval.
      if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=logits,
            loss=loss,
            eval_metric_ops=eval_metric_ops)
      # Train.
      if mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=_get_train_op(loss, self._train_op_fn, self._optimizer),
            predictions=logits,
            eval_metric_ops=eval_metric_ops)
      raise ValueError('mode={} unrecognized'.format(mode))