in tensorflow_estimator/python/estimator/head/multi_label_head.py [0:0]
def __init__(self,
n_classes,
weight_column=None,
thresholds=None,
label_vocabulary=None,
loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
classes_for_class_based_metrics=None,
name=None):
if n_classes is None or n_classes < 2:
raise ValueError('n_classes must be > 1 for multi-label classification. '
'Given: {}'.format(n_classes))
thresholds = tuple(thresholds) if thresholds else tuple()
for threshold in thresholds:
if (threshold <= 0.0) or (threshold >= 1.0):
raise ValueError(
'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
if label_vocabulary is not None:
if not isinstance(label_vocabulary, (list, tuple)):
raise ValueError('label_vocabulary must be a list or tuple. '
'Given type: {}'.format(type(label_vocabulary)))
if len(label_vocabulary) != n_classes:
raise ValueError('Length of label_vocabulary must be n_classes ({}). '
'Given: {}'.format(n_classes, len(label_vocabulary)))
if loss_fn:
base_head.validate_loss_fn_args(loss_fn)
base_head.validate_loss_reduction(loss_reduction)
if classes_for_class_based_metrics:
classes_for_class_based_metrics = tuple(classes_for_class_based_metrics)
if isinstance(classes_for_class_based_metrics[0], six.string_types):
if not label_vocabulary:
raise ValueError('label_vocabulary must be provided when '
'classes_for_class_based_metrics are strings.')
class_ids = []
for class_string in classes_for_class_based_metrics:
class_ids.append(label_vocabulary.index(class_string))
classes_for_class_based_metrics = tuple(class_ids)
else:
for class_id in classes_for_class_based_metrics:
if (class_id < 0) or (class_id >= n_classes):
raise ValueError(
'All classes_for_class_based_metrics must be in range [0, {}]. '
'Given: {}'.format(n_classes - 1, class_id))
else:
classes_for_class_based_metrics = tuple()
self._n_classes = n_classes
self._weight_column = weight_column
self._thresholds = thresholds
self._label_vocabulary = label_vocabulary
self._loss_reduction = loss_reduction
self._loss_fn = loss_fn
self._classes_for_class_based_metrics = classes_for_class_based_metrics
self._name = name
# Metric keys.
keys = metric_keys.MetricKeys
self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)
self._auc_key = self._summary_key(keys.AUC)
self._auc_pr_key = self._summary_key(keys.AUC_PR)
self._loss_regularization_key = self._summary_key(keys.LOSS_REGULARIZATION)
accuracy_keys = []
precision_keys = []
recall_keys = []
for threshold in self._thresholds:
accuracy_keys.append(
self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold))
precision_keys.append(
self._summary_key(keys.PRECISION_AT_THRESHOLD % threshold))
recall_keys.append(
self._summary_key(keys.RECALL_AT_THRESHOLD % threshold))
self._accuracy_keys = tuple(accuracy_keys)
self._precision_keys = tuple(precision_keys)
self._recall_keys = tuple(recall_keys)
prob_keys = []
auc_keys = []
auc_pr_keys = []
for class_id in self._classes_for_class_based_metrics:
if self._label_vocabulary is None:
prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
auc_key = keys.AUC_AT_CLASS % class_id
auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
else:
prob_key = (
keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id])
auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id]
auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[class_id]
prob_keys.append(self._summary_key(prob_key))
auc_keys.append(self._summary_key(auc_key))
auc_pr_keys.append(self._summary_key(auc_pr_key))
self._prob_keys = tuple(prob_keys)
self._auc_keys = tuple(auc_keys)
self._auc_pr_keys = tuple(auc_pr_keys)