in tensorflow_ranking/python/model.py [0:0]
def _compute_logits_impl(self, context_features, example_features, labels,
mode, params, config):
# Scatter/Gather per-example scores through groupwise comparison. Each
# instance in a mini-batch will form a number of groups. Each group of
# examples are scored by `_score_fn` and scores for individual examples are
# accumulated into logits.
with tf.compat.v1.name_scope('groupwise_dnn_v2'):
batch_size, list_size, is_valid = _infer_sizes(example_features, labels)
# For each example feature, assuming the shape is [batch_size, list_size,
# feature_size], the groups are formed along the 2nd dim. Each group has a
# 'group_size' number of indices in [0, list_size). Based on these
# indices, we can gather the example feature into a sub-tensor for each
# group. The total number of groups we have for a mini-batch is batch_size
# * num_groups. Inside each group, we have a 'group_size' number of
# examples.
self._update_scatter_gather_indices(is_valid, mode, params)
num_groups = tf.shape(input=self._indices_mask)[1]
with tf.compat.v1.name_scope('group_features'):
# For context features, We have shape [batch_size * num_groups, ...].
large_batch_context_features = {}
for name, value in six.iteritems(context_features):
# [batch_size, num_groups, ...].
value = tf.repeat(
tf.expand_dims(value, axis=1), repeats=[num_groups], axis=1)
# [batch_size * num_groups, ...]
large_batch_context_features[name] = utils.reshape_first_ndims(
value, 2, [batch_size * num_groups])
# For example feature, we have shape [batch_size * num_groups,
# group_size, ...].
large_batch_group_features = {}
for name, value in six.iteritems(example_features):
# [batch_size, num_groups, group_size, ...].
value = tf.gather_nd(value, self._feature_gather_indices)
# [batch_size * num_groups, group_size, ...].
large_batch_group_features[name] = utils.reshape_first_ndims(
value, 3, [batch_size * num_groups, self._group_size])
# Do the inference and get scores for the large batch of [batch_size *
# num_groups, logits_size] and reshape them to [batch_size, num_groups,
# logits_size].
with tf.compat.v1.variable_scope('group_score'):
scores = self._score_fn(large_batch_context_features,
large_batch_group_features, mode, params,
config)
with tf.compat.v1.name_scope('accumulate_scores'):
# Reset invalid scores to 0 based on mask.
scores_mask = tf.tile(
tf.expand_dims(self._indices_mask, 2),
multiples=[1, 1,
tf.shape(input=self._score_scatter_indices)[2]],
name='tile_scores_mask')
counts = tf.scatter_nd(self._score_scatter_indices,
tf.cast(scores_mask, tf.float32),
[batch_size, list_size])
def _accumulate_scores(task_scores):
"""A subroutine to accumulate scores for a single Tensor."""
task_scores = tf.reshape(
task_scores,
tf.shape(input=self._score_scatter_indices)[0:3])
task_scores = tf.compat.v1.where(scores_mask, task_scores,
tf.zeros_like(task_scores))
# Scatter scores from [batch_size, num_groups, group_size] to
# [batch_size, list_size].
task_logits = tf.scatter_nd(self._score_scatter_indices, task_scores,
[batch_size, list_size])
# Use average.
task_logits = tf.compat.v1.div_no_nan(task_logits, counts)
return task_logits
if isinstance(scores, dict):
logits = {}
for name, task_scores in six.iteritems(scores):
logits[name] = _accumulate_scores(task_scores)
else:
logits = _accumulate_scores(scores)
return logits