in tensorflow_ranking/python/head.py [0:0]
def __init__(self, heads, head_weights=None):
"""Constructor.
Args:
heads: A tuple or list of `_RankingHead`.
head_weights: A tuple or list of weights.
"""
if not heads:
raise ValueError('Must specify heads. Given: {}'.format(heads))
if head_weights:
if len(head_weights) != len(heads):
raise ValueError(
'heads and head_weights must have the same size. '
'Given len(heads): {}. Given len(head_weights): {}.'.format(
len(heads), len(head_weights)))
for head in heads:
if head.name is None:
raise ValueError(
'All given heads must have name specified. Given: {}'.format(head))
self._heads = tuple(heads)
self._head_weights = tuple(head_weights) if head_weights else tuple()
# TODO: Figure out a better way to set train_op_fn and optimizer
# for _MultiRankingHead.
# pylint: disable=protected-access
tf.compat.v1.logging.info(
'Use the train_op_fn and optimizer from the first head.')
self._train_op_fn = self._heads[0]._train_op_fn
self._optimizer = self._heads[0]._optimizer