in tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java [204:244]
public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
Ops tf, Operand<T> labels, Operand<T> predictions, int expectedRankDiff) {
tf = tf.withSubScope("removeSqueezableDimensions");
Shape predictionsShape = predictions.shape();
int predictionsRank = predictionsShape.numDimensions();
Shape labelsShape = labels.shape();
int labelsRank = labelsShape.numDimensions();
if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) {
// Use static rank.
int rankDiff = predictionsRank - labelsRank;
if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) {
predictions = tf.squeeze(predictions);
} else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) {
labels = tf.squeeze(labels);
}
return new LossTuple<>(labels, predictions);
}
// Use dynamic rank.
// TODO: hold for lazy select feature,
// Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels));
if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) {
/*
* TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze
* predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
* tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); *
*/
predictions = tf.squeeze(predictions, Squeeze.axis(Collections.singletonList(-1L)));
}
if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) {
/*
* TODO, if we ever get a select that does lazy evaluation labels = tf.select(
* tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), tf.squeeze(labels,
* Squeeze.axis(Arrays.asList(-1L))), predictions ); *
*/
labels = tf.squeeze(labels, Squeeze.axis(Collections.singletonList(-1L)));
}
return new LossTuple<>(labels, predictions);
}