in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java [801:871]
public static <T extends TNumber> Operand<T> confusionMatrix(
Ops tf,
Operand<T> labels,
Operand<T> predictions,
Operand<TInt64> numClasses,
Operand<T> weights,
Class<T> type) {
if (!predictions.shape().isCompatibleWith(labels.shape()))
throw new IllegalArgumentException(
String.format(
"Prediction shape %s is not compatible with labels shape %s",
predictions.shape().toString(), labels.shape().toString()));
tf = tf.withSubScope("confusionMatrix");
LossTuple<T> ops = LossesHelper.squeezeOrExpandDimensions(tf, predictions, labels, null);
Operand<TInt64> tPredictions = cast(tf, ops.getTarget(), TInt64.class);
Operand<TInt64> tLabels = cast(tf, ops.getLabels(), TInt64.class);
List<Op> labelControls = new ArrayList<>();
List<Op> predictionControls = new ArrayList<>();
labelControls.add(
tf.assertThat(
tf.reduceAny(tf.math.greaterEqual(tLabels, tf.constant(0L)), allAxes(tf, tLabels)),
Collections.singletonList(tf.constant("`labels` contains negative values"))));
predictionControls.add(
tf.assertThat(
tf.reduceAny(
tf.math.greaterEqual(tPredictions, tf.constant(0L)), allAxes(tf, tPredictions)),
Collections.singletonList(tf.constant("`predictions` contains negative values"))));
if (numClasses == null) {
numClasses =
tf.math.maximum(
tf.reduceMax(tPredictions, allAxes(tf, tPredictions)),
tf.reduceMax(tLabels, allAxes(tf, tLabels)));
} else {
labelControls.add(
tf.assertThat(
tf.reduceAny(tf.math.less(tLabels, numClasses), allAxes(tf, tLabels)),
Collections.singletonList(tf.constant("``labels` out of bounds"))));
predictionControls.add(
tf.assertThat(
tf.reduceAny(tf.math.less(tPredictions, numClasses), allAxes(tf, tPredictions)),
Collections.singletonList(tf.constant("``predictions` out of bounds"))));
}
if (weights != null) {
if (!tPredictions.shape().isCompatibleWith(weights.shape())) {
throw new IllegalArgumentException(
String.format(
"Prediction shape %s is not compatible with weights shape %s",
tPredictions.shape().toString(), weights.shape().toString()));
}
}
Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls);
tLabels = tfc.identity(tLabels);
tfc = tf.withSubScope("confusionMatrixPredictions").withControlDependencies(predictionControls);
tPredictions = tfc.identity(tPredictions);
Operand<TInt64> shape = tf.stack(Arrays.asList(numClasses, numClasses));
Operand<TInt64> indices = tf.stack(Arrays.asList(tLabels, tPredictions), Stack.axis(1L));
Operand<T> values =
weights == null ? cast(tf, tf.onesLike(tPredictions), type) : cast(tf, weights, type);
SparseTensor<T> cmSparse = new SparseTensor<>(indices, values, shape);
Operand<T> zeroMatrix = tf.zeros(shape, type);
return tf.sparse.sparseTensorDenseAdd(
cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix);
}