def _compute_logits_impl()

in tensorflow_ranking/python/model.py [0:0]


  def _compute_logits_impl(self, context_features, example_features, labels,
                           mode, params, config):
    # Scatter/Gather per-example scores through groupwise comparison. Each
    # instance in a mini-batch will form a number of groups. Each group of
    # examples are scored by `_score_fn` and scores for individual examples are
    # accumulated into logits.
    with tf.compat.v1.name_scope('groupwise_dnn_v2'):
      batch_size, list_size, is_valid = _infer_sizes(example_features, labels)
      # For each example feature, assuming the shape is [batch_size, list_size,
      # feature_size], the groups are formed along the 2nd dim. Each group has a
      # 'group_size' number of indices in [0, list_size). Based on these
      # indices, we can gather the example feature into a sub-tensor for each
      # group. The total number of groups we have for a mini-batch is batch_size
      # * num_groups. Inside each group, we have a 'group_size' number of
      # examples.
      self._update_scatter_gather_indices(is_valid, mode, params)
      num_groups = tf.shape(input=self._indices_mask)[1]

      with tf.compat.v1.name_scope('group_features'):
        # For context features, We have shape [batch_size * num_groups, ...].
        large_batch_context_features = {}
        for name, value in six.iteritems(context_features):
          # [batch_size, num_groups, ...].
          value = tf.repeat(
              tf.expand_dims(value, axis=1), repeats=[num_groups], axis=1)
          # [batch_size * num_groups, ...]
          large_batch_context_features[name] = utils.reshape_first_ndims(
              value, 2, [batch_size * num_groups])

        # For example feature, we have shape [batch_size * num_groups,
        # group_size, ...].
        large_batch_group_features = {}
        for name, value in six.iteritems(example_features):
          # [batch_size, num_groups, group_size, ...].
          value = tf.gather_nd(value, self._feature_gather_indices)
          # [batch_size * num_groups, group_size, ...].
          large_batch_group_features[name] = utils.reshape_first_ndims(
              value, 3, [batch_size * num_groups, self._group_size])

      # Do the inference and get scores for the large batch of [batch_size *
      # num_groups, logits_size] and reshape them to [batch_size, num_groups,
      # logits_size].
      with tf.compat.v1.variable_scope('group_score'):
        scores = self._score_fn(large_batch_context_features,
                                large_batch_group_features, mode, params,
                                config)

      with tf.compat.v1.name_scope('accumulate_scores'):
        # Reset invalid scores to 0 based on mask.
        scores_mask = tf.tile(
            tf.expand_dims(self._indices_mask, 2),
            multiples=[1, 1,
                       tf.shape(input=self._score_scatter_indices)[2]],
            name='tile_scores_mask')
        counts = tf.scatter_nd(self._score_scatter_indices,
                               tf.cast(scores_mask, tf.float32),
                               [batch_size, list_size])

        def _accumulate_scores(task_scores):
          """A subroutine to accumulate scores for a single Tensor."""
          task_scores = tf.reshape(
              task_scores,
              tf.shape(input=self._score_scatter_indices)[0:3])
          task_scores = tf.compat.v1.where(scores_mask, task_scores,
                                           tf.zeros_like(task_scores))
          # Scatter scores from [batch_size, num_groups, group_size] to
          # [batch_size, list_size].
          task_logits = tf.scatter_nd(self._score_scatter_indices, task_scores,
                                      [batch_size, list_size])
          # Use average.
          task_logits = tf.compat.v1.div_no_nan(task_logits, counts)
          return task_logits

        if isinstance(scores, dict):
          logits = {}
          for name, task_scores in six.iteritems(scores):
            logits[name] = _accumulate_scores(task_scores)
        else:
          logits = _accumulate_scores(scores)

    return logits