in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java [848:922]
public <U extends TNumber> Operand<U> result(Ops tf, Class<U> resultType) {
init(tf);
if (getCurve() == AUCCurve.PR && getSummationMethod() == AUCSummationMethod.INTERPOLATION) {
// This use case is different and is handled separately.
return cast(tf, interpolatePRAuc(tf), resultType);
}
Operand<T> x;
Operand<T> y;
Operand<T> recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives));
switch (getCurve()) {
case ROC:
x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives));
y = recall;
break;
case PR:
y = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives));
x = recall;
break;
default:
throw new IllegalArgumentException("Unexpected AUCCurve value: " + getCurve());
}
// Find the rectangle heights based on `summationMethod`.
// y[:self.numThresholds - 1]
Operand<T> ySlice1 = slice(tf, y, 0, getNumThresholds() - 1);
// y[1:]
Operand<T> ySlice2 = slice(tf, y, 1, -1);
Operand<T> heights;
switch (getSummationMethod()) {
case INTERPOLATION:
//noinspection SuspiciousNameCombination
heights = tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type()));
break;
case MINORING:
//noinspection SuspiciousNameCombination
heights = tf.math.minimum(ySlice1, ySlice2);
break;
case MAJORING:
//noinspection SuspiciousNameCombination
heights = tf.math.maximum(ySlice1, ySlice2);
break;
default:
throw new IllegalArgumentException(
"Unexpected AUCSummationMethod value: " + getSummationMethod());
}
if (isMultiLabel()) {
Operand<T> riemannTerms =
tf.math.mul(
tf.math.sub(slice(tf, x, 0, getNumThresholds() - 1), slice(tf, x, 1, -1)), heights);
Operand<T> byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0));
if (getLabelWeights() == null) {
return cast(tf, MetricsHelper.mean(tf, byLabelAuc), resultType);
} else {
// Weighted average of the label AUCs.
return cast(
tf,
tf.math.divNoNan(
tf.reduceSum(
tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())),
tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))),
resultType);
}
} else {
Operand<T> slice1 = slice(tf, x, 0, getNumThresholds() - 1);
Operand<T> slice2 = slice(tf, x, 1, -1);
Operand<T> sub = tf.math.sub(slice1, slice2);
Operand<T> operand = tf.math.mul(sub, heights);
return cast(tf, tf.reduceSum(operand, allAxes(tf, operand)), resultType);
}
}