in tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java [68:159]
Operand<T> sparseSoftmaxCrossEntropyWithLogits(
Scope scope, Operand<U> labels, Operand<T> logits) {
scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits");
Operand<? extends TNumber> preciseLogits;
if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) {
preciseLogits = Cast.create(scope, logits, TFloat32.class);
} else {
preciseLogits = logits;
}
Shape labelsStaticShape = labels.shape();
org.tensorflow.op.core.Shape<TInt32> labelsShape =
org.tensorflow.op.core.Shape.create(scope, labels);
Shape logitsShape = logits.shape();
Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1);
boolean staticShapesFullyDefined =
!labelsStaticShape.hasUnknownDimension() && !logitsShortened.hasUnknownDimension();
if (logitsShape.numDimensions() == 0) {
throw new IllegalArgumentException(
String.format("Logits cannot be scalars - received shape %s.", logitsShape));
}
if (!logitsShape.hasUnknownDimension()
&& !labelsStaticShape.hasUnknownDimension()
&& labelsStaticShape.numDimensions() != logitsShape.numDimensions() - 1) {
throw new IllegalArgumentException(
String.format(
"Rank mismatch: Rank of labels (received %s) should equal rank of logits minus 1 (received %s).",
labelsStaticShape, logitsShape));
}
if (staticShapesFullyDefined && !labelsStaticShape.equals(logitsShortened)) {
throw new IllegalArgumentException(
String.format(
"Shape mismatch: The shape of labels (received %s) "
+ "should equal the shape of logits except for the last "
+ "dimension (received %s).",
labelsStaticShape, logitsShape));
}
// Check if no reshapes are required.
if (logitsShape.numDimensions() == 2) {
org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits<? extends TNumber> smax =
org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create(
scope, preciseLogits, labels);
Operand<? extends TNumber> cost = smax.loss();
if (cost.type() != logits.type()) {
return Cast.create(scope, cost, logits.type());
} else {
// Unchecked cast already checked with previous if
return (Operand<T>) cost;
}
}
List<Op> shapeChecks = new ArrayList<>();
if (!staticShapesFullyDefined) {
shapeChecks.add(
AssertThat.create(
scope,
Equal.create(
scope,
org.tensorflow.op.core.Shape.create(scope, labels),
Shapes.take(
scope,
org.tensorflow.op.core.Shape.create(scope, logits),
Constant.scalarOf(scope, -1))),
Collections.singletonList(
Constant.scalarOf(
scope,
"Shape mismatch: The shape of labels "
+ "should equal the shape of logits except for the last "
+ "dimension "))));
}
// Reshape logits to 2 dims, labels to 1 dim.
long numClassses = logitsShape.size(-1);
preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses));
labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1));
scope.withControlDependencies(shapeChecks);
// call raw op
org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits<? extends TNumber> smax =
org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create(
scope, preciseLogits, labels);
Operand<? extends TNumber> cost = smax.loss();
cost = Reshape.create(scope, cost, labelsShape);
if (cost.type() != logits.type()) {
return Cast.create(scope, cost, logits.type());
} else {
// Unchecked cast already checked with previous if
return (Operand<T>) cost;
}
}