public List updateStateList()

in tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java [125:203]


  public List<Op> updateStateList(
      Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) {

    if (values == null) {
      throw new IllegalArgumentException("values is required.");
    }
    init(tf);
    List<Op> updateOperations = new ArrayList<>();
    // cast everything to match the variables
    Operand<T> tSampleWeights = null;
    Operand<T> tValues = cast(tf, values, getInternalType());

    if (sampleWeights != null) {
      tSampleWeights = cast(tf, sampleWeights, getInternalType());
      // Update dimensions of weights to match with values if possible.
      LossTuple<T> tuple =
          LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights);
      tValues = tuple.getTarget();
      tSampleWeights = tuple.getSampleWeights();
      try {
        // Broadcast weights if possible
        tSampleWeights = MetricsHelper.broadcastWeights(tf, tSampleWeights, tValues);
      } catch (IllegalArgumentException ex) {
        // reduce values to same ndim as weight array
        // if we get here we have static shapes with either
        // different ranks or different dimension sizes.
        // first, reduce the values down to the rank of the samples
        int valuesRank = tValues.shape().numDimensions();
        int weightsRank = tSampleWeights.shape().numDimensions();
        int numAxes = Math.min(0, valuesRank - weightsRank);
        if (numAxes
            > 0) { // values rank is greater than weights rank, reduce values to weights rank.
          int[] axes = new int[numAxes];
          for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank;
          if (reduction == MetricReduction.SUM) {
            tValues = tf.reduceSum(tValues, tf.constant(axes));
          } else {
            tValues = tf.math.mean(tValues, tf.constant(axes));
          }
        }
      }
      tValues = tf.math.mul(tValues, tSampleWeights);
    }

    Operand<? extends TNumber> weightedValueSum =
        tf.reduceSum(tValues, LossesHelper.allAxes(tf, tValues));
    Operand<T> totalUpdate = tf.assignAdd(total, cast(tf, weightedValueSum, total.type()));
    updateOperations.add(totalUpdate);
    Operand<T> numValues;
    // Exit early if the reduction doesn't have a denominator.
    if (reduction != MetricReduction.SUM) {
      // Update `count` for reductions that require a denominator.
      switch (reduction) {
        case SUM_OVER_BATCH_SIZE:
          numValues = cast(tf, tf.constant(tValues.shape().size()), internalType);
          break;
        case WEIGHTED_MEAN:
          if (tSampleWeights == null) {
            numValues = cast(tf, tf.constant(tValues.shape().size()), internalType);
          } else {
            numValues =
                cast(
                    tf,
                    tf.reduceSum(tSampleWeights, LossesHelper.allAxes(tf, tSampleWeights)),
                    internalType);
          }
          break;
        default:
          throw new UnsupportedOperationException(
              String.format("reduction [%s] not implemented", reduction));
      }

      Operand<T> totalCount = tf.assignAdd(this.count, numValues);

      updateOperations.add(totalCount);
    }

    return updateOperations;
  }