in python/tvm/relax/backend/cuda/cublas.py [0:0]
def _check_matmul(context: PatternCheckContext) -> bool:
if has_leaking_intermediate_variables(context):
return False
lhs = context.annotated_expr["lhs"]
rhs = context.annotated_expr["rhs"]
matmul_call = context.annotated_expr["root"]
if "scale" in context.annotated_expr and "zp" in context.annotated_expr:
scale = context.annotated_expr["scale"]
zero_point = context.annotated_expr["zp"]
# Only scalar values for scale and zero_point are supported.
if scale.struct_info.ndim != 0 or zero_point.struct_info.ndim != 0:
return False
# Only zero_point == 0.0 is supported.
if zero_point.data.numpy()[()].item() != 0.0:
return False
lhs_dtype = lhs.struct_info.dtype
rhs_dtype = rhs.struct_info.dtype
out_dtype = matmul_call.struct_info.dtype
if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
return False
lhs_shape = lhs.struct_info.shape.values
rhs_shape = rhs.struct_info.shape.values
if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)):
# Reduction axis must be constant
return False
if lhs_dtype == "int8" and rhs_dtype == "int8":
if lhs_shape[-1] % 4 != 0:
# Reduction axis must be multiples of 4 for IGEMM
return False
if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 4 != 0:
# Rows number must be multiples of 4 for IGEMM
return False
elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
matmul_rhs_var = matmul_call.args[1]
rhs_transposed = False
if matmul_rhs_var in context.matched_bindings:
matmul_rhs_call = context.matched_bindings[matmul_rhs_var]
assert (
isinstance(matmul_rhs_call, tvm.relax.Call)
and matmul_rhs_call.op.name == "relax.permute_dims"
)
rhs_transposed = True
if not rhs_transposed:
# cuBLAS FP8 operations require rhs being transposed
return False
# cuBLAS FP8 operations require all tensors being aligned to 16 bytes.
if (
not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int))
or rhs_shape[-1] % (16 // DataType(lhs_dtype).itemsize) != 0
):
return False
if (
not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int))
or rhs_shape[-2] % (16 // DataType(out_dtype).itemsize) != 0
):
return False
lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
if "bias" in context.annotated_expr:
if lhs_dtype == "int8" and rhs_dtype == "int8":
# Non-default epilogue not supported for IGEMM
return False
bias = context.annotated_expr["bias"]
bias_shape = bias.struct_info.shape.values
bias_batches = reduce(operator.mul, bias_shape[:-1], 1)
if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or int(bias_batches) > 1:
# cuBLAS only supports bias vector
return False
analyzer = Analyzer()
# cuBLASLt does not seem to support batched GEMM with one of matrices having
# one batch (with batch_stride 0). So for batched GEMM, the two batch counts
# must be equal. If lhs is batched but rhs is not, we can use the regular GEMM by
# flattening all batch axes into the M axis.
return (
isinstance(lhs_batches, tvm.tir.Var)
or isinstance(rhs_batches, tvm.tir.Var)
or (analyzer.can_prove_equal(lhs_batches, rhs_batches))
or (analyzer.can_prove(lhs_batches >= 1) and analyzer.can_prove(rhs_batches == 1))
)