def _check_logits_and_labels()

in tensorflow_ranking/python/head.py [0:0]


  def _check_logits_and_labels(self, logits, labels=None):
    """Validates the keys of logits and labels."""
    head_names = []
    for head in self._heads:
      head_names.append(head.name)

    if len(head_names) != len(set(head_names)):
      raise ValueError('Duplicated names in heads.')

    # Check the logits keys.
    if not isinstance(logits, dict):
      raise ValueError('logits in _MultiRankingHead should be a dict.')
    logits_missing_names = list(set(head_names) - set(list(logits)))
    if logits_missing_names:
      raise ValueError('logits has missing values for head(s): {}.'.format(
          logits_missing_names))

    # Check the labels keys.
    if labels is not None:
      if not isinstance(labels, dict):
        raise ValueError('labels in _MultiRankingHead should be a dict.')
      labels_missing_names = list(set(head_names) - set(list(labels)))
      if labels_missing_names:
        raise ValueError('labels has missing values for head(s): {}.'.format(
            labels_missing_names))