public Operand result()

in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java [848:922]


  public <U extends TNumber> Operand<U> result(Ops tf, Class<U> resultType) {
    init(tf);
    if (getCurve() == AUCCurve.PR && getSummationMethod() == AUCSummationMethod.INTERPOLATION) {
      // This use case is different and is handled separately.
      return cast(tf, interpolatePRAuc(tf), resultType);
    }
    Operand<T> x;
    Operand<T> y;
    Operand<T> recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives));

    switch (getCurve()) {
      case ROC:
        x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives));
        y = recall;
        break;
      case PR:
        y = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives));
        x = recall;
        break;
      default:
        throw new IllegalArgumentException("Unexpected AUCCurve value: " + getCurve());
    }

    // Find the rectangle heights based on `summationMethod`.
    // y[:self.numThresholds - 1]
    Operand<T> ySlice1 = slice(tf, y, 0, getNumThresholds() - 1);
    // y[1:]
    Operand<T> ySlice2 = slice(tf, y, 1, -1);

    Operand<T> heights;
    switch (getSummationMethod()) {
      case INTERPOLATION:
        //noinspection SuspiciousNameCombination
        heights = tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type()));
        break;
      case MINORING:
        //noinspection SuspiciousNameCombination
        heights = tf.math.minimum(ySlice1, ySlice2);
        break;
      case MAJORING:
        //noinspection SuspiciousNameCombination
        heights = tf.math.maximum(ySlice1, ySlice2);
        break;
      default:
        throw new IllegalArgumentException(
            "Unexpected AUCSummationMethod value: " + getSummationMethod());
    }

    if (isMultiLabel()) {
      Operand<T> riemannTerms =
          tf.math.mul(
              tf.math.sub(slice(tf, x, 0, getNumThresholds() - 1), slice(tf, x, 1, -1)), heights);
      Operand<T> byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0));

      if (getLabelWeights() == null) {
        return cast(tf, MetricsHelper.mean(tf, byLabelAuc), resultType);
      } else {
        // Weighted average of the label AUCs.
        return cast(
            tf,
            tf.math.divNoNan(
                tf.reduceSum(
                    tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())),
                tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))),
            resultType);
      }

    } else {
      Operand<T> slice1 = slice(tf, x, 0, getNumThresholds() - 1);
      Operand<T> slice2 = slice(tf, x, 1, -1);
      Operand<T> sub = tf.math.sub(slice1, slice2);
      Operand<T> operand = tf.math.mul(sub, heights);
      return cast(tf, tf.reduceSum(operand, allAxes(tf, operand)), resultType);
    }
  }