in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java [125:203]
public List<Op> updateStateList(
Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) {
if (values == null) {
throw new IllegalArgumentException("values is required.");
}
init(tf);
List<Op> updateOperations = new ArrayList<>();
// cast everything to match the variables
Operand<T> tSampleWeights = null;
Operand<T> tValues = cast(tf, values, getInternalType());
if (sampleWeights != null) {
tSampleWeights = cast(tf, sampleWeights, getInternalType());
// Update dimensions of weights to match with values if possible.
LossTuple<T> tuple =
LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights);
tValues = tuple.getTarget();
tSampleWeights = tuple.getSampleWeights();
try {
// Broadcast weights if possible
tSampleWeights = MetricsHelper.broadcastWeights(tf, tSampleWeights, tValues);
} catch (IllegalArgumentException ex) {
// reduce values to same ndim as weight array
// if we get here we have static shapes with either
// different ranks or different dimension sizes.
// first, reduce the values down to the rank of the samples
int valuesRank = tValues.shape().numDimensions();
int weightsRank = tSampleWeights.shape().numDimensions();
int numAxes = Math.min(0, valuesRank - weightsRank);
if (numAxes
> 0) { // values rank is greater than weights rank, reduce values to weights rank.
int[] axes = new int[numAxes];
for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank;
if (reduction == MetricReduction.SUM) {
tValues = tf.reduceSum(tValues, tf.constant(axes));
} else {
tValues = tf.math.mean(tValues, tf.constant(axes));
}
}
}
tValues = tf.math.mul(tValues, tSampleWeights);
}
Operand<? extends TNumber> weightedValueSum =
tf.reduceSum(tValues, LossesHelper.allAxes(tf, tValues));
Operand<T> totalUpdate = tf.assignAdd(total, cast(tf, weightedValueSum, total.type()));
updateOperations.add(totalUpdate);
Operand<T> numValues;
// Exit early if the reduction doesn't have a denominator.
if (reduction != MetricReduction.SUM) {
// Update `count` for reductions that require a denominator.
switch (reduction) {
case SUM_OVER_BATCH_SIZE:
numValues = cast(tf, tf.constant(tValues.shape().size()), internalType);
break;
case WEIGHTED_MEAN:
if (tSampleWeights == null) {
numValues = cast(tf, tf.constant(tValues.shape().size()), internalType);
} else {
numValues =
cast(
tf,
tf.reduceSum(tSampleWeights, LossesHelper.allAxes(tf, tSampleWeights)),
internalType);
}
break;
default:
throw new UnsupportedOperationException(
String.format("reduction [%s] not implemented", reduction));
}
Operand<T> totalCount = tf.assignAdd(this.count, numValues);
updateOperations.add(totalCount);
}
return updateOperations;
}