public static Op assertBroadcastable()

in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java [80:155]


  public static <T extends TNumber> Op assertBroadcastable(
      Ops tf, Operand<T> sampleWeights, Operand<T> values) {

    // try static check for exact match

    Shape weightsShapeStatic = sampleWeights.shape();
    int weightsRankStatic = weightsShapeStatic.numDimensions();

    Shape valuesShapeStatic = values.shape();
    int valuesRankStatic = valuesShapeStatic.numDimensions();

    // if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) {
    if (!weightsShapeStatic.isUnknown()
        && !valuesShapeStatic.isUnknown()
        && !weightsShapeStatic.hasUnknownDimension()
        && !valuesShapeStatic.hasUnknownDimension()) {
      if (weightsRankStatic == 0) {
        return tf.withSubScope("staticScalarCheckSuccess")
            .withControlDependencies(java.util.Collections.EMPTY_LIST)
            .noOp();
      }
      if (weightsRankStatic != valuesRankStatic) {
        throw new NotBroadcastableException(
            String.format(
                "%s values.rank=%d. weights.rank=%d.  values.shape=%s. weights.shape=%s.",
                ASSERT_BROADCAST_ERROR_PREFIX,
                valuesRankStatic,
                weightsRankStatic,
                valuesShapeStatic,
                weightsShapeStatic));
      }

      for (int i = 0; i < valuesRankStatic; i++) {
        if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)
            && weightsShapeStatic.size(i) != 1) {
          throw new NotBroadcastableException(
              String.format(
                  "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.",
                  ASSERT_BROADCAST_ERROR_PREFIX, i, valuesShapeStatic, weightsShapeStatic));
        }
      }
      return tf.withSubScope("staticDimsCheckSuccess")
          .withControlDependencies(Collections.EMPTY_LIST)
          .noOp();
    }
    // Dynamic checks.
    Operand<TInt32> weightsShape = tf.shape(sampleWeights);
    Operand<TInt32> weightsRank = tf.rank(sampleWeights);
    Operand<TInt32> valuesShape = tf.shape(values);
    Operand<TInt32> valuesRank = tf.rank(values);

    Operand<TBool> isScalar = tf.math.equal(weightsRank, tf.constant(0));
    List<Operand<?>> data =
        Arrays.asList(
            tf.constant(ASSERT_BROADCAST_ERROR_PREFIX),
            tf.constant("weights.shape="),
            weightsShape,
            tf.constant("values.shape="),
            valuesShape,
            tf.constant("isScalar="),
            isScalar);

    // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a
    // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true.
    Operand<T> reshapedWeights =
        tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights);
    weightsShape = tf.shape(reshapedWeights);
    weightsRank = tf.rank(reshapedWeights);

    Operand<TBool> validNonscalar =
        canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape);

    Operand<TBool> isValidShape = tf.select(isScalar, isScalar, validNonscalar);

    return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data);
  }