in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java [80:155]
public static <T extends TNumber> Op assertBroadcastable(
Ops tf, Operand<T> sampleWeights, Operand<T> values) {
// try static check for exact match
Shape weightsShapeStatic = sampleWeights.shape();
int weightsRankStatic = weightsShapeStatic.numDimensions();
Shape valuesShapeStatic = values.shape();
int valuesRankStatic = valuesShapeStatic.numDimensions();
// if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) {
if (!weightsShapeStatic.isUnknown()
&& !valuesShapeStatic.isUnknown()
&& !weightsShapeStatic.hasUnknownDimension()
&& !valuesShapeStatic.hasUnknownDimension()) {
if (weightsRankStatic == 0) {
return tf.withSubScope("staticScalarCheckSuccess")
.withControlDependencies(java.util.Collections.EMPTY_LIST)
.noOp();
}
if (weightsRankStatic != valuesRankStatic) {
throw new NotBroadcastableException(
String.format(
"%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.",
ASSERT_BROADCAST_ERROR_PREFIX,
valuesRankStatic,
weightsRankStatic,
valuesShapeStatic,
weightsShapeStatic));
}
for (int i = 0; i < valuesRankStatic; i++) {
if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)
&& weightsShapeStatic.size(i) != 1) {
throw new NotBroadcastableException(
String.format(
"%s Mismatch at dim %d. values.shape=%s weights.shape=%s.",
ASSERT_BROADCAST_ERROR_PREFIX, i, valuesShapeStatic, weightsShapeStatic));
}
}
return tf.withSubScope("staticDimsCheckSuccess")
.withControlDependencies(Collections.EMPTY_LIST)
.noOp();
}
// Dynamic checks.
Operand<TInt32> weightsShape = tf.shape(sampleWeights);
Operand<TInt32> weightsRank = tf.rank(sampleWeights);
Operand<TInt32> valuesShape = tf.shape(values);
Operand<TInt32> valuesRank = tf.rank(values);
Operand<TBool> isScalar = tf.math.equal(weightsRank, tf.constant(0));
List<Operand<?>> data =
Arrays.asList(
tf.constant(ASSERT_BROADCAST_ERROR_PREFIX),
tf.constant("weights.shape="),
weightsShape,
tf.constant("values.shape="),
valuesShape,
tf.constant("isScalar="),
isScalar);
// hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a
// scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true.
Operand<T> reshapedWeights =
tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights);
weightsShape = tf.shape(reshapedWeights);
weightsRank = tf.rank(reshapedWeights);
Operand<TBool> validNonscalar =
canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape);
Operand<TBool> isValidShape = tf.select(isScalar, isScalar, validNonscalar);
return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data);
}