in easy_rec/python/model/rank_model.py [0:0]
def _build_metric_impl(self,
metric,
loss_type,
label_name,
num_class=1,
suffix=''):
if not isinstance(loss_type, set):
loss_type = {loss_type}
from easy_rec.python.core.easyrec_metrics import metrics_tf
from easy_rec.python.core import metrics as metrics_lib
binary_loss_set = {
LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS,
LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS,
LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
LossType.LISTWISE_RANK_LOSS, LossType.ZILN_LOSS
}
metric_dict = {}
if metric.WhichOneof('metric') == 'auc':
assert loss_type & binary_loss_set
if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
label = tf.to_int64(self._labels[label_name])
metric_dict['auc' + suffix] = metrics_tf.auc(
label,
self._prediction_dict['probs' + suffix],
num_thresholds=metric.auc.num_thresholds)
elif num_class == 2:
label = tf.to_int64(self._labels[label_name])
metric_dict['auc' + suffix] = metrics_tf.auc(
label,
self._prediction_dict['probs' + suffix][:, 1],
num_thresholds=metric.auc.num_thresholds)
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'gauc':
assert loss_type & binary_loss_set
if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
label = tf.to_int64(self._labels[label_name])
uids = self._feature_dict[metric.gauc.uid_field]
if isinstance(uids, tf.sparse.SparseTensor):
uids = tf.sparse_to_dense(
uids.indices, uids.dense_shape, uids.values, default_value='')
uids = tf.reshape(uids, [-1])
metric_dict['gauc' + suffix] = metrics_lib.gauc(
label,
self._prediction_dict['probs' + suffix],
uids=uids,
reduction=metric.gauc.reduction)
elif num_class == 2:
label = tf.to_int64(self._labels[label_name])
metric_dict['gauc' + suffix] = metrics_lib.gauc(
label,
self._prediction_dict['probs' + suffix][:, 1],
uids=self._feature_dict[metric.gauc.uid_field],
reduction=metric.gauc.reduction)
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'session_auc':
assert loss_type & binary_loss_set
if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
label = tf.to_int64(self._labels[label_name])
metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
label,
self._prediction_dict['probs' + suffix],
session_ids=self._feature_dict[metric.session_auc.session_id_field],
reduction=metric.session_auc.reduction)
elif num_class == 2:
label = tf.to_int64(self._labels[label_name])
metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
label,
self._prediction_dict['probs' + suffix][:, 1],
session_ids=self._feature_dict[metric.session_auc.session_id_field],
reduction=metric.session_auc.reduction)
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'max_f1':
assert loss_type & binary_loss_set
if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
label = tf.to_int64(self._labels[label_name])
metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
label, self._prediction_dict['logits' + suffix])
elif num_class == 2:
label = tf.to_int64(self._labels[label_name])
metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
label, self._prediction_dict['logits' + suffix][:, 1])
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'recall_at_topk':
assert loss_type & binary_loss_set
assert num_class > 1
label = tf.to_int64(self._labels[label_name])
metric_dict['recall_at_topk' + suffix] = metrics_tf.recall_at_k(
label, self._prediction_dict['logits' + suffix],
metric.recall_at_topk.topk)
elif metric.WhichOneof('metric') == 'mean_absolute_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
metric_dict['mean_absolute_error' +
suffix] = metrics_tf.mean_absolute_error(
label, self._prediction_dict['y' + suffix])
elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
metric_dict['mean_absolute_error' +
suffix] = metrics_tf.mean_absolute_error(
label, self._prediction_dict['probs' + suffix])
else:
assert False, 'mean_absolute_error is not supported for this model'
elif metric.WhichOneof('metric') == 'mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
metric_dict['mean_squared_error' +
suffix] = metrics_tf.mean_squared_error(
label, self._prediction_dict['y' + suffix])
elif num_class == 1 and loss_type & binary_loss_set:
metric_dict['mean_squared_error' +
suffix] = metrics_tf.mean_squared_error(
label, self._prediction_dict['probs' + suffix])
else:
assert False, 'mean_squared_error is not supported for this model'
elif metric.WhichOneof('metric') == 'root_mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
metric_dict['root_mean_squared_error' +
suffix] = metrics_tf.root_mean_squared_error(
label, self._prediction_dict['y' + suffix])
elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
metric_dict['root_mean_squared_error' +
suffix] = metrics_tf.root_mean_squared_error(
label, self._prediction_dict['probs' + suffix])
else:
assert False, 'root_mean_squared_error is not supported for this model'
elif metric.WhichOneof('metric') == 'accuracy':
assert loss_type & {LossType.CLASSIFICATION}
assert num_class > 1
label = tf.to_int64(self._labels[label_name])
metric_dict['accuracy' + suffix] = metrics_tf.accuracy(
label, self._prediction_dict['y' + suffix])
return metric_dict