def call()

in tensorflow_recommenders/tasks/retrieval.py [0:0]


  def call(self,
           query_embeddings: tf.Tensor,
           candidate_embeddings: tf.Tensor,
           sample_weight: Optional[tf.Tensor] = None,
           candidate_sampling_probability: Optional[tf.Tensor] = None,
           candidate_ids: Optional[tf.Tensor] = None,
           compute_metrics: bool = True,
           compute_batch_metrics: bool = True) -> tf.Tensor:
    """Computes the task loss and metrics.

    The main argument are pairs of query and candidate embeddings: the first row
    of query_embeddings denotes a query for which the candidate from the first
    row of candidate embeddings was selected by the user.

    The task will try to maximize the affinity of these query, candidate pairs
    while minimizing the affinity between the query and candidates belonging
    to other queries in the batch.

    Args:
      query_embeddings: [num_queries, embedding_dim] tensor of query
        representations.
      candidate_embeddings: [num_queries, embedding_dim] tensor of candidate
        representations.
      sample_weight: [num_queries] tensor of sample weights.
      candidate_sampling_probability: Optional tensor of candidate sampling
        probabilities. When given will be be used to correct the logits to
        reflect the sampling probability of negative candidates.
      candidate_ids: Optional tensor containing candidate ids. When given,
        factorized top-K evaluation will be id-based rather than score-based.
      compute_metrics: Whether to compute metrics. Set this to False
        during training for faster training.
      compute_batch_metrics: Whether to compute batch level metrics which
        includes both batch_metrics and loss_metrics.
    Returns:
      loss: Tensor of loss values.
    """

    scores = tf.linalg.matmul(
        query_embeddings, candidate_embeddings, transpose_b=True)

    num_queries = tf.shape(scores)[0]
    num_candidates = tf.shape(scores)[1]

    labels = tf.eye(num_queries, num_candidates)

    if self._temperature is not None:
      scores = scores / self._temperature

    if candidate_sampling_probability is not None:
      scores = layers.loss.SamplingProbablityCorrection()(
          scores, candidate_sampling_probability)

    if self._remove_accidental_hits:
      if candidate_ids is None:
        raise ValueError(
            "When accidental hit removal is enabled, candidate ids "
            "must be supplied."
        )
      scores = layers.loss.RemoveAccidentalHits()(labels, scores, candidate_ids)

    if self._num_hard_negatives is not None:
      scores, labels = layers.loss.HardNegativeMining(self._num_hard_negatives)(
          scores,
          labels)

    loss = self._loss(y_true=labels, y_pred=scores, sample_weight=sample_weight)

    if not compute_metrics and not compute_batch_metrics:
      return loss

    update_ops = []

    if self._factorized_metrics is not None and compute_metrics:
      update_ops.append(
          self._factorized_metrics.update_state(
              query_embeddings,
              # Slice to the size of query embeddings if `candidate_embeddings`
              # contains extra negatives.
              candidate_embeddings[:tf.shape(query_embeddings)[0]],
              true_candidate_ids=candidate_ids)
      )
    if compute_batch_metrics:
      for metric in self._batch_metrics:
        update_ops.append(metric.update_state(labels, scores))

      for metric in self._loss_metrics:
        update_ops.append(
            metric.update_state(loss, sample_weight=sample_weight))

    with tf.control_dependencies(update_ops):
      return tf.identity(loss)