in tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java [91:138]
public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(
Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Shape predictionsShape = predictions.shape();
long predictionsRank = predictionsShape.numDimensions();
// Default case when no modifications are made.
LossTuple<T> lossTuple = new LossTuple<>(labels, predictions, sampleWeights);
if (labels != null) {
Shape labelsShape = labels.shape();
long labelsRank = labelsShape.numDimensions();
if (labelsRank != Shape.UNKNOWN_SIZE && predictionsRank != Shape.UNKNOWN_SIZE) {
// Use static rank for 'label' and 'prediction'.
if (predictionsRank - labelsRank != 1 || predictionsShape.size(-1) == 1) {
lossTuple = removeSqueezableDimensions(tf, labels, predictions);
}
} else { // use dynamic rank
lossTuple = removeSqueezableDimensions(tf, labels, predictions);
}
}
if (sampleWeights == null) { // nothing more to do.
return lossTuple;
}
Shape weightsShape = sampleWeights.shape();
long weightsRank = weightsShape.numDimensions();
if (weightsRank == 0) { // scalar
return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights);
}
if (predictionsRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) {
if (weightsRank - predictionsRank == 1) {
sampleWeights = tf.squeeze(sampleWeights);
} else if (predictionsRank - weightsRank == 1) {
sampleWeights = tf.expandDims(sampleWeights, tf.constant(-1L));
}
return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights);
}
// Use dynamic rank.
Operand<TInt32> weightsRankTensor = tf.rank(sampleWeights);
Operand<TInt32> rankDiff = tf.math.sub(weightsRankTensor, tf.rank(predictions));
sampleWeights =
tf.select(
tf.math.equal(weightsRankTensor, tf.constant(0)),
sampleWeights,
maybeAdjustWeights(tf, sampleWeights, rankDiff));
return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights);
}