in tensorflow_recommenders/layers/loss.py [0:0]
def call(self, logits: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Filters logits and labels with per-query hard negative mining.
The result will include logits and labels for num_hard_negatives
negatives as well as the positive candidate.
Args:
logits: [batch_size, number_of_candidates] tensor of logits.
labels: [batch_size, number_of_candidates] one-hot tensor of labels.
Returns:
logits: [batch_size, num_hard_negatives + 1] tensor of logits.
labels: [batch_size, num_hard_negatives + 1] one-hot tensor of labels.
"""
# Number of sampled logits, i.e, the number of hard negatives to be
# sampled (k) + number of true logit (1) per query, capped by batch size.
num_sampled = tf.minimum(self._num_hard_negatives + 1, tf.shape(logits)[1])
# To gather indices of top k negative logits per row (query) in
# logits, true logits need to be excluded. First replace the true
# logits (corresponding to positive labels) with a large score value
# and then select the top k + 1 logits from each
# row so that selected indices include the indices of true logit + top k
# negative logits. This approach is to avoid using inefficient
# tf.boolean_mask() when excluding true logits.
# For each query, get the indices of the logits which have the highest
# k + 1 logit values, including the highest k negative logits and one true
# logit.
_, col_indices = tf.nn.top_k(
logits + labels * MAX_FLOAT, k=num_sampled, sorted=False)
# Gather sampled logits and corresponding labels.
logits = _gather_elements_along_row(logits, col_indices)
labels = _gather_elements_along_row(labels, col_indices)
return logits, labels