in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java [357:611]
public static <T extends TNumber> List<Op> updateConfusionMatrixVariables(
Ops tf,
Map<ConfusionMatrixEnum, Variable<T>> variablesToUpdate,
Operand<T> labels,
Operand<T> predictions,
Operand<TFloat32> thresholds,
Integer topK,
Integer classIndex,
Operand<T> sampleWeight,
boolean multiLabel,
Operand<T> labelWeights) {
if (multiLabel && labelWeights != null)
throw new IllegalArgumentException(
"labelWeights for multilabel data should be handled outside of updateConfusionMatrixVariables when multiLabel is true.");
if (variablesToUpdate == null || variablesToUpdate.isEmpty()) {
return Collections.EMPTY_LIST;
}
Operand<T> tLabels = labels;
Operand<T> tPredictions = predictions;
Operand<T> tSampleWeight = sampleWeight;
// We will tile data for threshold comparisons. We want a cross product of thresholds and
// predictions/labels:
// In the multilabel case, we want a data shape of (T, N, D).
// else (T, ND).
// where
// T is numThresholds (the size of the 0th dimension of thresholds)
// N is the number of examples (the 0th dimension of labels and predictions)
// Dx == Cx except that if classIndex != null,
// then the last dimension of Dx is size 1
// D is the product of all Dx
// ND is N * D
// size of the 0th dimension of thresholds
// reshape to scalar for operations later.
Operand<TInt32> numThresholds =
tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar()));
// if multilabel, then (rank(thresholds) == 1)
// else true
Operand<TBool> oneThresh;
if (multiLabel) {
oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds));
} else {
// TODO handle Ragged Tensors????
// [y_pred,
// y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
// sampleWeights)
oneThresh = tf.constant(true);
}
List<Op> controlOps = new ArrayList<>();
Operand<TInt32> axes = allAxes(tf, tPredictions);
controlOps.add(
tf.withSubScope("updateConfusionMatrixVariables-1")
.assertThat(
tf.reduceAll(
tf.math.greaterEqual(
tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
axes),
Collections.singletonList(tf.constant("predictions must be >= 0"))));
controlOps.add(
tf.withSubScope("updateConfusionMatrixVariables-2")
.assertThat(
tf.reduceAll(
tf.math.lessEqual(tPredictions, cast(tf, tf.constant(1), tPredictions.type())),
axes),
Collections.singletonList(tf.constant("predictions must be <= 1"))));
LossTuple<T> result =
LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions, tSampleWeight);
tPredictions = result.getTarget(); // shape (N, Cx)
tLabels = result.getLabels(); // shape (N, Cx)
tSampleWeight = result.getSampleWeights(); // broadcastable to (N, Dx)
if (!tPredictions.shape().isCompatibleWith(tLabels.shape()))
throw new IllegalArgumentException(
String.format(
"Shapes %s and %s are incompatible)",
tPredictions.shape().toString(), tLabels.shape().toString()));
if (topK != null) {
tPredictions = filterTopK(tf, tPredictions, topK);
}
if (classIndex != null) {
// Slice to new shapes (N, Dx)
tLabels =
tf.squeeze(
tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(-1)),
Squeeze.axis(Collections.singletonList(1L)));
tPredictions =
tf.squeeze(
tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(-1)),
Squeeze.axis(Collections.singletonList(1L)));
}
org.tensorflow.op.core.Shape<TInt32> predShape = tf.shape(tPredictions);
Operand<TInt32> numExamples =
tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar()));
// number of labels (and predictions) per example (after possibly slicing by classIndex)
// In the notation we are using for comments, this is D.
Operand<TInt32> numLabels =
tf.select(
tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)),
tf.constant(1),
tf.reduceProd(
// take all but the first dimension
tf.shape.takeLast(
predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))),
tf.constant(0)));
// threshLabelTile == numLabels except in one case:
// if multilabel and rank(thresholds) != 1, then threshLabelTile is 1
Operand<TInt32> threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1));
// if multilabel, then shape (1, N, Dx)
// else shape (1, ND),
Operand<T> predictionsExtraDim;
Operand<TBool> labelsExtraDim;
if (multiLabel) {
predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0));
labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0));
} else {
predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1)));
labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1)));
}
// the shape of each thresholds tile
// if multilabel, then [T, 1, -1]
// else [T, -1]
List<Operand<TInt32>> threshPretileShape;
// the tiling multiples for thresholds
// We want to repeat the thresholds for each data position.
// if multilabel, then [1, N, threshLabelTile]. (threshLabelTile is typically numLabels)
// else [1, ND]
List<Operand<TInt32>> threshTiles;
// tiling multiples for predictionsExtraDim and labelsExtraDim
// We want to repeat the predictions and labels for each threshold.
// If multilabel, then [T, 1, 1]
// else [T, 1]
List<Operand<TInt32>> dataTiles;
if (multiLabel) {
threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1));
threshTiles = Arrays.asList(tf.constant(1), numExamples, threshLabelTile);
dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1));
} else {
threshPretileShape =
Arrays.asList(tf.reshape(numThresholds, tf.constant(Shape.scalar())), tf.constant(-1));
Operand<TInt32> mul = tf.math.mul(numExamples, numLabels);
threshTiles = Arrays.asList(tf.constant(1), mul);
dataTiles = Arrays.asList(numThresholds, tf.constant(1));
}
// if multilabel, then shape (T, 1, T*)
// else shape (T, T*)
// where T* is the product of all threshold dimension sizes beyond 0
Operand<T> thresholdsReshaped =
tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape));
Operand<TInt32> threshTilesShape = tf.stack(threshTiles);
// if multilabel, then
// if thresholds has rank > 1, then shape (T, N, T*)
// else shape (T, N, D)
// else shape (T, ND)
Operand<T> threshTiled = tf.tile(thresholdsReshaped, threshTilesShape);
Operand<TInt32> dataTilesShape = tf.stack(dataTiles);
// if multilabel, then shape (T, N, D)
// else (T, ND)
Operand<T> predsTiled = tf.tile(predictionsExtraDim, dataTilesShape);
// Compare predictions and threshold.
Operand<TBool> predIsPos = tf.math.greater(predsTiled, threshTiled);
// Tile labels by number of thresholds
Operand<TBool> labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles));
Operand<T> weightsTiled;
if (tSampleWeight != null) {
tSampleWeight = tf.broadcastTo(tSampleWeight, tf.shape(tPredictions));
// if multilabel, then
// reshape tSampleWeight to (1, N, threshLabelTile)
// tile the result into shape (T, N, threshLabelTile)
// where threshLabelTile is typically D
// else
// reshape tSampleWeight to (1, ND)
// tile the result into shape (T, ND)
weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape);
} else {
weightsTiled = null;
}
if (labelWeights != null) {
// Change shape to (1, Dx).
Operand<T> lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0));
// Broadcast to shape (N, Dx).
lLabelWeights = tf.broadcastTo(lLabelWeights, tPredictions);
// If multilabel: shape (T, N, D)
// else: shape (T, ND)
Operand<T> labelWeightsTiled =
tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles));
if (weightsTiled == null) {
weightsTiled = labelWeightsTiled;
} else {
weightsTiled = tf.math.mul(weightsTiled, labelWeightsTiled);
}
}
Map<ConfusionMatrixEnum, Operand[]> loopVars = new HashMap<>();
loopVars.put(ConfusionMatrixEnum.TRUE_POSITIVES, new Operand[] {labelIsPos, predIsPos});
Variable<T> updateTN = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES);
Variable<T> updateFP = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES);
Variable<T> updateFN = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES);
Operand<TBool> predIsNeg = null;
Operand<TBool> labelIsNeg;
if (updateFN != null || updateTN != null) {
predIsNeg = tf.math.logicalNot(predIsPos);
loopVars.put(ConfusionMatrixEnum.FALSE_NEGATIVES, new Operand[] {labelIsPos, predIsNeg});
}
if (updateFP != null || updateTN != null) {
labelIsNeg = tf.math.logicalNot(labelIsPos);
loopVars.put(ConfusionMatrixEnum.FALSE_POSITIVES, new Operand[] {labelIsNeg, predIsPos});
if (updateTN != null) {
loopVars.put(ConfusionMatrixEnum.TRUE_NEGATIVES, new Operand[] {labelIsNeg, predIsNeg});
}
}
final Operand<T> weightsTiledF = weightsTiled;
loopVars
.keySet()
.forEach(
(c) -> {
if (variablesToUpdate.containsKey(c)) {
Operand[] op = loopVars.get(c);
// op[0] = label, op[1] == prediction
controlOps.add(
weightedAssignAdd(tf, op[0], op[1], weightsTiledF, variablesToUpdate.get(c)));
}
});
return controlOps;
}