in tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java [75:149]
public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntropyWithLogits(
Scope scope, Operand<U> labels, Operand<T> logits, int axis) {
scope = scope.withSubScope("SoftmaxCrossEntropyWithLogits");
axis = axis % logits.shape().numDimensions();
if (axis < 0) {
axis += logits.shape().numDimensions();
}
if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) {
Operand<TFloat32> result =
softmaxCrossEntropyWithLogits(
scope,
Cast.create(scope, labels, TFloat32.class),
Cast.create(scope, logits, TFloat32.class),
axis);
return Cast.create(scope, result, logits.asOutput().type());
}
if (logits.asOutput().type() != labels.asOutput().type()) {
return softmaxCrossEntropyWithLogits(
scope, Cast.create(scope, labels, logits.asOutput().type()), logits, axis);
}
Operand<TInt64> inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class);
Shape shape = logits.shape();
// Move the dim to the end if dim is not the last dimension.
if (axis != -1 && axis != logits.shape().numDimensions() - 1) {
logits = moveDimToEnd(scope, logits, axis, inputRank);
labels = moveDimToEnd(scope, labels, axis, inputRank);
}
Operand<T> tLabels;
if (labels.type() != logits.type()) {
tLabels = Cast.create(scope, labels, logits.type());
} else {
// Unchecked warning checked in if statement.
tLabels = (Operand<T>) labels;
}
Shape inputShape = logits.shape();
logits = flattenOuterDims(scope, logits);
tLabels = flattenOuterDims(scope, tLabels);
org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits<T> smax =
org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits.create(scope, logits, tLabels);
/* cannot use generic on cost, because cost may be recast later. */
Operand<T> cost = smax.loss();
Operand<TInt64> outputShape =
Slice.create(
scope,
Constant.tensorOf(scope, inputShape),
Constant.arrayOf(scope, 0L),
Constant.arrayOf(scope, inputShape.numDimensions() - 1L));
cost = Reshape.create(scope, cost, outputShape);
if (scope.env().isGraph() && !shape.hasUnknownDimension()) {
long[] array = shape.asArray();
if (array == null) {
array = new long[0];
}
long[] newArray = new long[array.length - 1];
if (axis < 0) {
axis = shape.numDimensions() + axis;
}
for (int i = 0; i < axis; i++) {
newArray[i] = shape.size(i);
}
for (int i = axis + 1; i < shape.numDimensions(); i++) {
newArray[i - 1] = shape.size(i);
}
cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray));
}
return cost;
}