def make_loss_metric_fn()

in tensorflow_ranking/python/losses.py [0:0]


def make_loss_metric_fn(loss_key,
                        weights_feature_name=None,
                        lambda_weight=None,
                        name=None):
  """Factory method to create a metric based on a loss.

  Args:
    loss_key: A key in `RankingLossKey`.
    weights_feature_name: A `string` specifying the name of the weights feature
      in `features` dict.
    lambda_weight: A `_LambdaWeight` object.
    name: A `string` used as the name for this metric.

  Returns:
    A metric fn with the following Args:
    * `labels`: A `Tensor` of the same shape as `predictions` representing
    graded relevance.
    * `predictions`: A `Tensor` with shape [batch_size, list_size]. Each value
    is the ranking score of the corresponding example.
    * `features`: A dict of `Tensor`s that contains all features.
  """

  metric_dict = {
      RankingLossKey.PAIRWISE_HINGE_LOSS:
          losses_impl.PairwiseHingeLoss(name, lambda_weight=lambda_weight),
      RankingLossKey.PAIRWISE_LOGISTIC_LOSS:
          losses_impl.PairwiseLogisticLoss(name, lambda_weight=lambda_weight),
      RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS:
          losses_impl.PairwiseSoftZeroOneLoss(
              name, lambda_weight=lambda_weight),
      RankingLossKey.CIRCLE_LOSS:
          losses_impl.CircleLoss(name),
      RankingLossKey.SOFTMAX_LOSS:
          losses_impl.SoftmaxLoss(name, lambda_weight=lambda_weight),
      RankingLossKey.UNIQUE_SOFTMAX_LOSS:
          losses_impl.UniqueSoftmaxLoss(name, lambda_weight=lambda_weight),
      RankingLossKey.SIGMOID_CROSS_ENTROPY_LOSS:
          losses_impl.SigmoidCrossEntropyLoss(name),
      RankingLossKey.MEAN_SQUARED_LOSS:
          losses_impl.MeanSquaredLoss(name),
      RankingLossKey.LIST_MLE_LOSS:
          losses_impl.ListMLELoss(name, lambda_weight=lambda_weight),
      RankingLossKey.APPROX_NDCG_LOSS:
          losses_impl.ApproxNDCGLoss(name),
      RankingLossKey.APPROX_MRR_LOSS:
          losses_impl.ApproxMRRLoss(name),
      RankingLossKey.GUMBEL_APPROX_NDCG_LOSS:
          losses_impl.ApproxNDCGLoss(name),
      RankingLossKey.NEURAL_SORT_CROSS_ENTROPY_LOSS:
          losses_impl.NeuralSortCrossEntropyLoss(name),
      RankingLossKey.GUMBEL_NEURAL_SORT_CROSS_ENTROPY_LOSS:
          losses_impl.NeuralSortCrossEntropyLoss(name),
      RankingLossKey.NEURAL_SORT_NDCG_LOSS:
          losses_impl.NeuralSortNDCGLoss(name),
      RankingLossKey.GUMBEL_NEURAL_SORT_NDCG_LOSS:
          losses_impl.NeuralSortNDCGLoss(name),
  }

  def _get_weights(features):
    """Get weights tensor from features and reshape it to 2-D if necessary."""
    weights = None
    if weights_feature_name:
      weights = tf.convert_to_tensor(value=features[weights_feature_name])
      # Convert weights to a 2-D Tensor.
      weights = utils.reshape_to_2d(weights)
    return weights

  def metric_fn(labels, predictions, features):
    """Defines the metric fn."""
    weights = _get_weights(features)
    loss = metric_dict.get(loss_key, None)
    if loss is None:
      raise ValueError('loss_key {} not supported.'.format(loss_key))
    return loss.eval_metric(labels, predictions, weights)

  return metric_fn