in tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java [60:142]
public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor, Operand<TBool> mask,
Operand<T> updates,
Options... options) {
scope = scope.withNameAsSubScope("BooleanMaskUpdate");
int axis = 0;
boolean broadcast = true;
if (options != null) {
for (Options opts : options) {
if (opts.axis != null) {
axis = opts.axis;
}
if (opts.broadcast != null) {
broadcast = opts.broadcast;
}
}
}
if (axis < 0) {
axis += tensor.rank();
}
Shape maskShape = mask.shape();
Shape tensorShape = tensor.shape();
if (maskShape.numDimensions() == 0) {
throw new IllegalArgumentException("Mask cannot be a scalar.");
}
if (maskShape.hasUnknownDimension()) {
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
}
Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions());
if (!requiredMaskShape.isCompatibleWith(maskShape)) {
throw new IllegalArgumentException(
"Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + ".");
}
Operand<TInt32> liveShape = org.tensorflow.op.core.Shape.create(scope, tensor);
Operand<TInt32> leadingSize = ReduceProd.create(scope,
StridedSliceHelper.stridedSlice(scope,
liveShape,
Indices.sliceTo(axis + maskShape.numDimensions())
),
Constant.arrayOf(scope, 0)
);
Operand<TInt32> innerShape = StridedSliceHelper
.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()));
Operand<T> reshaped = Reshape.create(scope, tensor, Concat.create(
scope,
Arrays.asList(
Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)),
innerShape
),
Constant.scalarOf(scope, 0)
));
Operand<TInt64> indices = Where.create(scope, mask);
if (broadcast) {
Operand<TInt32> indicesShape = org.tensorflow.op.core.Shape.create(scope, indices);
// this is the number of true values
Operand<TInt32> batchShape = StridedSliceHelper.stridedSlice(scope, indicesShape, Indices.sliceTo(-1));
Operand<TInt32> updateShape = Concat.create(
scope,
Arrays.asList(
batchShape,
innerShape
),
Constant.scalarOf(scope, 0)
);
updates = BroadcastTo.create(scope, updates, updateShape);
}
Operand<T> newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, updates);
return Reshape.create(scope, newValue, liveShape);
}