in tensorflow_ranking/python/losses.py [0:0]
def make_loss_fn(loss_keys,
loss_weights=None,
weights_feature_name=None,
lambda_weight=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None,
params=None,
gumbel_params=None):
"""Makes a loss function using a single loss or multiple losses.
Args:
loss_keys: A string or list of strings representing loss keys defined in
`RankingLossKey`. Listed loss functions will be combined in a weighted
manner, with weights specified by `loss_weights`. If `loss_weights` is
None, default weight of 1 will be used.
loss_weights: List of weights, same length as `loss_keys`. Used when merging
losses to calculate the weighted sum of losses. If `None`, all losses are
weighted equally with weight being 1.
weights_feature_name: A string specifying the name of the weights feature in
`features` dict.
lambda_weight: A `_LambdaWeight` object created by factory methods like
`create_ndcg_lambda_weight()`.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
params: A string-keyed dictionary that contains any other loss-specific
arguments.
gumbel_params: A string-keyed dictionary that contains other
`gumbel_softmax_sample` arguments.
Returns:
A function _loss_fn(). See `_loss_fn()` for its signature.
Raises:
ValueError: If `reduction` is invalid.
ValueError: If `loss_keys` is None or empty.
ValueError: If `loss_keys` and `loss_weights` have different sizes.
"""
if (reduction not in tf.compat.v1.losses.Reduction.all() or
reduction == tf.compat.v1.losses.Reduction.NONE):
raise ValueError('Invalid reduction: {}'.format(reduction))
if not loss_keys:
raise ValueError('loss_keys cannot be None or empty.')
if not isinstance(loss_keys, list):
loss_keys = [loss_keys]
if loss_weights:
if len(loss_keys) != len(loss_weights):
raise ValueError('loss_keys and loss_weights must have the same size.')
params = params or {}
gumbel_params = gumbel_params or {}
gumbel_sampler = losses_impl.GumbelSampler(**gumbel_params)
def _loss_fn(labels, logits, features):
"""Computes a single loss or weighted combination of losses.
Args:
labels: A `Tensor` of the same shape as `logits` representing relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
features: Dict of Tensors of shape [batch_size, list_size, ...] for
per-example features and shape [batch_size, ...] for non-example context
features.
Returns:
An op for a single loss or weighted combination of multiple losses.
Raises:
ValueError: If `loss_keys` is invalid.
"""
weights = None
if weights_feature_name:
weights = tf.convert_to_tensor(value=features[weights_feature_name])
# Convert weights to a 2-D Tensor.
weights = utils.reshape_to_2d(weights)
gbl_labels, gbl_logits, gbl_weights = gumbel_sampler.sample(
labels, logits, weights=weights)
loss_kwargs = {
'labels': labels,
'logits': logits,
'weights': weights,
'reduction': reduction,
'name': name,
}
gbl_loss_kwargs = {
'labels': gbl_labels,
'logits': gbl_logits,
'weights': gbl_weights,
'reduction': reduction,
'name': name,
}
loss_kwargs.update(params)
gbl_loss_kwargs.update(params)
loss_kwargs_with_lambda_weight = loss_kwargs.copy()
loss_kwargs_with_lambda_weight['lambda_weight'] = lambda_weight
key_to_fn = {
RankingLossKey.PAIRWISE_HINGE_LOSS:
(_pairwise_hinge_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.PAIRWISE_LOGISTIC_LOSS:
(_pairwise_logistic_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS:
(_pairwise_soft_zero_one_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.CIRCLE_LOSS:
(_circle_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.SOFTMAX_LOSS:
(_softmax_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.UNIQUE_SOFTMAX_LOSS:
(_unique_softmax_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.SIGMOID_CROSS_ENTROPY_LOSS:
(_sigmoid_cross_entropy_loss, loss_kwargs),
RankingLossKey.MEAN_SQUARED_LOSS: (_mean_squared_loss, loss_kwargs),
RankingLossKey.LIST_MLE_LOSS:
(_list_mle_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.APPROX_NDCG_LOSS: (_approx_ndcg_loss, loss_kwargs),
RankingLossKey.APPROX_MRR_LOSS: (_approx_mrr_loss, loss_kwargs),
RankingLossKey.GUMBEL_APPROX_NDCG_LOSS:
(_approx_ndcg_loss, gbl_loss_kwargs),
RankingLossKey.NEURAL_SORT_CROSS_ENTROPY_LOSS:
(_neural_sort_cross_entropy_loss, loss_kwargs),
RankingLossKey.GUMBEL_NEURAL_SORT_CROSS_ENTROPY_LOSS:
(_neural_sort_cross_entropy_loss, gbl_loss_kwargs),
RankingLossKey.NEURAL_SORT_NDCG_LOSS:
(_neural_sort_ndcg_loss, loss_kwargs),
RankingLossKey.GUMBEL_NEURAL_SORT_NDCG_LOSS:
(_neural_sort_ndcg_loss, gbl_loss_kwargs),
}
# Obtain the list of loss ops.
loss_ops = []
for loss_key in loss_keys:
if loss_key not in key_to_fn:
raise ValueError('Invalid loss_key: {}.'.format(loss_key))
loss_fn, kwargs = key_to_fn[loss_key]
loss_ops.append(loss_fn(**kwargs))
# Compute weighted combination of losses.
if loss_weights:
weighted_losses = []
for loss_op, loss_weight in zip(loss_ops, loss_weights):
weighted_losses.append(tf.multiply(loss_op, loss_weight))
else:
weighted_losses = loss_ops
return tf.add_n(weighted_losses)
return _loss_fn