in tensorflow_ranking/python/estimator.py [0:0]
def _model_fn(self):
"""Wraps model_fn with additional signatures of subscores."""
def _gam_model_fn(features, labels, mode, params, config):
"""Redefines the model_fn for GAM to include subscore signatures."""
estimator_spec = super(GAMEstimatorBuilder,
self)._model_fn()(features, labels, mode, params,
config)
if mode == tf.estimator.ModeKeys.PREDICT:
# Export subscores of each feature. Find nodes ending with
# `_SUBSCORE_POSTFIX` and `_SUBWEIGHT_POSTFIX` and create signatures
# with their corresponding tensors as outputs. Signatures for example
# feature sub-scores are regression signatures, and signatures for
# context feature weighting vectors are prediction signatures.
subscore_signatures = {}
for node in tf.compat.v1.get_default_graph().as_graph_def().node:
if node.name.endswith(_SUBSCORE_POSTFIX):
subscore_name = node.name[node.name.rfind("/") + 1:]
subscore_tensor = (
tf.compat.v1.get_default_graph().get_tensor_by_name(
"{}:0".format(node.name)))
subscore_signatures[subscore_name] = (
tf.estimator.export.RegressionOutput(subscore_tensor))
elif node.name.endswith(_SUBWEIGHT_POSTFIX):
subscore_name = node.name[node.name.rfind("/") + 1:]
subscore_tensor = (
tf.compat.v1.get_default_graph().get_tensor_by_name(
"{}:0".format(node.name)))
subscore_signatures[subscore_name] = (
tf.estimator.export.PredictOutput(subscore_tensor))
estimator_spec.export_outputs.update(subscore_signatures)
return estimator_spec
return _gam_model_fn