private static LossTuple removeSqueezableDimensions()

in tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ConfusionMatrix.java [277:317]


  private static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
      Scope scope, Operand<T> labels, Operand<T> predictions, int expectedRankDiff) {
    Scope lScope = scope.withSubScope("removeSqueezableDimensions");
    Shape predictionsShape = predictions.shape();
    int predictionsRank = predictionsShape.numDimensions();
    Shape labelsShape = labels.shape();
    int labelsRank = labelsShape.numDimensions();

    if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) {
      // Use rank.
      int rankDiff = predictionsRank - labelsRank;
      if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) {
        predictions = Squeeze.create(lScope, predictions);
      } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) {
        labels = Squeeze.create(lScope, labels);
      }
      return new LossTuple<>(labels, predictions);
    }
    // Use dynamic rank.

    // TODO: hold for lazy select feature,
    //  Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels));
    if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) {
      /*
       * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze
       * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
       * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); *
       */
      predictions =
          Squeeze.create(lScope, predictions, Squeeze.axis(Collections.singletonList(-1L)));
    }
    if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) {
      /*
       * TODO, if we ever get a select that does lazy evaluation labels = tf.select(
       * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), tf.squeeze(labels,
       * Squeeze.axis(Arrays.asList(-1L))), predictions ); *
       */
      labels = Squeeze.create(lScope, labels, Squeeze.axis(Collections.singletonList(-1L)));
    }
    return new LossTuple<>(labels, predictions);
  }