def sample()

in tensorflow_ranking/python/losses_impl.py [0:0]


  def sample(self, labels, logits, weights=None):
    """Samples scores from Concrete(logits).

    If the sampler was constructed with `ragged=True` this method expects
    `labels`, `logits` and item-wise `weights` to be a `RaggedTensor`.

    Args:
      labels: A `Tensor` or `RaggedTensor` with shape [batch_size, list_size]
        same as `logits`, representing graded relevance. Or in the diversity
        tasks, a `Tensor` (or `RaggedTensor`) with shape [batch_size, list_size,
        subtopic_size]. Each value represents relevance to a subtopic, 1 for
        relevent subtopic, 0 for irrelevant, and -1 for paddings. When the
        actual subtopic number of a query is smaller than the `subtopic_size`,
        `labels` will be padded to `subtopic_size` with -1.
      logits: A `Tensor` or `RaggedTensor` with shape [batch_size, list_size].
        Each value is the ranking score of the corresponding item.
      weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
        weights, or a `Tensor` or `RaggedTensor` with shape [batch_size,
        list_size] for item-wise weights. If None, the weight of a list in the
        mini-batch is set to the sum of the labels of the items in that list.

    Returns:
      A tuple of expanded labels, logits, and weights where the first dimension
      is now batch_size * sample_size. Logit Tensors are sampled from
      Concrete(logits) while labels and weights are simply tiled so the
      resulting
      Tensor has the updated dimensions.
    """
    with tf.compat.v1.name_scope(self._name, 'gumbel_softmax_sample',
                                 (labels, logits, weights)):
      # Convert ragged tensors to dense and construct a mask.
      if self._ragged:
        is_weights_ragged = isinstance(weights, tf.RaggedTensor)
        labels, logits, weights, mask = utils.ragged_to_dense(
            labels, logits, weights)

      batch_size = tf.shape(input=labels)[0]
      list_size = tf.shape(input=labels)[1]

      # Expand labels.
      expanded_labels = tf.expand_dims(labels, 1)
      expanded_labels = tf.repeat(expanded_labels, [self._sample_size], axis=1)
      expanded_labels = utils.reshape_first_ndims(
          expanded_labels, 2, [batch_size * self._sample_size])

      # Sample logits from Concrete(logits).
      sampled_logits = tf.expand_dims(logits, 1)
      sampled_logits = tf.tile(sampled_logits, [1, self._sample_size, 1])
      sampled_logits += _sample_gumbel(
          [batch_size, self._sample_size, list_size], seed=self._seed)
      sampled_logits = tf.reshape(sampled_logits,
                                  [batch_size * self._sample_size, list_size])

      is_label_valid = utils.is_label_valid(expanded_labels)
      if is_label_valid.shape.rank > 2:
        is_label_valid = tf.reduce_any(is_label_valid, axis=-1)
      sampled_logits = tf.compat.v1.where(
          is_label_valid, sampled_logits / self._temperature,
          tf.math.log(1e-20) * tf.ones_like(sampled_logits))
      sampled_logits = tf.math.log(tf.nn.softmax(sampled_logits) + 1e-20)

      expanded_weights = weights
      if expanded_weights is not None:
        true_fn = lambda: tf.expand_dims(tf.expand_dims(expanded_weights, 1), 1)
        false_fn = lambda: tf.expand_dims(expanded_weights, 1)
        expanded_weights = tf.cond(
            pred=tf.math.equal(tf.rank(expanded_weights), 1),
            true_fn=true_fn,
            false_fn=false_fn)
        expanded_weights = tf.tile(expanded_weights, [1, self._sample_size, 1])
        expanded_weights = tf.reshape(expanded_weights,
                                      [batch_size * self._sample_size, -1])

      # Convert dense tensors back to ragged.
      if self._ragged:
        # Construct expanded mask for the number of samples.
        expanded_mask = tf.expand_dims(mask, 1)
        expanded_mask = tf.repeat(expanded_mask, [self._sample_size], axis=1)
        expanded_mask = tf.reshape(
            expanded_mask, [batch_size * self._sample_size, list_size])
        # Convert labels and sampled logits to ragged tensors.
        expanded_labels = tf.ragged.boolean_mask(expanded_labels, expanded_mask)
        sampled_logits = tf.ragged.boolean_mask(sampled_logits, expanded_mask)
        # If ragged weights were provided, convert dense weights back to ragged.
        if is_weights_ragged:
          expanded_weights = tf.ragged.boolean_mask(
              expanded_weights, expanded_mask)

      return expanded_labels, sampled_logits, expanded_weights