in python/tvm/contrib/cutlass/gen_tensor_op.py [0:0]
def instantiate_template(func_name, annotations, func_args):
"""Return CUTLASS host code based on a template and the provided annotations.
Parameters
----------
func_name: str
A string to identify the type of the kernel (dense/matmul, batched_matmul, or conv2d).
annotations: container.Map
Key and value pairs annotated during kernel selection.
func_args: list
Names of the function arguments.
Returns
-------
codegen_result : CodegenResult
Generated CUTLASS host code and required header-file names.
"""
attrs = {}
for k in ["lda", "ldb", "ldc", "cutlass_op_def", "cutlass_op_name", "op_type"]:
if k in annotations:
attrs[k] = annotations[k]
headers = ["tvm/runtime/registry.h"]
if "relu" in func_name:
headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h")
elif "gelu" in func_name:
headers.append("cutlass/epilogue/thread/linear_combination_gelu.h")
elif "sigmoid" in func_name:
headers.append("cutlass/epilogue/thread/linear_combination_sigmoid.h")
elif "silu" in func_name:
headers.append("cutlass/epilogue/thread/linear_combination_silu.h")
elif "hardswish" in func_name:
headers.append("cutlass/epilogue/thread/linear_combination_hardswish.h")
else:
headers.append("cutlass/epilogue/thread/linear_combination.h")
if "residual" in func_name:
headers.append("cutlass/epilogue/thread/linear_combination_residual_block.h")
def get_dim(shape_annot, var_name, axis_idx, batched_offset=0):
if isinstance(shape_annot, IntImm):
return str(int(shape_annot))
return f"{var_name}->shape[{batched_offset + axis_idx}]"
def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_idx):
if isinstance(stride_annot, IntImm):
return str(int(stride_annot))
dim1 = func_args[arg0_idx] + f"->shape[{arg0_axis_idx}]"
dim2 = func_args[arg1_idx] + f"->shape[{arg1_axis_idx}]"
return dim1 + " * " + dim2
def get_flattened_batch_dim(arg_name, batch_rank):
return " * ".join(["{}->shape[{}]".format(arg_name, i) for i in range(batch_rank)])
if "decode_matmul" in func_name:
headers.append("cutlass_kernels/fpA_intB_gemm.h")
lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 0)
rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 1)
scales_arg_idx = _get_optional_int_annotation(annotations, "scales_arg_idx", 2)
bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None)
residual_arg_idx = _get_optional_int_annotation(annotations, "residual_arg_idx", None)
attrs["A_arg"] = func_args[lhs_arg_idx]
attrs["B_arg"] = func_args[rhs_arg_idx]
attrs["scales_arg"] = func_args[scales_arg_idx]
attrs["activation"] = annotations.get("activation", "identity")
attrs["bias_stride"] = annotations["bias_stride"]
attrs["M"] = annotations["M"]
attrs["group_size"] = annotations["group_size"]
if not isinstance(attrs["M"], tvm.tir.IntImm):
attrs["M"] = get_flattened_batch_dim(
func_args[lhs_arg_idx], int(annotations["batch_rank"])
)
if bias_arg_idx is not None:
attrs["bias_arg"] = func_args[bias_arg_idx]
if residual_arg_idx is not None:
attrs["residual_arg"] = func_args[residual_arg_idx]
attrs["binary_op"] = annotations["binary_op"]
attrs["unary_op"] = annotations["unary_op"]
if annotations["weight_nbit"] == 4:
attrs["weight_dtype"] = "cutlass::uint4b_t"
attrs["float_per_int"] = 2
else:
assert annotations["weight_nbit"] == 8
attrs["weight_dtype"] = "uint8_t"
attrs["float_per_int"] = 1
code = emit_fp16A_intB_matmul(attrs)
return CodegenResult(code, headers)
elif "dense" in func_name or "matmul" in func_name:
batched = "batch" in annotations
# dense is equal to transposed_matmul
transposed = "transposed" in func_name or "dense" in func_name
lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 0)
rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 1)
if "bias" in func_name:
bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", 2)
else:
bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None)
residual_arg_idx = _get_optional_int_annotation(annotations, "residual_arg_idx", None)
lhs_arg = func_args[lhs_arg_idx]
rhs_arg = func_args[rhs_arg_idx]
lhs_shape = annotations[f"arg{lhs_arg_idx}_shape"]
rhs_shape = annotations[f"arg{rhs_arg_idx}_shape"]
lhs_batched_offset = len(lhs_shape) - 2
rhs_batched_offset = len(rhs_shape) - 2
attrs["lhs_arg"] = lhs_arg
attrs["rhs_arg"] = rhs_arg
if bias_arg_idx is not None:
attrs["bias_arg"] = func_args[bias_arg_idx]
if residual_arg_idx is not None:
attrs["residual_arg"] = func_args[residual_arg_idx]
attrs["ElementInputA"] = DataTypeTag[dtype_map[annotations[f"arg{lhs_arg_idx}_dtype"]]]
attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]]
attrs["K"] = lhs_shape[lhs_batched_offset + 1]
attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0, lhs_batched_offset)
if transposed:
attrs["N"] = get_dim(rhs_shape[rhs_batched_offset], rhs_arg, 0, rhs_batched_offset)
else:
attrs["N"] = get_dim(rhs_shape[rhs_batched_offset + 1], rhs_arg, 1, rhs_batched_offset)
if batched:
headers.append("cutlass/gemm/device/gemm_batched.h")
def get_batch_on_arg(arg_name, arg_shape):
return " * ".join(
"{}->shape[{}]".format(arg_name, i) for i in range(len(arg_shape) - 2)
)
if isinstance(annotations["batch"], IntImm):
attrs["batch"] = str(int(annotations["batch"]))
elif annotations["batch_stride_A"] == 0:
# 2D x ND
attrs["batch"] = get_batch_on_arg(rhs_arg, rhs_shape)
else:
# ND x 2D or ND x ND
attrs["batch"] = get_batch_on_arg(lhs_arg, lhs_shape)
attrs["batch_stride_A"] = get_batch_stride(
annotations["batch_stride_A"],
lhs_arg_idx,
lhs_arg_idx,
lhs_batched_offset,
lhs_batched_offset + 1,
)
attrs["batch_stride_B"] = get_batch_stride(
annotations["batch_stride_B"],
rhs_arg_idx,
rhs_arg_idx,
rhs_batched_offset,
rhs_batched_offset + 1,
)
if transposed:
attrs["batch_stride_C"] = get_batch_stride(
annotations["batch_stride_C"],
lhs_arg_idx,
rhs_arg_idx,
lhs_batched_offset,
rhs_batched_offset,
)
else:
attrs["batch_stride_C"] = get_batch_stride(
annotations["batch_stride_C"],
lhs_arg_idx,
rhs_arg_idx,
lhs_batched_offset,
rhs_batched_offset + 1,
)
else:
headers.append("cutlass/gemm/device/gemm.h")
if "residual" in func_name:
headers.append("cutlass/gemm/device/gemm_universal_with_broadcast.h")
code = instantiate_gemm_template(attrs)
return CodegenResult(code, headers)
elif "conv2d" in func_name:
data_arg_idx = _get_optional_int_annotation(annotations, "data_arg_idx", 0)
weight_arg_idx = _get_optional_int_annotation(annotations, "weight_arg_idx", 1)
bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None)
residual_arg_idx = _get_optional_int_annotation(annotations, "residual_arg_idx", None)
attrs["data_arg"] = func_args[data_arg_idx]
attrs["weight_arg"] = func_args[weight_arg_idx]
if bias_arg_idx is not None:
attrs["bias_arg"] = func_args[bias_arg_idx]
if residual_arg_idx is not None:
attrs["residual_arg"] = func_args[residual_arg_idx]
activation_shape = annotations[f"arg{data_arg_idx}_shape"]
weight_shape = annotations[f"arg{weight_arg_idx}_shape"]
output_shape = annotations["ret_shape"]
if "conv2d_transpose" in func_name:
headers.append("cutlass/conv/kernel/default_conv2d_dgrad.h")
activation_shape = output_shape
output_shape = annotations["arg0_shape"]
elif "backward" in func_name:
headers.append("cutlass/conv/kernel/default_conv2d_wgrad.h")
activation_shape = annotations["arg1_shape"]
weight_shape = output_shape
output_shape = annotations["arg0_shape"]
elif "residual" in func_name:
headers.append("cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h")
else:
headers.append("cutlass/conv/kernel/default_conv2d_fprop.h")
headers.append("cutlass/conv/device/implicit_gemm_convolution.h")
op_name = attrs["cutlass_op_name"]
if "splitk" in op_name:
headers += [
"cutlass/reduction/device/reduce_split_k.h",
"cutlass/reduction/thread/reduction_operators.h",
]
data_arg = attrs["data_arg"]
attrs["N"] = get_dim(activation_shape[0], data_arg, 0)
attrs["H"] = get_dim(activation_shape[1], data_arg, 1)
attrs["W"] = get_dim(activation_shape[2], data_arg, 2)
attrs["C"] = activation_shape[3]
attrs["P"] = get_dim(output_shape[1], "out0", 1)
attrs["Q"] = get_dim(output_shape[2], "out0", 2)
attrs["K"] = output_shape[3]
attrs["R"] = weight_shape[1]
attrs["S"] = weight_shape[2]
attrs["pad_h"] = annotations["padding"][0]
attrs["pad_w"] = annotations["padding"][1]
attrs["stride_h"] = annotations["strides"][0]
attrs["stride_w"] = annotations["strides"][1]
attrs["dilation_h"] = annotations["dilation"][0]
attrs["dilation_w"] = annotations["dilation"][1]
if "splitk" in op_name:
attrs["split_k_mode"] = "kParallel"
attrs["split_k_slices"] = str(re.search(r"splitk(\d+)", op_name).group(1))
else:
attrs["split_k_mode"] = "kSerial"
attrs["split_k_slices"] = 1
if "residual_shape" in annotations:
attrs["residual_shape"] = annotations["residual_shape"]
code = instantiate_conv2d_template(attrs)
return CodegenResult(code, headers)
elif "attention" in func_name:
is_var_len = "var_len" in func_name
data_type = dtype_map[annotations["arg0_dtype"]]
attrs["qkv_layout"] = annotations["qkv_layout"]
if attrs["qkv_layout"] == "default":
attrs["query"] = func_args[0]
attrs["key"] = func_args[1]
attrs["value"] = func_args[2]
attrs["num_queries"] = s = get_dim(annotations["num_queries"], func_args[0], 1)
attrs["num_keys"] = get_dim(annotations["num_keys"], func_args[1], 1)
if len(func_args) > 4 and not is_var_len: # +1 for workspace, the last arg
attrs["bias"] = func_args[3]
elif attrs["qkv_layout"] == "qkv_stacked":
attrs["qkv"] = func_args[0]
attrs["num_queries"] = s = annotations["num_queries"]
attrs["num_keys"] = annotations["num_keys"]
if len(func_args) > 2 and not is_var_len: # +1 for workspace, the last arg
attrs["bias"] = func_args[1]
else:
raise NotImplementedError()
attrs["data_type"] = DataTypeTag[data_type]
attrs["num_batches"] = b = annotations["num_batches"]
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"]))
attrs["scale"] = (
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"]
)
if is_var_len:
attrs["seqstart_q"] = func_args[int(annotations["seqstart_q_idx"])]
attrs["seqstart_k"] = func_args[int(annotations["seqstart_k_idx"])]
attrs["max_seqlen_q"] = func_args[int(annotations["max_seqlen_q_idx"])]
attrs["max_seqlen_k"] = func_args[int(annotations["max_seqlen_k_idx"])]
is_mqa = annotations["num_q_heads"] != annotations["num_kv_heads"]
use_flash = (
annotations["ret_dtype"] == "float16"
and "bias" not in attrs
and int(attrs["head_dim"]) <= 256
and int(attrs["head_dim"]) % 8 == 0
and int(attrs["head_dim"]) == int(attrs["head_dim_value"])
# For the causal case (custom mask = "BottomRight"), only use flash for multi-query
# attention workloads. Otherwise, CUTLASS fMHA seems faster for causal attention
# with a single query.
# In addition, sliding-window attention is only supported by flash.
and (
int(annotations["custom_mask_type"]) == 0
or (int(annotations["custom_mask_type"]) == 2 and is_mqa)
or (int(annotations["custom_mask_type"]) == 2 and "window_size" in annotations)
)
# Flash v2 is currently not supported for sm < 80
and int(annotations["arch"]) >= 80
)
# See https://github.com/Dao-AILab/flash-attention/blob/
# 92dd5703ecdb99aa4a4aee9817f28557907403a2/csrc/flash_attn/flash_api.cpp#L111-L116
if "window_size" in annotations:
assert use_flash, "Sliding-window attention is supported only by Flash Attention."
assert (
int(annotations["custom_mask_type"]) == 2
), "Sliding-window attention is only supported for causal with bottom right mask."
attrs["window_size_left"] = int(annotations["window_size"]) - 1
attrs["window_size_right"] = 0
attrs["is_causal"] = False
else:
if int(annotations["custom_mask_type"]) == 2:
attrs["window_size_left"] = attrs["num_keys"]
attrs["window_size_right"] = 0
attrs["is_causal"] = True
else:
attrs["window_size_left"] = -1
attrs["window_size_right"] = -1
attrs["is_causal"] = False
if use_flash:
headers.append("flash.h")
attrs["num_q_heads"] = annotations["num_q_heads"]
attrs["num_kv_heads"] = annotations["num_kv_heads"]
if is_var_len:
code = instantiate_flash_attention_var_len_template(attrs)
else:
code = instantiate_flash_attention_template(attrs)
else:
headers.append("kernel_forward.h")
assert (
not is_mqa
), "The number of query and KV heads need to be the same for CUTLASS fMHA."
attrs["num_heads"] = n = annotations["num_q_heads"]
data_type_size = DataTypeSize[data_type]
if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
attrs["kIsAligned"] = True
elif (h % 4 == 0) and (h_v % 4 == 0):
attrs["kIsAligned"] = False
else:
raise NotImplementedError()
if h_v > 64:
attrs["kQueriesPerBlock"] = 32
attrs["kKeysPerBlock"] = 128
attrs["kSingleValueIteration"] = h_v <= 128
else:
attrs["kQueriesPerBlock"] = 64
attrs["kKeysPerBlock"] = 64
attrs["kSingleValueIteration"] = True
assert (
attrs["scale"] > 0 or attrs["scale"] < 0
), "Cutlass may generate nan occasionally when scale == 0.0"
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False
attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
attrs["custom_mask_type"] = annotations["custom_mask_type"]
for arg in func_args:
if "workspace" in arg:
attrs["workspace"] = arg
if "bias" in attrs:
attrs["kSupportsBias"] = True
if len(annotations["bias_shape"]) == 4:
strides = "p.num_keys"
if annotations["bias_shape"][2] == 1:
attrs["bias_strideM"] = 0
else:
attrs["bias_strideM"] = strides
strides = f"p.num_queries * {strides}"
if annotations["bias_shape"][1] == 1:
attrs["bias_strideH"] = 0
else:
attrs["bias_strideH"] = strides
strides = f"p.num_heads * {strides}"
if annotations["bias_shape"][0] == 1:
attrs["bias_strideB"] = 0
else:
attrs["bias_strideB"] = strides
else:
raise NotImplementedError()
else:
# To support negative scale in current Cutlass implementation,
# kSupportsBias should be set true, or there are nan's as result.
attrs["kSupportsBias"] = attrs["scale"] < 0
code = instantiate_attention_template(attrs)
return CodegenResult(code, headers)
elif "layer_norm" in func_name:
headers.append("cutlass/util/device_layernorm.h")
headers.append("cutlass/layout/matrix.h")
attrs = {"input": func_args[0], "gamma": func_args[1], "beta": func_args[2]}
attrs.update(dict(annotations))
if not isinstance(attrs["M"], tvm.tir.IntImm):
attrs["M"] = get_flattened_batch_dim(func_args[0], int(attrs["batch_rank"]))
code = instantiate_layer_norm_template(attrs)
return CodegenResult(code, headers)
elif "rms_norm" in func_name:
headers.append("cutlass/util/device_rmsnorm.h")
headers.append("cutlass/layout/matrix.h")
attrs = {"input": func_args[0], "weight": func_args[1]}
attrs.update(dict(annotations))
if not isinstance(attrs["M"], tvm.tir.IntImm):
attrs["M"] = get_flattened_batch_dim(func_args[0], int(attrs["batch_rank"]))
code = instantiate_rms_norm_template(attrs)
return CodegenResult(code, headers)
raise ValueError(f"Do not have a template for {func_name}")