in tensorflow_ranking/python/keras/network.py [0:0]
def listwise_scoring(scorer,
context_features,
example_features,
training=None,
mask=None):
"""Listwise scoring op for context and example features.
Args:
scorer: A callable (e.g., A keras layer instance, a function) for scoring
with the following signature:
* Args:
`context_features`: (dict) A dict of Tensors with the shape [batch_size,
...].
`example_features`: (dict) A dict of Tensors with the shape [batch_size,
...].
`training`: (bool) whether in training or inference mode.
* Returns: The computed logits, a Tensor of shape [batch_size,
output_size].
context_features: (dict) context feature names to dense 2D tensors of shape
[batch_size, ...].
example_features: (dict) example feature names to dense 3D tensors of shape
[batch_size, list_size, ...].
training: (bool) whether in train or inference mode.
mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which
is True for a valid example and False for invalid one.
Returns:
(tf.Tensor) A score tensor of shape [batch_size, list_size, output_size].
Raises:
ValueError: If example features is None or an empty dict.
"""
# Raise error if example features is None or empty dict.
if not example_features:
raise ValueError('Need a valid example feature.')
tensor = next(six.itervalues(example_features))
batch_size = tf.shape(tensor)[0]
list_size = tf.shape(tensor)[1]
if mask is None:
mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool)
nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask)
# Expand context features to be of [batch_size, list_size, ...].
large_batch_context_features = {}
for name, tensor in six.iteritems(context_features):
x = tf.expand_dims(input=tensor, axis=1)
x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1)
large_batch_context_features[name] = utils.reshape_first_ndims(
x, 2, [batch_size * list_size])
large_batch_example_features = {}
for name, tensor in six.iteritems(example_features):
# Replace invalid example features with valid ones.
padded_tensor = tf.gather_nd(tensor, nd_indices)
large_batch_example_features[name] = utils.reshape_first_ndims(
padded_tensor, 2, [batch_size * list_size])
# Get scores for large batch.
scores = scorer(
large_batch_context_features,
large_batch_example_features,
training=training)
scores = tf.reshape(scores, shape=[batch_size, list_size, -1])
# Apply nd_mask to zero out invalid entries.
# Expand dimension and use broadcasting for filtering.
expanded_nd_mask = tf.expand_dims(nd_mask, axis=2)
scores = tf.where(expanded_nd_mask, scores, tf.zeros_like(scores))
return scores