def _check_matmul()

in python/tvm/relax/backend/cuda/cublas.py [0:0]


def _check_matmul(context: PatternCheckContext) -> bool:
    if has_leaking_intermediate_variables(context):
        return False
    lhs = context.annotated_expr["lhs"]
    rhs = context.annotated_expr["rhs"]
    matmul_call = context.annotated_expr["root"]

    if "scale" in context.annotated_expr and "zp" in context.annotated_expr:
        scale = context.annotated_expr["scale"]
        zero_point = context.annotated_expr["zp"]
        # Only scalar values for scale and zero_point are supported.
        if scale.struct_info.ndim != 0 or zero_point.struct_info.ndim != 0:
            return False
        # Only zero_point == 0.0 is supported.
        if zero_point.data.numpy()[()].item() != 0.0:
            return False

    lhs_dtype = lhs.struct_info.dtype
    rhs_dtype = rhs.struct_info.dtype
    out_dtype = matmul_call.struct_info.dtype
    if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
        return False

    lhs_shape = lhs.struct_info.shape.values
    rhs_shape = rhs.struct_info.shape.values

    if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)):
        # Reduction axis must be constant
        return False

    if lhs_dtype == "int8" and rhs_dtype == "int8":
        if lhs_shape[-1] % 4 != 0:
            # Reduction axis must be multiples of 4 for IGEMM
            return False
        if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 4 != 0:
            # Rows number must be multiples of 4 for IGEMM
            return False
    elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
        matmul_rhs_var = matmul_call.args[1]
        rhs_transposed = False
        if matmul_rhs_var in context.matched_bindings:
            matmul_rhs_call = context.matched_bindings[matmul_rhs_var]
            assert (
                isinstance(matmul_rhs_call, tvm.relax.Call)
                and matmul_rhs_call.op.name == "relax.permute_dims"
            )
            rhs_transposed = True

        if not rhs_transposed:
            # cuBLAS FP8 operations require rhs being transposed
            return False

        # cuBLAS FP8 operations require all tensors being aligned to 16 bytes.
        if (
            not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int))
            or rhs_shape[-1] % (16 // DataType(lhs_dtype).itemsize) != 0
        ):
            return False
        if (
            not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int))
            or rhs_shape[-2] % (16 // DataType(out_dtype).itemsize) != 0
        ):
            return False

    lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
    rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)

    if "bias" in context.annotated_expr:
        if lhs_dtype == "int8" and rhs_dtype == "int8":
            # Non-default epilogue not supported for IGEMM
            return False
        bias = context.annotated_expr["bias"]
        bias_shape = bias.struct_info.shape.values
        bias_batches = reduce(operator.mul, bias_shape[:-1], 1)
        if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or int(bias_batches) > 1:
            # cuBLAS only supports bias vector
            return False

    analyzer = Analyzer()

    # cuBLASLt does not seem to support batched GEMM with one of matrices having
    # one batch (with batch_stride 0). So for batched GEMM, the two batch counts
    # must be equal. If lhs is batched but rhs is not, we can use the regular GEMM by
    # flattening all batch axes into the M axis.
    return (
        isinstance(lhs_batches, tvm.tir.Var)
        or isinstance(rhs_batches, tvm.tir.Var)
        or (analyzer.can_prove_equal(lhs_batches, rhs_batches))
        or (analyzer.can_prove(lhs_batches >= 1) and analyzer.can_prove(rhs_batches == 1))
    )