in tensorflow_ranking/python/estimator.py [0:0]
def make_gam_ranking_estimator(
example_feature_columns,
example_hidden_units,
context_feature_columns=None,
context_hidden_units=None,
optimizer=None,
learning_rate=0.05,
loss="approx_ndcg_loss",
loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE,
activation_fn=tf.nn.relu,
dropout=None,
use_batch_norm=False,
batch_norm_moment=0.999,
model_dir=None,
checkpoint_secs=120,
num_checkpoints=1000,
listwise_inference=False):
"""Builds an `Estimator` instance with GAM scoring function.
See the comment of `GAMEstimatorBuilder` class for more details.
Args:
example_feature_columns: (dict) A dict containing all the example feature
columns used by the model. Keys are feature names, and values are
instances of classes derived from `_FeatureColumn`.
example_hidden_units: (list) Iterable of number hidden units per layer for
example features. All layers are fully connected. Ex. `[64, 32]` means
first layer has 64 nodes and second one has 32.
context_feature_columns: (dict) A dict containing all the context feature
columns used by the model. See `example_feature_columns`.
context_hidden_units: (list) Iterable of number hidden units per layer for
context features. See `example_hidden_units`.
optimizer: (`tf.Optimizer`) An `Optimizer` object for model optimzation. If
`None`, an Adagard optimizer with `learning_rate` will be created.
learning_rate: (float) Only used if `optimizer` is a string. Defaults to
0.05.
loss: (str) A string to decide the loss function used in training. See
`RankingLossKey` class for possible values.
loss_reduction: (str) An enum of strings indicating the loss reduction type.
See type definition in the `tf.compat.v1.losses.Reduction`.
activation_fn: Activation function applied to each layer. If `None`, will
use `tf.nn.relu`.
dropout: (float) When not `None`, the probability we will drop out a given
coordinate.
use_batch_norm: (bool) Whether to use batch normalization after each hidden
layer.
batch_norm_moment: (float) Momentum for the moving average in batch
normalization.
model_dir: (str) Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
checkpoint_secs: (int) Time interval (in seconds) to save checkpoints.
num_checkpoints: (int) Number of checkpoints to keep.
listwise_inference: (bool) Whether the inference will be performed with the
listwise data format such as `ExampleListWithContext`.
Returns:
An `Estimator` with GAM scoring function.
"""
scoring_function = _make_gam_score_fn(
context_hidden_units,
example_hidden_units,
activation_fn=activation_fn,
dropout=dropout,
batch_norm=use_batch_norm,
batch_norm_moment=batch_norm_moment)
hparams = dict(
model_dir=model_dir,
learning_rate=learning_rate,
listwise_inference=listwise_inference,
loss=loss,
checkpoint_secs=checkpoint_secs,
num_checkpoints=num_checkpoints)
return GAMEstimatorBuilder(
context_feature_columns,
example_feature_columns,
optimizer=optimizer,
scoring_function=scoring_function,
loss_reduction=loss_reduction,
hparams=hparams).make_estimator()