in tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ConfusionMatrix.java [175:265]
public static <T extends TNumber> Operand<T> confusionMatrix(
Scope scope,
Operand<T> labels,
Operand<T> predictions,
Operand<T> weights,
Operand<TInt64> numClasses) {
Scope lScope = scope.withSubScope("confusionMatrix");
LossTuple<T> tuple = removeSqueezableDimensions(scope, labels, predictions, 0);
Operand<TInt64> lLabels = Cast.create(lScope, tuple.getLabels(), TInt64.class);
Operand<TInt64> lPredictions = Cast.create(lScope, tuple.getTarget(), TInt64.class);
Operand<TInt64> zero = Constant.scalarOf(lScope, 0L);
Operand<TInt64> one = Constant.scalarOf(lScope, 1L);
AssertThat labelsNonNegative =
AssertThat.create(
lScope,
ReduceAll.create(
lScope, GreaterEqual.create(lScope, lLabels, zero), Axes.allAxes(scope, lLabels)),
Collections.singletonList(
Constant.scalarOf(lScope, "labels contains negative values")));
lLabels =
Identity.create(
lScope.withControlDependencies(Collections.singletonList(labelsNonNegative)), lLabels);
AssertThat predictionsNonNegative =
AssertThat.create(
lScope,
ReduceAll.create(
lScope,
GreaterEqual.create(lScope, lPredictions, zero),
Axes.allAxes(scope, lPredictions)),
Collections.singletonList(
Constant.scalarOf(lScope, "predictions contains negative values")));
lPredictions =
Identity.create(
lScope.withControlDependencies(Collections.singletonList(predictionsNonNegative)),
lPredictions);
Operand<TInt64> lNumClasses;
if (numClasses == null) {
lNumClasses =
Add.create(
lScope,
Maximum.create(
lScope,
ReduceMax.create(lScope, lPredictions, zero),
ReduceMax.create(lScope, lLabels, zero)),
one);
} else {
lNumClasses = Cast.create(lScope, numClasses, TInt64.class);
Operand<TBool> less = Less.create(lScope, lLabels, lNumClasses);
AssertThat labelsLess =
AssertThat.create(
lScope,
ReduceAll.create(scope, less, Axes.allAxes(scope, less), ReduceAll.keepDims(false)),
Collections.singletonList(Constant.scalarOf(lScope, "labels out of bounds")));
lLabels =
Identity.create(
lScope.withControlDependencies(Collections.singletonList(labelsLess)), lLabels);
less = Less.create(lScope, lPredictions, lNumClasses);
AssertThat predictionsLess =
AssertThat.create(
lScope,
ReduceAll.create(scope, less, Axes.allAxes(scope, less), ReduceAll.keepDims(false)),
Collections.singletonList(Constant.scalarOf(lScope, "predictions out of bounds")));
lPredictions =
Identity.create(
lScope.withControlDependencies(Collections.singletonList(predictionsLess)),
lPredictions);
}
if (weights != null) {
if (!predictions.shape().isCompatibleWith(weights.shape())) {
throw new IllegalArgumentException(
String.format(
"predictions.shape() [%s], is not compatible with weights.shape() [ %s].",
predictions.shape(), weights.shape()));
}
}
Operand<TInt64> shape = Stack.create(lScope, Arrays.asList(lNumClasses, lNumClasses));
Operand<TInt64> indices =
Stack.create(lScope, Arrays.asList(lLabels, lPredictions), Stack.axis(1L));
Operand<T> values = weights == null ? OnesLike.create(lScope, predictions) : weights;
/// Operand<T> zeroMatrix = Zeros.create(lScope, Cast.create(lScope, shape, TInt32.class),
// type);
return ScatterNd.create(lScope, indices, values, shape);
}