def listwise_scoring()

in tensorflow_ranking/python/keras/network.py [0:0]


def listwise_scoring(scorer,
                     context_features,
                     example_features,
                     training=None,
                     mask=None):
  """Listwise scoring op for context and example features.

  Args:
    scorer: A callable (e.g., A keras layer instance, a function) for scoring
    with the following signature:
      * Args:
        `context_features`: (dict) A dict of Tensors with the shape [batch_size,
          ...].
        `example_features`: (dict) A dict of Tensors with the shape [batch_size,
          ...].
        `training`: (bool) whether in training or inference mode.
      * Returns: The computed logits, a Tensor of shape [batch_size,
        output_size].
    context_features: (dict) context feature names to dense 2D tensors of shape
      [batch_size, ...].
    example_features: (dict) example feature names to dense 3D tensors of shape
      [batch_size, list_size, ...].
    training: (bool) whether in train or inference mode.
    mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which
      is True for a valid example and False for invalid one.

  Returns:
    (tf.Tensor) A score tensor of shape [batch_size, list_size, output_size].

  Raises:
    ValueError: If example features is None or an empty dict.
  """
  # Raise error if example features is None or empty dict.
  if not example_features:
    raise ValueError('Need a valid example feature.')

  tensor = next(six.itervalues(example_features))
  batch_size = tf.shape(tensor)[0]
  list_size = tf.shape(tensor)[1]
  if mask is None:
    mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool)
  nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask)

  # Expand context features to be of [batch_size, list_size, ...].
  large_batch_context_features = {}
  for name, tensor in six.iteritems(context_features):
    x = tf.expand_dims(input=tensor, axis=1)
    x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1)
    large_batch_context_features[name] = utils.reshape_first_ndims(
        x, 2, [batch_size * list_size])

  large_batch_example_features = {}
  for name, tensor in six.iteritems(example_features):
    # Replace invalid example features with valid ones.
    padded_tensor = tf.gather_nd(tensor, nd_indices)
    large_batch_example_features[name] = utils.reshape_first_ndims(
        padded_tensor, 2, [batch_size * list_size])

  # Get scores for large batch.
  scores = scorer(
      large_batch_context_features,
      large_batch_example_features,
      training=training)
  scores = tf.reshape(scores, shape=[batch_size, list_size, -1])

  # Apply nd_mask to zero out invalid entries.
  # Expand dimension and use broadcasting for filtering.
  expanded_nd_mask = tf.expand_dims(nd_mask, axis=2)
  scores = tf.where(expanded_nd_mask, scores, tf.zeros_like(scores))

  return scores