def _make_gam_score_fn()

in tensorflow_ranking/python/estimator.py [0:0]


def _make_gam_score_fn(context_hidden_units,
                       example_hidden_units,
                       activation_fn=tf.nn.relu,
                       dropout=None,
                       batch_norm=False,
                       batch_norm_moment=0.999):
  """Returns a scoring fn that outputs a score per example."""
  activation_fn = activation_fn or tf.nn.relu

  def _scoring_fn(context_features, example_features, mode):
    """Defines the scoring fn for GAM.

    Args:
      context_features: (dict) A mapping from context feature names to dense 2-D
        Tensors of shape [batch_size, ...].
      example_features: (dict) A mapping from example feature names to dense 3-D
        Tensors of shape [batch_size, list_size, ...].
      mode: (`tf.estimator.ModeKeys`) TRAIN, EVAL, or PREDICT.

    Returns:
      A Tensor of shape [batch_size, 1] containing per-example scores.
    """

    # Input layer.
    example_feature_names = sorted(list(example_features.keys()))
    context_feature_names = sorted(list(context_features.keys()))
    with tf.compat.v1.name_scope("input_layer"):
      example_input = [(name,
                        tf.compat.v1.layers.flatten(example_features[name]))
                       for name in sorted(list(example_feature_names))]
      context_input = [(name,
                        tf.compat.v1.layers.flatten(context_features[name]))
                       for name in sorted(list(context_feature_names))]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    # Construct a tower for each example feature.  Each tower outputs a
    # scalar value as the sub-score.  All sub-scores are
    # [batch_size * list_size, 1]-shaped tensors and are stored in
    # `sub_logits_list` as a `feature_num`-sized list.
    with tf.compat.v1.name_scope("example_feature_towers"):
      sub_logits_list = []
      for name, input_layer in example_input:
        with tf.compat.v1.name_scope("{}_tower".format(name)):
          cur_layer = input_layer
          if batch_norm:
            cur_layer = tf.compat.v1.layers.batch_normalization(
                cur_layer, training=is_training, momentum=batch_norm_moment)
          sub_logits = _feed_forward_network(
              cur_layer,
              map(int, example_hidden_units),
              output_units=1,
              activation_fn=activation_fn,
              batch_norm=batch_norm,
              batch_norm_moment=batch_norm_moment,
              dropout=dropout,
              is_training=is_training)
          sub_logits = tf.identity(
              sub_logits, name="{}_{}".format(name, _SUBSCORE_POSTFIX))
          sub_logits_list.append(sub_logits)

    # Construct a tower for each context feature.  Each tower outputs a
    # weighting vector of `feature_num`-dim where `feature_num` is the number
    # of example features.  All the vectors are
    # [batch_size * list_size, feature_num] tensors and are stored in
    # `sub_weights_list` with length of number of context feature.
    sub_weights_list = []
    if context_input:
      # Construct a tower per context features.
      with tf.compat.v1.name_scope("context_feature_towers"):
        feature_num = len(sub_logits_list)
        for name, input_layer in context_input:
          with tf.compat.v1.name_scope("{}_tower".format(name)):
            cur_layer = input_layer
            if batch_norm:
              cur_layer = tf.compat.v1.layers.batch_normalization(
                  cur_layer, training=is_training, momentum=batch_norm_moment)
            sub_weights = _feed_forward_network(
                cur_layer,
                map(int, context_hidden_units),
                output_units=feature_num,
                activation_fn=activation_fn,
                batch_norm=batch_norm,
                batch_norm_moment=batch_norm_moment,
                dropout=dropout,
                is_training=is_training)
            sub_weights = tf.math.softmax(
                sub_weights, name="{}_{}".format(name, _SUBWEIGHT_POSTFIX))
            sub_weights_list.append(sub_weights)

    # Construct an additive model from the outputs of all example feature towers
    # `sub_logits_list` weighted by outputs of all context feature towers
    # `sub_weights_list`.  If no context features are provided, the outputs will
    # simply be the sum of `sub_logits_list`.
    if sub_weights_list:
      sub_logits = tf.concat(sub_logits_list, axis=-1)
      feature_weights = tf.math.add_n(sub_weights_list)
      logits = tf.math.reduce_sum(
          input_tensor=sub_logits * feature_weights, axis=-1)
    else:
      logits = tf.math.add_n(sub_logits_list)

    tf.compat.v1.summary.scalar("logits_mean",
                                tf.reduce_mean(input_tensor=logits))
    return logits

  return _scoring_fn