in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java [587:643]
protected void init(Ops tf) {
checkIsGraph(tf);
if (shape != null && !isInitialized()) {
setTF(tf);
if (labelWeights != null) {
// assert that labelWeights are non-negative.
Op checks =
tf.withSubScope("AUC")
.assertThat(
tf.math.greaterEqual(
labelWeights, cast(tf, tf.constant(0), labelWeights.type())),
Collections.singletonList(
tf.constant("All values of labelWeights must be non-negative.")));
Ops ltf =
tf.withSubScope("updateState")
.withControlDependencies(Collections.singletonList(checks));
this.labelWeights = ltf.identity(this.labelWeights);
}
if (isMultiLabel()) {
if (shape == null) {
throw new IllegalArgumentException("For multiLabel, a shape must be provided");
}
if (shape.numDimensions() != 2)
throw new IllegalArgumentException(
String.format(
"labels must have rank=2 when multiLabel is true. Found rank %d.",
shape.numDimensions()));
numLabels = (int) shape.size(1);
variableShape = Shape.of(numThresholds, numLabels);
} else {
variableShape = Shape.of(numThresholds);
}
// Create metric variables
Operand<T> zero = zeros.call(tf, tf.constant(variableShape), type);
if (truePositives == null) {
truePositives = tf.withName(getTruePositivesName()).withInitScope().variable(zero);
}
if (falsePositives == null) {
falsePositives = tf.withName(getFalsePositivesName()).withInitScope().variable(zero);
}
if (trueNegatives == null) {
trueNegatives = tf.withName(getTrueNegativesName()).withInitScope().variable(zero);
}
if (falseNegatives == null) {
falseNegatives = tf.withName(getFalseNegativesName()).withInitScope().variable(zero);
}
setInitialized(true);
}
}