in tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ConfusionMatrix.java [277:317]
private static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
Scope scope, Operand<T> labels, Operand<T> predictions, int expectedRankDiff) {
Scope lScope = scope.withSubScope("removeSqueezableDimensions");
Shape predictionsShape = predictions.shape();
int predictionsRank = predictionsShape.numDimensions();
Shape labelsShape = labels.shape();
int labelsRank = labelsShape.numDimensions();
if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) {
// Use rank.
int rankDiff = predictionsRank - labelsRank;
if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) {
predictions = Squeeze.create(lScope, predictions);
} else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) {
labels = Squeeze.create(lScope, labels);
}
return new LossTuple<>(labels, predictions);
}
// Use dynamic rank.
// TODO: hold for lazy select feature,
// Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels));
if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) {
/*
* TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze
* predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
* tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); *
*/
predictions =
Squeeze.create(lScope, predictions, Squeeze.axis(Collections.singletonList(-1L)));
}
if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) {
/*
* TODO, if we ever get a select that does lazy evaluation labels = tf.select(
* tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), tf.squeeze(labels,
* Squeeze.axis(Arrays.asList(-1L))), predictions ); *
*/
labels = Squeeze.create(lScope, labels, Squeeze.axis(Collections.singletonList(-1L)));
}
return new LossTuple<>(labels, predictions);
}