def make_gam_ranking_estimator()

in tensorflow_ranking/python/estimator.py [0:0]


def make_gam_ranking_estimator(
    example_feature_columns,
    example_hidden_units,
    context_feature_columns=None,
    context_hidden_units=None,
    optimizer=None,
    learning_rate=0.05,
    loss="approx_ndcg_loss",
    loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE,
    activation_fn=tf.nn.relu,
    dropout=None,
    use_batch_norm=False,
    batch_norm_moment=0.999,
    model_dir=None,
    checkpoint_secs=120,
    num_checkpoints=1000,
    listwise_inference=False):
  """Builds an `Estimator` instance with GAM scoring function.

  See the comment of `GAMEstimatorBuilder` class for more details.

  Args:
    example_feature_columns: (dict) A dict containing all the example feature
      columns used by the model. Keys are feature names, and values are
      instances of classes derived from `_FeatureColumn`.
    example_hidden_units: (list) Iterable of number hidden units per layer for
      example features. All layers are fully connected. Ex. `[64, 32]` means
      first layer has 64 nodes and second one has 32.
    context_feature_columns: (dict) A dict containing all the context feature
      columns used by the model. See `example_feature_columns`.
    context_hidden_units: (list) Iterable of number hidden units per layer for
      context features. See `example_hidden_units`.
    optimizer: (`tf.Optimizer`) An `Optimizer` object for model optimzation. If
      `None`, an Adagard optimizer with `learning_rate` will be created.
    learning_rate: (float) Only used if `optimizer` is a string. Defaults to
      0.05.
    loss: (str) A string to decide the loss function used in training. See
      `RankingLossKey` class for possible values.
    loss_reduction: (str) An enum of strings indicating the loss reduction type.
      See type definition in the `tf.compat.v1.losses.Reduction`.
    activation_fn: Activation function applied to each layer. If `None`, will
      use `tf.nn.relu`.
    dropout: (float) When not `None`, the probability we will drop out a given
      coordinate.
    use_batch_norm: (bool) Whether to use batch normalization after each hidden
      layer.
    batch_norm_moment: (float) Momentum for the moving average in batch
      normalization.
    model_dir: (str) Directory to save model parameters, graph and etc. This can
      also be used to load checkpoints from the directory into a estimator to
      continue training a previously saved model.
    checkpoint_secs: (int) Time interval (in seconds) to save checkpoints.
    num_checkpoints: (int) Number of checkpoints to keep.
    listwise_inference: (bool) Whether the inference will be performed with the
      listwise data format such as `ExampleListWithContext`.

  Returns:
    An `Estimator` with GAM scoring function.
  """

  scoring_function = _make_gam_score_fn(
      context_hidden_units,
      example_hidden_units,
      activation_fn=activation_fn,
      dropout=dropout,
      batch_norm=use_batch_norm,
      batch_norm_moment=batch_norm_moment)

  hparams = dict(
      model_dir=model_dir,
      learning_rate=learning_rate,
      listwise_inference=listwise_inference,
      loss=loss,
      checkpoint_secs=checkpoint_secs,
      num_checkpoints=num_checkpoints)

  return GAMEstimatorBuilder(
      context_feature_columns,
      example_feature_columns,
      optimizer=optimizer,
      scoring_function=scoring_function,
      loss_reduction=loss_reduction,
      hparams=hparams).make_estimator()