in tensorflow_ranking/python/head.py [0:0]
def _merge_predict_export_outputs(self, all_estimator_spec):
"""Merges list of `EstimatorSpec` export_outputs for PREDICT.
For each individual head, its _DEFAULT_SERVING_KEY and _PREDICT_SERVING_KEY
are extracted and merged for `export_outputs` in PREDICT mode of
`EstimatorSpec`. By default, the first head is served.
Args:
all_estimator_spec: list of `EstimatorSpec` for the individual heads.
Returns:
A dict of merged export_outputs from all heads for PREDICT.
"""
# The first head is used for serving by default.
export_outputs = {
_DEFAULT_SERVING_KEY:
_default_export_output(all_estimator_spec[0].export_outputs,
self._heads[0].name),
}
merged_predict_outputs = {}
for head, spec in zip(self._heads, all_estimator_spec):
for k, v in six.iteritems(spec.export_outputs):
# Collect default serving key for export_outputs
key = (
head.name if k == _DEFAULT_SERVING_KEY else '{}/{}'.format(
head.name, k))
export_outputs[key] = v
# Collect predict serving key for merged_predict_outputs
if (k == _PREDICT_SERVING_KEY and
isinstance(v, tf.estimator.export.PredictOutput)):
for kp, vp in six.iteritems(v.outputs):
merged_predict_outputs['{}/{}'.format(head.name, kp)] = vp
export_outputs[_PREDICT_SERVING_KEY] = (
tf.estimator.export.PredictOutput(merged_predict_outputs))
return export_outputs