public static Operand create()

in tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java [60:142]


  public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor, Operand<TBool> mask,
      Operand<T> updates,
      Options... options) {

    scope = scope.withNameAsSubScope("BooleanMaskUpdate");

    int axis = 0;
    boolean broadcast = true;
    if (options != null) {
      for (Options opts : options) {
        if (opts.axis != null) {
          axis = opts.axis;
        }
        if (opts.broadcast != null) {
          broadcast = opts.broadcast;
        }
      }
    }

    if (axis < 0) {
      axis += tensor.rank();
    }

    Shape maskShape = mask.shape();
    Shape tensorShape = tensor.shape();

    if (maskShape.numDimensions() == 0) {
      throw new IllegalArgumentException("Mask cannot be a scalar.");
    }
    if (maskShape.hasUnknownDimension()) {
      throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
    }

    Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions());
    if (!requiredMaskShape.isCompatibleWith(maskShape)) {
      throw new IllegalArgumentException(
          "Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + ".");
    }

    Operand<TInt32> liveShape = org.tensorflow.op.core.Shape.create(scope, tensor);

    Operand<TInt32> leadingSize = ReduceProd.create(scope,
        StridedSliceHelper.stridedSlice(scope,
            liveShape,
            Indices.sliceTo(axis + maskShape.numDimensions())
        ),
        Constant.arrayOf(scope, 0)
    );

    Operand<TInt32> innerShape = StridedSliceHelper
        .stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()));

    Operand<T> reshaped = Reshape.create(scope, tensor, Concat.create(
        scope,
        Arrays.asList(
            Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)),
            innerShape
        ),
        Constant.scalarOf(scope, 0)
    ));

    Operand<TInt64> indices = Where.create(scope, mask);

    if (broadcast) {
      Operand<TInt32> indicesShape = org.tensorflow.op.core.Shape.create(scope, indices);
      // this is the number of true values
      Operand<TInt32> batchShape = StridedSliceHelper.stridedSlice(scope, indicesShape, Indices.sliceTo(-1));

      Operand<TInt32> updateShape = Concat.create(
          scope,
          Arrays.asList(
              batchShape,
              innerShape
          ),
          Constant.scalarOf(scope, 0)
      );

      updates = BroadcastTo.create(scope, updates, updateShape);
    }

    Operand<T> newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, updates);
    return Reshape.create(scope, newValue, liveShape);
  }