public static List updateConfusionMatrixVariables()

in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java [357:611]


  public static <T extends TNumber> List<Op> updateConfusionMatrixVariables(
      Ops tf,
      Map<ConfusionMatrixEnum, Variable<T>> variablesToUpdate,
      Operand<T> labels,
      Operand<T> predictions,
      Operand<TFloat32> thresholds,
      Integer topK,
      Integer classIndex,
      Operand<T> sampleWeight,
      boolean multiLabel,
      Operand<T> labelWeights) {
    if (multiLabel && labelWeights != null)
      throw new IllegalArgumentException(
          "labelWeights for multilabel data should be handled outside of updateConfusionMatrixVariables when multiLabel is true.");

    if (variablesToUpdate == null || variablesToUpdate.isEmpty()) {
      return Collections.EMPTY_LIST;
    }

    Operand<T> tLabels = labels;
    Operand<T> tPredictions = predictions;
    Operand<T> tSampleWeight = sampleWeight;

    // We will tile data for threshold comparisons. We want a cross product of thresholds and
    // predictions/labels:
    //   In the multilabel case, we want a data shape of (T, N, D).
    //                                              else (T, ND).
    //   where
    //     T is numThresholds (the size of the 0th dimension of thresholds)
    //     N is the number of examples (the 0th dimension of labels and predictions)
    //     Dx == Cx except that if classIndex != null,
    //                          then the last dimension of Dx is size 1
    //     D is the product of all Dx
    //     ND is N * D

    // size of the 0th dimension of thresholds
    // reshape to scalar for operations later.
    Operand<TInt32> numThresholds =
        tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar()));

    // if multilabel, then (rank(thresholds) == 1)
    //                else true
    Operand<TBool> oneThresh;
    if (multiLabel) {
      oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds));
    } else {
      // TODO handle Ragged Tensors????
      // [y_pred,
      //    y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
      //                                                   sampleWeights)
      oneThresh = tf.constant(true);
    }

    List<Op> controlOps = new ArrayList<>();
    Operand<TInt32> axes = allAxes(tf, tPredictions);
    controlOps.add(
        tf.withSubScope("updateConfusionMatrixVariables-1")
            .assertThat(
                tf.reduceAll(
                    tf.math.greaterEqual(
                        tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
                    axes),
                Collections.singletonList(tf.constant("predictions must be >= 0"))));
    controlOps.add(
        tf.withSubScope("updateConfusionMatrixVariables-2")
            .assertThat(
                tf.reduceAll(
                    tf.math.lessEqual(tPredictions, cast(tf, tf.constant(1), tPredictions.type())),
                    axes),
                Collections.singletonList(tf.constant("predictions must be <= 1"))));

    LossTuple<T> result =
        LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions, tSampleWeight);
    tPredictions = result.getTarget(); // shape (N, Cx)
    tLabels = result.getLabels(); // shape (N, Cx)
    tSampleWeight = result.getSampleWeights(); // broadcastable to (N, Dx)

    if (!tPredictions.shape().isCompatibleWith(tLabels.shape()))
      throw new IllegalArgumentException(
          String.format(
              "Shapes %s and %s are incompatible)",
              tPredictions.shape().toString(), tLabels.shape().toString()));

    if (topK != null) {
      tPredictions = filterTopK(tf, tPredictions, topK);
    }

    if (classIndex != null) {
      // Slice to new shapes (N, Dx)
      tLabels =
          tf.squeeze(
              tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(-1)),
              Squeeze.axis(Collections.singletonList(1L)));
      tPredictions =
          tf.squeeze(
              tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(-1)),
              Squeeze.axis(Collections.singletonList(1L)));
    }
    org.tensorflow.op.core.Shape<TInt32> predShape = tf.shape(tPredictions);

    Operand<TInt32> numExamples =
        tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar()));

    // number of labels (and predictions) per example (after possibly slicing by classIndex)
    // In the notation we are using for comments, this is D.
    Operand<TInt32> numLabels =
        tf.select(
            tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)),
            tf.constant(1),
            tf.reduceProd(
                // take all but the first dimension
                tf.shape.takeLast(
                    predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))),
                tf.constant(0)));

    // threshLabelTile == numLabels except in one case:
    //    if multilabel and rank(thresholds) != 1, then threshLabelTile is 1
    Operand<TInt32> threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1));

    // if multilabel, then shape (1, N, Dx)
    // else shape (1, ND),
    Operand<T> predictionsExtraDim;
    Operand<TBool> labelsExtraDim;

    if (multiLabel) {
      predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0));
      labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0));
    } else {
      predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1)));
      labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1)));
    }

    // the shape of each thresholds tile
    // if multilabel, then [T, 1, -1]
    //                else [T, -1]
    List<Operand<TInt32>> threshPretileShape;

    // the tiling multiples for thresholds
    // We want to repeat the thresholds for each data position.
    // if multilabel, then [1, N, threshLabelTile]. (threshLabelTile is typically numLabels)
    //                else [1, ND]
    List<Operand<TInt32>> threshTiles;

    // tiling multiples for predictionsExtraDim and labelsExtraDim
    // We want to repeat the predictions and labels for each threshold.
    // If multilabel, then [T, 1, 1]
    //                else [T, 1]
    List<Operand<TInt32>> dataTiles;

    if (multiLabel) {
      threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1));
      threshTiles = Arrays.asList(tf.constant(1), numExamples, threshLabelTile);
      dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1));
    } else {
      threshPretileShape =
          Arrays.asList(tf.reshape(numThresholds, tf.constant(Shape.scalar())), tf.constant(-1));
      Operand<TInt32> mul = tf.math.mul(numExamples, numLabels);
      threshTiles = Arrays.asList(tf.constant(1), mul);
      dataTiles = Arrays.asList(numThresholds, tf.constant(1));
    }

    // if multilabel, then shape (T, 1, T*)
    //                else shape (T, T*)
    // where T* is the product of all threshold dimension sizes beyond 0
    Operand<T> thresholdsReshaped =
        tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape));

    Operand<TInt32> threshTilesShape = tf.stack(threshTiles);

    // if multilabel, then
    //     if thresholds has rank > 1, then shape (T, N, T*)
    //                                 else shape (T, N, D)
    // else shape (T, ND)
    Operand<T> threshTiled = tf.tile(thresholdsReshaped, threshTilesShape);

    Operand<TInt32> dataTilesShape = tf.stack(dataTiles);

    // if multilabel, then shape (T, N, D)
    //                      else (T, ND)
    Operand<T> predsTiled = tf.tile(predictionsExtraDim, dataTilesShape);

    // Compare predictions and threshold.
    Operand<TBool> predIsPos = tf.math.greater(predsTiled, threshTiled);
    // Tile labels by number of thresholds
    Operand<TBool> labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles));
    Operand<T> weightsTiled;
    if (tSampleWeight != null) {
      tSampleWeight = tf.broadcastTo(tSampleWeight, tf.shape(tPredictions));
      // if multilabel, then
      //     reshape tSampleWeight to (1, N, threshLabelTile)
      //     tile the result into shape (T, N, threshLabelTile)
      //     where threshLabelTile is typically D
      // else
      //     reshape tSampleWeight to (1, ND)
      //     tile the result into shape (T, ND)
      weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape);
    } else {
      weightsTiled = null;
    }

    if (labelWeights != null) {
      // Change shape to (1, Dx).
      Operand<T> lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0));

      // Broadcast to shape (N, Dx).
      lLabelWeights = tf.broadcastTo(lLabelWeights, tPredictions);

      // If multilabel: shape (T, N, D)
      //          else: shape (T, ND)
      Operand<T> labelWeightsTiled =
          tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles));

      if (weightsTiled == null) {
        weightsTiled = labelWeightsTiled;
      } else {
        weightsTiled = tf.math.mul(weightsTiled, labelWeightsTiled);
      }
    }

    Map<ConfusionMatrixEnum, Operand[]> loopVars = new HashMap<>();
    loopVars.put(ConfusionMatrixEnum.TRUE_POSITIVES, new Operand[] {labelIsPos, predIsPos});
    Variable<T> updateTN = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES);
    Variable<T> updateFP = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES);
    Variable<T> updateFN = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES);

    Operand<TBool> predIsNeg = null;
    Operand<TBool> labelIsNeg;
    if (updateFN != null || updateTN != null) {
      predIsNeg = tf.math.logicalNot(predIsPos);
      loopVars.put(ConfusionMatrixEnum.FALSE_NEGATIVES, new Operand[] {labelIsPos, predIsNeg});
    }

    if (updateFP != null || updateTN != null) {
      labelIsNeg = tf.math.logicalNot(labelIsPos);
      loopVars.put(ConfusionMatrixEnum.FALSE_POSITIVES, new Operand[] {labelIsNeg, predIsPos});
      if (updateTN != null) {
        loopVars.put(ConfusionMatrixEnum.TRUE_NEGATIVES, new Operand[] {labelIsNeg, predIsNeg});
      }
    }

    final Operand<T> weightsTiledF = weightsTiled;
    loopVars
        .keySet()
        .forEach(
            (c) -> {
              if (variablesToUpdate.containsKey(c)) {
                Operand[] op = loopVars.get(c);
                // op[0] = label, op[1] == prediction
                controlOps.add(
                    weightedAssignAdd(tf, op[0], op[1], weightsTiledF, variablesToUpdate.get(c)));
              }
            });

    return controlOps;
  }