public static Operand confusionMatrix()

in tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ConfusionMatrix.java [175:265]


  public static <T extends TNumber> Operand<T> confusionMatrix(
      Scope scope,
      Operand<T> labels,
      Operand<T> predictions,
      Operand<T> weights,
      Operand<TInt64> numClasses) {
    Scope lScope = scope.withSubScope("confusionMatrix");
    LossTuple<T> tuple = removeSqueezableDimensions(scope, labels, predictions, 0);
    Operand<TInt64> lLabels = Cast.create(lScope, tuple.getLabels(), TInt64.class);
    Operand<TInt64> lPredictions = Cast.create(lScope, tuple.getTarget(), TInt64.class);

    Operand<TInt64> zero = Constant.scalarOf(lScope, 0L);
    Operand<TInt64> one = Constant.scalarOf(lScope, 1L);

    AssertThat labelsNonNegative =
        AssertThat.create(
            lScope,
            ReduceAll.create(
                lScope, GreaterEqual.create(lScope, lLabels, zero), Axes.allAxes(scope, lLabels)),
            Collections.singletonList(
                Constant.scalarOf(lScope, "labels contains negative values")));
    lLabels =
        Identity.create(
            lScope.withControlDependencies(Collections.singletonList(labelsNonNegative)), lLabels);

    AssertThat predictionsNonNegative =
        AssertThat.create(
            lScope,
            ReduceAll.create(
                lScope,
                GreaterEqual.create(lScope, lPredictions, zero),
                Axes.allAxes(scope, lPredictions)),
            Collections.singletonList(
                Constant.scalarOf(lScope, "predictions contains negative values")));
    lPredictions =
        Identity.create(
            lScope.withControlDependencies(Collections.singletonList(predictionsNonNegative)),
            lPredictions);

    Operand<TInt64> lNumClasses;
    if (numClasses == null) {
      lNumClasses =
          Add.create(
              lScope,
              Maximum.create(
                  lScope,
                  ReduceMax.create(lScope, lPredictions, zero),
                  ReduceMax.create(lScope, lLabels, zero)),
              one);
    } else {
      lNumClasses = Cast.create(lScope, numClasses, TInt64.class);
      Operand<TBool> less = Less.create(lScope, lLabels, lNumClasses);
      AssertThat labelsLess =
          AssertThat.create(
              lScope,
              ReduceAll.create(scope, less, Axes.allAxes(scope, less), ReduceAll.keepDims(false)),
              Collections.singletonList(Constant.scalarOf(lScope, "labels out of bounds")));
      lLabels =
          Identity.create(
              lScope.withControlDependencies(Collections.singletonList(labelsLess)), lLabels);

      less = Less.create(lScope, lPredictions, lNumClasses);
      AssertThat predictionsLess =
          AssertThat.create(
              lScope,
              ReduceAll.create(scope, less, Axes.allAxes(scope, less), ReduceAll.keepDims(false)),
              Collections.singletonList(Constant.scalarOf(lScope, "predictions  out of bounds")));
      lPredictions =
          Identity.create(
              lScope.withControlDependencies(Collections.singletonList(predictionsLess)),
              lPredictions);
    }

    if (weights != null) {
      if (!predictions.shape().isCompatibleWith(weights.shape())) {
        throw new IllegalArgumentException(
            String.format(
                "predictions.shape() [%s], is not compatible with weights.shape() [ %s].",
                predictions.shape(), weights.shape()));
      }
    }

    Operand<TInt64> shape = Stack.create(lScope, Arrays.asList(lNumClasses, lNumClasses));
    Operand<TInt64> indices =
        Stack.create(lScope, Arrays.asList(lLabels, lPredictions), Stack.axis(1L));
    Operand<T> values = weights == null ? OnesLike.create(lScope, predictions) : weights;
    /// Operand<T> zeroMatrix = Zeros.create(lScope, Cast.create(lScope, shape, TInt32.class),
    // type);

    return ScatterNd.create(lScope, indices, values, shape);
  }