in tensorflow_ranking/python/utils.py [0:0]
def sort_by_scores(scores,
features_list,
topn=None,
shuffle_ties=True,
seed=None,
mask=None):
"""Sorts list of features according to per-example scores.
Args:
scores: A `Tensor` of shape [batch_size, list_size] representing the
per-example scores.
features_list: A list of `Tensor`s to be sorted. The shape of the `Tensor`
can be [batch_size, list_size] or [batch_size, list_size, feature_dims].
The latter is applicable for example features.
topn: An integer as the cutoff of examples in the sorted list.
shuffle_ties: A boolean. If True, randomly shuffle before the sorting.
seed: The ops-level random seed used when `shuffle_ties` is True.
mask: An optional `Tensor` of shape [batch_size, list_size] representing
which entries are valid for sorting. Invalid entries will be pushed to the
end.
Returns:
A list of `Tensor`s as the list of sorted features by `scores`.
"""
with tf.compat.v1.name_scope(name='sort_by_scores'):
scores = tf.cast(scores, tf.float32)
scores.get_shape().assert_has_rank(2)
list_size = tf.shape(input=scores)[1]
if topn is None:
topn = list_size
topn = tf.minimum(topn, list_size)
# Set invalid entries (those whose mask value is False) to the minimal value
# of scores so they will be placed last during sort ops.
if mask is not None:
scores = tf.where(mask, scores, tf.reduce_min(scores))
# Shuffle scores to break ties and/or push invalid entries (according to
# mask) to the end.
shuffle_ind = None
if shuffle_ties or mask is not None:
shuffle_ind = _get_shuffle_indices(
tf.shape(input=scores), mask, shuffle_ties=shuffle_ties, seed=seed)
scores = tf.gather(scores, shuffle_ind, batch_dims=1, axis=1)
# Perform sort and return sorted feature_list entries.
_, indices = tf.math.top_k(scores, topn, sorted=True)
if shuffle_ind is not None:
indices = tf.gather(shuffle_ind, indices, batch_dims=1, axis=1)
return [tf.gather(f, indices, batch_dims=1, axis=1) for f in features_list]