public static LossTuple squeezeOrExpandDimensions()

in tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java [91:138]


  public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(
      Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights) {

    Shape predictionsShape = predictions.shape();
    long predictionsRank = predictionsShape.numDimensions();

    // Default case when no modifications are made.
    LossTuple<T> lossTuple = new LossTuple<>(labels, predictions, sampleWeights);
    if (labels != null) {
      Shape labelsShape = labels.shape();
      long labelsRank = labelsShape.numDimensions();
      if (labelsRank != Shape.UNKNOWN_SIZE && predictionsRank != Shape.UNKNOWN_SIZE) {
        // Use static rank for 'label' and 'prediction'.
        if (predictionsRank - labelsRank != 1 || predictionsShape.size(-1) == 1) {
          lossTuple = removeSqueezableDimensions(tf, labels, predictions);
        }
      } else { // use dynamic rank
        lossTuple = removeSqueezableDimensions(tf, labels, predictions);
      }
    }
    if (sampleWeights == null) { // nothing more to do.
      return lossTuple;
    }
    Shape weightsShape = sampleWeights.shape();
    long weightsRank = weightsShape.numDimensions();
    if (weightsRank == 0) { // scalar
      return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights);
    }

    if (predictionsRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) {

      if (weightsRank - predictionsRank == 1) {
        sampleWeights = tf.squeeze(sampleWeights);
      } else if (predictionsRank - weightsRank == 1) {
        sampleWeights = tf.expandDims(sampleWeights, tf.constant(-1L));
      }
      return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights);
    }
    // Use dynamic rank.
    Operand<TInt32> weightsRankTensor = tf.rank(sampleWeights);
    Operand<TInt32> rankDiff = tf.math.sub(weightsRankTensor, tf.rank(predictions));
    sampleWeights =
        tf.select(
            tf.math.equal(weightsRankTensor, tf.constant(0)),
            sampleWeights,
            maybeAdjustWeights(tf, sampleWeights, rankDiff));
    return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights);
  }