def __init__()

in tensorflow_estimator/python/estimator/head/multi_label_head.py [0:0]


  def __init__(self,
               n_classes,
               weight_column=None,
               thresholds=None,
               label_vocabulary=None,
               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,
               loss_fn=None,
               classes_for_class_based_metrics=None,
               name=None):
    if n_classes is None or n_classes < 2:
      raise ValueError('n_classes must be > 1 for multi-label classification. '
                       'Given: {}'.format(n_classes))
    thresholds = tuple(thresholds) if thresholds else tuple()
    for threshold in thresholds:
      if (threshold <= 0.0) or (threshold >= 1.0):
        raise ValueError(
            'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
    if label_vocabulary is not None:
      if not isinstance(label_vocabulary, (list, tuple)):
        raise ValueError('label_vocabulary must be a list or tuple. '
                         'Given type: {}'.format(type(label_vocabulary)))
      if len(label_vocabulary) != n_classes:
        raise ValueError('Length of label_vocabulary must be n_classes ({}). '
                         'Given: {}'.format(n_classes, len(label_vocabulary)))

    if loss_fn:
      base_head.validate_loss_fn_args(loss_fn)
    base_head.validate_loss_reduction(loss_reduction)
    if classes_for_class_based_metrics:
      classes_for_class_based_metrics = tuple(classes_for_class_based_metrics)
      if isinstance(classes_for_class_based_metrics[0], six.string_types):
        if not label_vocabulary:
          raise ValueError('label_vocabulary must be provided when '
                           'classes_for_class_based_metrics are strings.')
        class_ids = []
        for class_string in classes_for_class_based_metrics:
          class_ids.append(label_vocabulary.index(class_string))
        classes_for_class_based_metrics = tuple(class_ids)
      else:
        for class_id in classes_for_class_based_metrics:
          if (class_id < 0) or (class_id >= n_classes):
            raise ValueError(
                'All classes_for_class_based_metrics must be in range [0, {}]. '
                'Given: {}'.format(n_classes - 1, class_id))
    else:
      classes_for_class_based_metrics = tuple()
    self._n_classes = n_classes
    self._weight_column = weight_column
    self._thresholds = thresholds
    self._label_vocabulary = label_vocabulary
    self._loss_reduction = loss_reduction
    self._loss_fn = loss_fn
    self._classes_for_class_based_metrics = classes_for_class_based_metrics
    self._name = name
    # Metric keys.
    keys = metric_keys.MetricKeys
    self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)
    self._auc_key = self._summary_key(keys.AUC)
    self._auc_pr_key = self._summary_key(keys.AUC_PR)
    self._loss_regularization_key = self._summary_key(keys.LOSS_REGULARIZATION)
    accuracy_keys = []
    precision_keys = []
    recall_keys = []
    for threshold in self._thresholds:
      accuracy_keys.append(
          self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold))
      precision_keys.append(
          self._summary_key(keys.PRECISION_AT_THRESHOLD % threshold))
      recall_keys.append(
          self._summary_key(keys.RECALL_AT_THRESHOLD % threshold))
    self._accuracy_keys = tuple(accuracy_keys)
    self._precision_keys = tuple(precision_keys)
    self._recall_keys = tuple(recall_keys)
    prob_keys = []
    auc_keys = []
    auc_pr_keys = []
    for class_id in self._classes_for_class_based_metrics:
      if self._label_vocabulary is None:
        prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
        auc_key = keys.AUC_AT_CLASS % class_id
        auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
      else:
        prob_key = (
            keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id])
        auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id]
        auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[class_id]
      prob_keys.append(self._summary_key(prob_key))
      auc_keys.append(self._summary_key(auc_key))
      auc_pr_keys.append(self._summary_key(auc_pr_key))
    self._prob_keys = tuple(prob_keys)
    self._auc_keys = tuple(auc_keys)
    self._auc_pr_keys = tuple(auc_pr_keys)