def make_loss_fn()

in tensorflow_ranking/python/losses.py [0:0]


def make_loss_fn(loss_keys,
                 loss_weights=None,
                 weights_feature_name=None,
                 lambda_weight=None,
                 reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
                 name=None,
                 params=None,
                 gumbel_params=None):
  """Makes a loss function using a single loss or multiple losses.

  Args:
    loss_keys: A string or list of strings representing loss keys defined in
      `RankingLossKey`. Listed loss functions will be combined in a weighted
      manner, with weights specified by `loss_weights`. If `loss_weights` is
      None, default weight of 1 will be used.
    loss_weights: List of weights, same length as `loss_keys`. Used when merging
      losses to calculate the weighted sum of losses. If `None`, all losses are
      weighted equally with weight being 1.
    weights_feature_name: A string specifying the name of the weights feature in
      `features` dict.
    lambda_weight: A `_LambdaWeight` object created by factory methods like
      `create_ndcg_lambda_weight()`.
    reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch.
    name: A string used as the name for this loss.
    params: A string-keyed dictionary that contains any other loss-specific
      arguments.
    gumbel_params: A string-keyed dictionary that contains other
      `gumbel_softmax_sample` arguments.

  Returns:
    A function _loss_fn(). See `_loss_fn()` for its signature.

  Raises:
    ValueError: If `reduction` is invalid.
    ValueError: If `loss_keys` is None or empty.
    ValueError: If `loss_keys` and `loss_weights` have different sizes.
  """
  if (reduction not in tf.compat.v1.losses.Reduction.all() or
      reduction == tf.compat.v1.losses.Reduction.NONE):
    raise ValueError('Invalid reduction: {}'.format(reduction))

  if not loss_keys:
    raise ValueError('loss_keys cannot be None or empty.')

  if not isinstance(loss_keys, list):
    loss_keys = [loss_keys]

  if loss_weights:
    if len(loss_keys) != len(loss_weights):
      raise ValueError('loss_keys and loss_weights must have the same size.')

  params = params or {}
  gumbel_params = gumbel_params or {}
  gumbel_sampler = losses_impl.GumbelSampler(**gumbel_params)

  def _loss_fn(labels, logits, features):
    """Computes a single loss or weighted combination of losses.

    Args:
      labels: A `Tensor` of the same shape as `logits` representing relevance.
      logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
        ranking score of the corresponding item.
      features: Dict of Tensors of shape [batch_size, list_size, ...] for
        per-example features and shape [batch_size, ...] for non-example context
        features.

    Returns:
      An op for a single loss or weighted combination of multiple losses.

    Raises:
      ValueError: If `loss_keys` is invalid.
    """
    weights = None
    if weights_feature_name:
      weights = tf.convert_to_tensor(value=features[weights_feature_name])
      # Convert weights to a 2-D Tensor.
      weights = utils.reshape_to_2d(weights)

    gbl_labels, gbl_logits, gbl_weights = gumbel_sampler.sample(
        labels, logits, weights=weights)

    loss_kwargs = {
        'labels': labels,
        'logits': logits,
        'weights': weights,
        'reduction': reduction,
        'name': name,
    }
    gbl_loss_kwargs = {
        'labels': gbl_labels,
        'logits': gbl_logits,
        'weights': gbl_weights,
        'reduction': reduction,
        'name': name,
    }
    loss_kwargs.update(params)
    gbl_loss_kwargs.update(params)

    loss_kwargs_with_lambda_weight = loss_kwargs.copy()
    loss_kwargs_with_lambda_weight['lambda_weight'] = lambda_weight

    key_to_fn = {
        RankingLossKey.PAIRWISE_HINGE_LOSS:
            (_pairwise_hinge_loss, loss_kwargs_with_lambda_weight),
        RankingLossKey.PAIRWISE_LOGISTIC_LOSS:
            (_pairwise_logistic_loss, loss_kwargs_with_lambda_weight),
        RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS:
            (_pairwise_soft_zero_one_loss, loss_kwargs_with_lambda_weight),
        RankingLossKey.CIRCLE_LOSS:
            (_circle_loss, loss_kwargs_with_lambda_weight),
        RankingLossKey.SOFTMAX_LOSS:
            (_softmax_loss, loss_kwargs_with_lambda_weight),
        RankingLossKey.UNIQUE_SOFTMAX_LOSS:
            (_unique_softmax_loss, loss_kwargs_with_lambda_weight),
        RankingLossKey.SIGMOID_CROSS_ENTROPY_LOSS:
            (_sigmoid_cross_entropy_loss, loss_kwargs),
        RankingLossKey.MEAN_SQUARED_LOSS: (_mean_squared_loss, loss_kwargs),
        RankingLossKey.LIST_MLE_LOSS:
            (_list_mle_loss, loss_kwargs_with_lambda_weight),
        RankingLossKey.APPROX_NDCG_LOSS: (_approx_ndcg_loss, loss_kwargs),
        RankingLossKey.APPROX_MRR_LOSS: (_approx_mrr_loss, loss_kwargs),
        RankingLossKey.GUMBEL_APPROX_NDCG_LOSS:
            (_approx_ndcg_loss, gbl_loss_kwargs),
        RankingLossKey.NEURAL_SORT_CROSS_ENTROPY_LOSS:
            (_neural_sort_cross_entropy_loss, loss_kwargs),
        RankingLossKey.GUMBEL_NEURAL_SORT_CROSS_ENTROPY_LOSS:
            (_neural_sort_cross_entropy_loss, gbl_loss_kwargs),
        RankingLossKey.NEURAL_SORT_NDCG_LOSS:
            (_neural_sort_ndcg_loss, loss_kwargs),
        RankingLossKey.GUMBEL_NEURAL_SORT_NDCG_LOSS:
            (_neural_sort_ndcg_loss, gbl_loss_kwargs),
    }

    # Obtain the list of loss ops.
    loss_ops = []
    for loss_key in loss_keys:
      if loss_key not in key_to_fn:
        raise ValueError('Invalid loss_key: {}.'.format(loss_key))
      loss_fn, kwargs = key_to_fn[loss_key]
      loss_ops.append(loss_fn(**kwargs))

    # Compute weighted combination of losses.
    if loss_weights:
      weighted_losses = []
      for loss_op, loss_weight in zip(loss_ops, loss_weights):
        weighted_losses.append(tf.multiply(loss_op, loss_weight))
    else:
      weighted_losses = loss_ops

    return tf.add_n(weighted_losses)

  return _loss_fn