Operand sparseSoftmaxCrossEntropyWithLogits()

in tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java [68:159]


      Operand<T> sparseSoftmaxCrossEntropyWithLogits(
          Scope scope, Operand<U> labels, Operand<T> logits) {
    scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits");
    Operand<? extends TNumber> preciseLogits;
    if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) {
      preciseLogits = Cast.create(scope, logits, TFloat32.class);
    } else {
      preciseLogits = logits;
    }
    Shape labelsStaticShape = labels.shape();
    org.tensorflow.op.core.Shape<TInt32> labelsShape =
        org.tensorflow.op.core.Shape.create(scope, labels);
    Shape logitsShape = logits.shape();
    Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1);

    boolean staticShapesFullyDefined =
        !labelsStaticShape.hasUnknownDimension() && !logitsShortened.hasUnknownDimension();
    if (logitsShape.numDimensions() == 0) {
      throw new IllegalArgumentException(
          String.format("Logits cannot be scalars - received shape %s.", logitsShape));
    }
    if (!logitsShape.hasUnknownDimension()
        && !labelsStaticShape.hasUnknownDimension()
        && labelsStaticShape.numDimensions() != logitsShape.numDimensions() - 1) {
      throw new IllegalArgumentException(
          String.format(
              "Rank mismatch: Rank of labels (received %s) should equal rank of logits minus 1 (received %s).",
              labelsStaticShape, logitsShape));
    }

    if (staticShapesFullyDefined && !labelsStaticShape.equals(logitsShortened)) {
      throw new IllegalArgumentException(
          String.format(
              "Shape mismatch: The shape of labels (received %s) "
                  + "should equal the shape of logits except for the last "
                  + "dimension (received %s).",
              labelsStaticShape, logitsShape));
    }
    // Check if no reshapes are required.
    if (logitsShape.numDimensions() == 2) {
      org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits<? extends TNumber> smax =
          org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create(
              scope, preciseLogits, labels);
      Operand<? extends TNumber> cost = smax.loss();
      if (cost.type() != logits.type()) {
        return Cast.create(scope, cost, logits.type());
      } else {
        // Unchecked cast already checked with previous if
        return (Operand<T>) cost;
      }
    }

    List<Op> shapeChecks = new ArrayList<>();

    if (!staticShapesFullyDefined) {
      shapeChecks.add(
          AssertThat.create(
              scope,
              Equal.create(
                  scope,
                  org.tensorflow.op.core.Shape.create(scope, labels),
                  Shapes.take(
                      scope,
                      org.tensorflow.op.core.Shape.create(scope, logits),
                      Constant.scalarOf(scope, -1))),
              Collections.singletonList(
                  Constant.scalarOf(
                      scope,
                      "Shape mismatch: The shape of labels  "
                          + "should equal the shape of logits except for the last "
                          + "dimension "))));
    }

    // Reshape logits to 2 dims, labels to 1 dim.
    long numClassses = logitsShape.size(-1);

    preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses));
    labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1));
    scope.withControlDependencies(shapeChecks);
    // call raw op
    org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits<? extends TNumber> smax =
        org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create(
            scope, preciseLogits, labels);
    Operand<? extends TNumber> cost = smax.loss();
    cost = Reshape.create(scope, cost, labelsShape);
    if (cost.type() != logits.type()) {
      return Cast.create(scope, cost, logits.type());
    } else {
      // Unchecked cast already checked with previous if
      return (Operand<T>) cost;
    }
  }