public static Operand softmaxCrossEntropyWithLogits()

in tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java [75:149]


  public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntropyWithLogits(
      Scope scope, Operand<U> labels, Operand<T> logits, int axis) {
    scope = scope.withSubScope("SoftmaxCrossEntropyWithLogits");
    axis = axis % logits.shape().numDimensions();
    if (axis < 0) {
      axis += logits.shape().numDimensions();
    }

    if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) {
      Operand<TFloat32> result =
          softmaxCrossEntropyWithLogits(
              scope,
              Cast.create(scope, labels, TFloat32.class),
              Cast.create(scope, logits, TFloat32.class),
              axis);
      return Cast.create(scope, result, logits.asOutput().type());
    }

    if (logits.asOutput().type() != labels.asOutput().type()) {
      return softmaxCrossEntropyWithLogits(
          scope, Cast.create(scope, labels, logits.asOutput().type()), logits, axis);
    }

    Operand<TInt64> inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class);
    Shape shape = logits.shape();

    // Move the dim to the end if dim is not the last dimension.
    if (axis != -1 && axis != logits.shape().numDimensions() - 1) {
      logits = moveDimToEnd(scope, logits, axis, inputRank);
      labels = moveDimToEnd(scope, labels, axis, inputRank);
    }

    Operand<T> tLabels;
    if (labels.type() != logits.type()) {
      tLabels = Cast.create(scope, labels, logits.type());
    } else {
      // Unchecked warning checked in if statement.
      tLabels = (Operand<T>) labels;
    }

    Shape inputShape = logits.shape();
    logits = flattenOuterDims(scope, logits);
    tLabels = flattenOuterDims(scope, tLabels);

    org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits<T> smax =
        org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits.create(scope, logits, tLabels);
    /* cannot use generic on cost, because cost may be recast later. */
    Operand<T> cost = smax.loss();
    Operand<TInt64> outputShape =
        Slice.create(
            scope,
            Constant.tensorOf(scope, inputShape),
            Constant.arrayOf(scope, 0L),
            Constant.arrayOf(scope, inputShape.numDimensions() - 1L));
    cost = Reshape.create(scope, cost, outputShape);
    if (scope.env().isGraph() && !shape.hasUnknownDimension()) {
      long[] array = shape.asArray();
      if (array == null) {
        array = new long[0];
      }
      long[] newArray = new long[array.length - 1];
      if (axis < 0) {
        axis = shape.numDimensions() + axis;
      }
      for (int i = 0; i < axis; i++) {
        newArray[i] = shape.size(i);
      }
      for (int i = axis + 1; i < shape.numDimensions(); i++) {
        newArray[i - 1] = shape.size(i);
      }
      cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray));
    }

    return cost;
  }