public static Operand matmul()

in tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java [193:289]


  public static <T extends TNumber> Operand<T> matmul(
      Scope scope,
      Operand<T> a,
      Operand<T> b,
      boolean transposeA,
      boolean transposeB,
      boolean adjointA,
      boolean adjointB,
      boolean aIsSparse,
      boolean bIsSparse) {
    Scope lscope = scope.withSubScope("MatMul");
    if (transposeA && adjointA)
      throw new IllegalArgumentException("Only one of transposeA and adjointA can be true.");
    if (transposeB && adjointB)
      throw new IllegalArgumentException("Only one of transposeB and adjointB can be true.");
    if (!(TFloating.class.isAssignableFrom(a.type()) || a.type().equals(TInt32.class)))
      throw new IllegalArgumentException(
          String.format(
              "Operand 'a' must be of type 'TBfloat16','TFloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s",
              a.type().getSimpleName()));
    if (!(TFloating.class.isAssignableFrom(a.type()) || b.type().equals(TInt32.class)))
      throw new IllegalArgumentException(
          String.format(
              "Operand 'b' must be of type 'TBfloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s",
              b.type().getSimpleName()));

    Shape aShape = a.shape();
    Shape bShape = b.shape();
    if (aShape.numDimensions() != bShape.numDimensions())
      throw new IllegalArgumentException(
          String.format(
              "Parameters 'a' and 'b' must the same rank: found a rank = %d, b rank = %d",
              aShape.numDimensions(), bShape.numDimensions()));
    boolean outputMayHaveNonEmptyBatchShape =
        aShape.numDimensions() == Shape.UNKNOWN_SIZE
            || aShape.numDimensions() > 2
            || bShape.numDimensions() == Shape.UNKNOWN_SIZE;

    if ((!aIsSparse && !bIsSparse) && outputMayHaveNonEmptyBatchShape) {
      // BatchMatmul does not support transpose, so we conjugate the matrix and
      // use adjoint instead. Conj() is a noop for real matrices.
      if (transposeA) {
        a = Conj.create(scope, a);
      }
      if (transposeB) {
        b = Conj.create(scope, b);
      }
      return org.tensorflow.op.linalg.MatMul.create(
          lscope,
          a,
          b,
          org.tensorflow.op.linalg.MatMul.transposeA(transposeA),
          org.tensorflow.op.linalg.MatMul.transposeB(transposeB));
    }

    // Neither matmul nor sparse_matmul support adjoint, so we conjugate
    // the matrix and use transpose instead. Conj() is a noop for real
    // matrices.
    if (adjointA) {
      a = Conj.create(scope, a);
      transposeA = true;
    }
    if (adjointB) {
      b = Conj.create(scope, b);
      transposeB = true;
    }

    boolean useSparseMatmul = false;
    if (aIsSparse || bIsSparse) {
      useSparseMatmul =
          (a.type().equals(TBfloat16.class) || a.type().equals(TFloat32.class))
              && (b.type().equals(TBfloat16.class) || b.type().equals(TFloat32.class));
    }
    if ((a.type().equals(TBfloat16.class) || b.type().equals(TBfloat16.class))
        && !a.type().equals(b.type())) useSparseMatmul = true;

    if (useSparseMatmul) {
      Operand<TFloat32> result =
          SparseMatMul.create(
              lscope,
              a,
              b,
              SparseMatMul.transposeA(transposeA),
              SparseMatMul.transposeB(transposeB),
              SparseMatMul.aIsSparse(aIsSparse),
              SparseMatMul.bIsSparse(bIsSparse));
      if (a.type().equals(TFloat32.class)) return (Operand<T>) result;
      else return Cast.create(scope, result, a.type());
    }

    return org.tensorflow.op.linalg.MatMul.create(
        lscope,
        a,
        b,
        org.tensorflow.op.linalg.MatMul.transposeA(transposeA),
        org.tensorflow.op.linalg.MatMul.transposeB(transposeB));
  }