in tensorflow_ranking/python/head.py [0:0]
def _check_logits_and_labels(self, logits, labels=None):
"""Validates the keys of logits and labels."""
head_names = []
for head in self._heads:
head_names.append(head.name)
if len(head_names) != len(set(head_names)):
raise ValueError('Duplicated names in heads.')
# Check the logits keys.
if not isinstance(logits, dict):
raise ValueError('logits in _MultiRankingHead should be a dict.')
logits_missing_names = list(set(head_names) - set(list(logits)))
if logits_missing_names:
raise ValueError('logits has missing values for head(s): {}.'.format(
logits_missing_names))
# Check the labels keys.
if labels is not None:
if not isinstance(labels, dict):
raise ValueError('labels in _MultiRankingHead should be a dict.')
labels_missing_names = list(set(head_names) - set(list(labels)))
if labels_missing_names:
raise ValueError('labels has missing values for head(s): {}.'.format(
labels_missing_names))