public List updateStateList()

in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java [129:175]


  public List<Op> updateStateList(
      Ops tf,
      Operand<? extends TNumber> labels,
      Operand<? extends TNumber> predictions,
      Operand<? extends TNumber> sampleWeights) {
    init(tf);
    if (sampleWeights != null) {
      long weightsRank = sampleWeights.shape().numDimensions();
      long labelsRank = labels.shape().numDimensions();
      if (weightsRank != 0
          && weightsRank != Shape.UNKNOWN_SIZE
          && labelsRank != Shape.UNKNOWN_SIZE
          && weightsRank != labelsRank) {
        throw new IllegalArgumentException(
            String.format(
                "Weights must either have rank 0, or the same rank as labels, weights rank = %d, labels rank = %d",
                weightsRank, labelsRank));
      }
    }
    long labelsSize = labels.shape().size();
    long predictionsSize = predictions.shape().size();
    if (labelsSize != predictionsSize) {
      throw new IllegalArgumentException(
          String.format(
              "labels and predictions must have the same size, labels size = %d, predictions size = %d",
              labelsSize, predictionsSize));
    }

    Operand<T> tLabels = cast(tf, labels, type);
    if (tLabels.shape().numDimensions() > 1) {
      tLabels = tf.shape.flatten(tLabels);
    }
    Operand<T> tPredictions = cast(tf, predictions, type);
    if (tPredictions.shape().numDimensions() > 1) {
      tPredictions = tf.shape.flatten(tPredictions);
    }
    Operand<T> tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
    if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) {
      tSampleWeights = tf.shape.flatten(tSampleWeights);
    }

    // Accumulate the prediction to current confusion matrix.
    Operand<T> currentCM =
        MetricsHelper.confusionMatrix(
            tf, tLabels, tPredictions, tf.constant(numClasses), tSampleWeights, type);
    return Collections.singletonList(tf.assignAdd(totalConfusionMatrix, currentCM));
  }