def _build_metric_impl()

in easy_rec/python/model/rank_model.py [0:0]


  def _build_metric_impl(self,
                         metric,
                         loss_type,
                         label_name,
                         num_class=1,
                         suffix=''):
    if not isinstance(loss_type, set):
      loss_type = {loss_type}
    from easy_rec.python.core.easyrec_metrics import metrics_tf
    from easy_rec.python.core import metrics as metrics_lib
    binary_loss_set = {
        LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
        LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS,
        LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS,
        LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
        LossType.LISTWISE_RANK_LOSS, LossType.ZILN_LOSS
    }
    metric_dict = {}
    if metric.WhichOneof('metric') == 'auc':
      assert loss_type & binary_loss_set
      if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
        label = tf.to_int64(self._labels[label_name])
        metric_dict['auc' + suffix] = metrics_tf.auc(
            label,
            self._prediction_dict['probs' + suffix],
            num_thresholds=metric.auc.num_thresholds)
      elif num_class == 2:
        label = tf.to_int64(self._labels[label_name])
        metric_dict['auc' + suffix] = metrics_tf.auc(
            label,
            self._prediction_dict['probs' + suffix][:, 1],
            num_thresholds=metric.auc.num_thresholds)
      else:
        raise ValueError('Wrong class number')
    elif metric.WhichOneof('metric') == 'gauc':
      assert loss_type & binary_loss_set
      if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
        label = tf.to_int64(self._labels[label_name])
        uids = self._feature_dict[metric.gauc.uid_field]
        if isinstance(uids, tf.sparse.SparseTensor):
          uids = tf.sparse_to_dense(
              uids.indices, uids.dense_shape, uids.values, default_value='')
          uids = tf.reshape(uids, [-1])
        metric_dict['gauc' + suffix] = metrics_lib.gauc(
            label,
            self._prediction_dict['probs' + suffix],
            uids=uids,
            reduction=metric.gauc.reduction)
      elif num_class == 2:
        label = tf.to_int64(self._labels[label_name])
        metric_dict['gauc' + suffix] = metrics_lib.gauc(
            label,
            self._prediction_dict['probs' + suffix][:, 1],
            uids=self._feature_dict[metric.gauc.uid_field],
            reduction=metric.gauc.reduction)
      else:
        raise ValueError('Wrong class number')
    elif metric.WhichOneof('metric') == 'session_auc':
      assert loss_type & binary_loss_set
      if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
        label = tf.to_int64(self._labels[label_name])
        metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
            label,
            self._prediction_dict['probs' + suffix],
            session_ids=self._feature_dict[metric.session_auc.session_id_field],
            reduction=metric.session_auc.reduction)
      elif num_class == 2:
        label = tf.to_int64(self._labels[label_name])
        metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
            label,
            self._prediction_dict['probs' + suffix][:, 1],
            session_ids=self._feature_dict[metric.session_auc.session_id_field],
            reduction=metric.session_auc.reduction)
      else:
        raise ValueError('Wrong class number')
    elif metric.WhichOneof('metric') == 'max_f1':
      assert loss_type & binary_loss_set
      if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
        label = tf.to_int64(self._labels[label_name])
        metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
            label, self._prediction_dict['logits' + suffix])
      elif num_class == 2:
        label = tf.to_int64(self._labels[label_name])
        metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
            label, self._prediction_dict['logits' + suffix][:, 1])
      else:
        raise ValueError('Wrong class number')
    elif metric.WhichOneof('metric') == 'recall_at_topk':
      assert loss_type & binary_loss_set
      assert num_class > 1
      label = tf.to_int64(self._labels[label_name])
      metric_dict['recall_at_topk' + suffix] = metrics_tf.recall_at_k(
          label, self._prediction_dict['logits' + suffix],
          metric.recall_at_topk.topk)
    elif metric.WhichOneof('metric') == 'mean_absolute_error':
      label = tf.to_float(self._labels[label_name])
      if loss_type & {
          LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
      }:
        metric_dict['mean_absolute_error' +
                    suffix] = metrics_tf.mean_absolute_error(
                        label, self._prediction_dict['y' + suffix])
      elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
        metric_dict['mean_absolute_error' +
                    suffix] = metrics_tf.mean_absolute_error(
                        label, self._prediction_dict['probs' + suffix])
      else:
        assert False, 'mean_absolute_error is not supported for this model'
    elif metric.WhichOneof('metric') == 'mean_squared_error':
      label = tf.to_float(self._labels[label_name])
      if loss_type & {
          LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
      }:
        metric_dict['mean_squared_error' +
                    suffix] = metrics_tf.mean_squared_error(
                        label, self._prediction_dict['y' + suffix])
      elif num_class == 1 and loss_type & binary_loss_set:
        metric_dict['mean_squared_error' +
                    suffix] = metrics_tf.mean_squared_error(
                        label, self._prediction_dict['probs' + suffix])
      else:
        assert False, 'mean_squared_error is not supported for this model'
    elif metric.WhichOneof('metric') == 'root_mean_squared_error':
      label = tf.to_float(self._labels[label_name])
      if loss_type & {
          LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
      }:
        metric_dict['root_mean_squared_error' +
                    suffix] = metrics_tf.root_mean_squared_error(
                        label, self._prediction_dict['y' + suffix])
      elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
        metric_dict['root_mean_squared_error' +
                    suffix] = metrics_tf.root_mean_squared_error(
                        label, self._prediction_dict['probs' + suffix])
      else:
        assert False, 'root_mean_squared_error is not supported for this model'
    elif metric.WhichOneof('metric') == 'accuracy':
      assert loss_type & {LossType.CLASSIFICATION}
      assert num_class > 1
      label = tf.to_int64(self._labels[label_name])
      metric_dict['accuracy' + suffix] = metrics_tf.accuracy(
          label, self._prediction_dict['y' + suffix])
    return metric_dict