in tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java [193:289]
public static <T extends TNumber> Operand<T> matmul(
Scope scope,
Operand<T> a,
Operand<T> b,
boolean transposeA,
boolean transposeB,
boolean adjointA,
boolean adjointB,
boolean aIsSparse,
boolean bIsSparse) {
Scope lscope = scope.withSubScope("MatMul");
if (transposeA && adjointA)
throw new IllegalArgumentException("Only one of transposeA and adjointA can be true.");
if (transposeB && adjointB)
throw new IllegalArgumentException("Only one of transposeB and adjointB can be true.");
if (!(TFloating.class.isAssignableFrom(a.type()) || a.type().equals(TInt32.class)))
throw new IllegalArgumentException(
String.format(
"Operand 'a' must be of type 'TBfloat16','TFloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s",
a.type().getSimpleName()));
if (!(TFloating.class.isAssignableFrom(a.type()) || b.type().equals(TInt32.class)))
throw new IllegalArgumentException(
String.format(
"Operand 'b' must be of type 'TBfloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s",
b.type().getSimpleName()));
Shape aShape = a.shape();
Shape bShape = b.shape();
if (aShape.numDimensions() != bShape.numDimensions())
throw new IllegalArgumentException(
String.format(
"Parameters 'a' and 'b' must the same rank: found a rank = %d, b rank = %d",
aShape.numDimensions(), bShape.numDimensions()));
boolean outputMayHaveNonEmptyBatchShape =
aShape.numDimensions() == Shape.UNKNOWN_SIZE
|| aShape.numDimensions() > 2
|| bShape.numDimensions() == Shape.UNKNOWN_SIZE;
if ((!aIsSparse && !bIsSparse) && outputMayHaveNonEmptyBatchShape) {
// BatchMatmul does not support transpose, so we conjugate the matrix and
// use adjoint instead. Conj() is a noop for real matrices.
if (transposeA) {
a = Conj.create(scope, a);
}
if (transposeB) {
b = Conj.create(scope, b);
}
return org.tensorflow.op.linalg.MatMul.create(
lscope,
a,
b,
org.tensorflow.op.linalg.MatMul.transposeA(transposeA),
org.tensorflow.op.linalg.MatMul.transposeB(transposeB));
}
// Neither matmul nor sparse_matmul support adjoint, so we conjugate
// the matrix and use transpose instead. Conj() is a noop for real
// matrices.
if (adjointA) {
a = Conj.create(scope, a);
transposeA = true;
}
if (adjointB) {
b = Conj.create(scope, b);
transposeB = true;
}
boolean useSparseMatmul = false;
if (aIsSparse || bIsSparse) {
useSparseMatmul =
(a.type().equals(TBfloat16.class) || a.type().equals(TFloat32.class))
&& (b.type().equals(TBfloat16.class) || b.type().equals(TFloat32.class));
}
if ((a.type().equals(TBfloat16.class) || b.type().equals(TBfloat16.class))
&& !a.type().equals(b.type())) useSparseMatmul = true;
if (useSparseMatmul) {
Operand<TFloat32> result =
SparseMatMul.create(
lscope,
a,
b,
SparseMatMul.transposeA(transposeA),
SparseMatMul.transposeB(transposeB),
SparseMatMul.aIsSparse(aIsSparse),
SparseMatMul.bIsSparse(bIsSparse));
if (a.type().equals(TFloat32.class)) return (Operand<T>) result;
else return Cast.create(scope, result, a.type());
}
return org.tensorflow.op.linalg.MatMul.create(
lscope,
a,
b,
org.tensorflow.op.linalg.MatMul.transposeA(transposeA),
org.tensorflow.op.linalg.MatMul.transposeB(transposeB));
}