public static Operand confusionMatrix()

in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java [801:871]


  public static <T extends TNumber> Operand<T> confusionMatrix(
      Ops tf,
      Operand<T> labels,
      Operand<T> predictions,
      Operand<TInt64> numClasses,
      Operand<T> weights,
      Class<T> type) {
    if (!predictions.shape().isCompatibleWith(labels.shape()))
      throw new IllegalArgumentException(
          String.format(
              "Prediction shape %s is not compatible with labels shape %s",
              predictions.shape().toString(), labels.shape().toString()));
    tf = tf.withSubScope("confusionMatrix");
    LossTuple<T> ops = LossesHelper.squeezeOrExpandDimensions(tf, predictions, labels, null);
    Operand<TInt64> tPredictions = cast(tf, ops.getTarget(), TInt64.class);
    Operand<TInt64> tLabels = cast(tf, ops.getLabels(), TInt64.class);

    List<Op> labelControls = new ArrayList<>();
    List<Op> predictionControls = new ArrayList<>();

    labelControls.add(
        tf.assertThat(
            tf.reduceAny(tf.math.greaterEqual(tLabels, tf.constant(0L)), allAxes(tf, tLabels)),
            Collections.singletonList(tf.constant("`labels` contains negative values"))));

    predictionControls.add(
        tf.assertThat(
            tf.reduceAny(
                tf.math.greaterEqual(tPredictions, tf.constant(0L)), allAxes(tf, tPredictions)),
            Collections.singletonList(tf.constant("`predictions` contains negative values"))));
    if (numClasses == null) {
      numClasses =
          tf.math.maximum(
              tf.reduceMax(tPredictions, allAxes(tf, tPredictions)),
              tf.reduceMax(tLabels, allAxes(tf, tLabels)));
    } else {
      labelControls.add(
          tf.assertThat(
              tf.reduceAny(tf.math.less(tLabels, numClasses), allAxes(tf, tLabels)),
              Collections.singletonList(tf.constant("``labels` out of bounds"))));
      predictionControls.add(
          tf.assertThat(
              tf.reduceAny(tf.math.less(tPredictions, numClasses), allAxes(tf, tPredictions)),
              Collections.singletonList(tf.constant("``predictions` out of bounds"))));
    }

    if (weights != null) {
      if (!tPredictions.shape().isCompatibleWith(weights.shape())) {
        throw new IllegalArgumentException(
            String.format(
                "Prediction shape %s is not compatible with weights shape %s",
                tPredictions.shape().toString(), weights.shape().toString()));
      }
    }

    Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls);
    tLabels = tfc.identity(tLabels);

    tfc = tf.withSubScope("confusionMatrixPredictions").withControlDependencies(predictionControls);
    tPredictions = tfc.identity(tPredictions);

    Operand<TInt64> shape = tf.stack(Arrays.asList(numClasses, numClasses));
    Operand<TInt64> indices = tf.stack(Arrays.asList(tLabels, tPredictions), Stack.axis(1L));
    Operand<T> values =
        weights == null ? cast(tf, tf.onesLike(tPredictions), type) : cast(tf, weights, type);
    SparseTensor<T> cmSparse = new SparseTensor<>(indices, values, shape);
    Operand<T> zeroMatrix = tf.zeros(shape, type);

    return tf.sparse.sparseTensorDenseAdd(
        cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix);
  }