in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java [129:175]
public List<Op> updateStateList(
Ops tf,
Operand<? extends TNumber> labels,
Operand<? extends TNumber> predictions,
Operand<? extends TNumber> sampleWeights) {
init(tf);
if (sampleWeights != null) {
long weightsRank = sampleWeights.shape().numDimensions();
long labelsRank = labels.shape().numDimensions();
if (weightsRank != 0
&& weightsRank != Shape.UNKNOWN_SIZE
&& labelsRank != Shape.UNKNOWN_SIZE
&& weightsRank != labelsRank) {
throw new IllegalArgumentException(
String.format(
"Weights must either have rank 0, or the same rank as labels, weights rank = %d, labels rank = %d",
weightsRank, labelsRank));
}
}
long labelsSize = labels.shape().size();
long predictionsSize = predictions.shape().size();
if (labelsSize != predictionsSize) {
throw new IllegalArgumentException(
String.format(
"labels and predictions must have the same size, labels size = %d, predictions size = %d",
labelsSize, predictionsSize));
}
Operand<T> tLabels = cast(tf, labels, type);
if (tLabels.shape().numDimensions() > 1) {
tLabels = tf.shape.flatten(tLabels);
}
Operand<T> tPredictions = cast(tf, predictions, type);
if (tPredictions.shape().numDimensions() > 1) {
tPredictions = tf.shape.flatten(tPredictions);
}
Operand<T> tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) {
tSampleWeights = tf.shape.flatten(tSampleWeights);
}
// Accumulate the prediction to current confusion matrix.
Operand<T> currentCM =
MetricsHelper.confusionMatrix(
tf, tLabels, tPredictions, tf.constant(numClasses), tSampleWeights, type);
return Collections.singletonList(tf.assignAdd(totalConfusionMatrix, currentCM));
}