in python/tvm/relax/backend/adreno/clml.py [0:0]
def clml_pattern_table():
"""Get the CLML pattern table."""
def _check_conv2d(context: PatternCheckContext) -> bool:
if "root" in context.annotated_expr:
root_call = context.annotated_expr["root"]
if root_call.op.name == "relax.nn.conv2d":
input_layout = root_call.attrs.data_layout
weight_layout = root_call.attrs.kernel_layout
if input_layout != "NCHW" or weight_layout != "OIHW":
return False
if root_call.op.name == "relax.nn.conv2d_transpose":
input_layout = root_call.attrs.data_layout
weight_layout = root_call.attrs.kernel_layout
if input_layout != "NCHW" or weight_layout != "OIHW":
return False
if "data" in context.annotated_expr:
input_expr = context.annotated_expr["data"]
input_dtype = input_expr.struct_info.dtype
if input_dtype not in ["float32", "float16"]:
return False
if "weight" in context.annotated_expr:
weight_expr = context.annotated_expr["weight"]
weight_dtype = weight_expr.struct_info.dtype
if weight_dtype not in ["float32", "float16"]:
return False
return True
def populate_patterns(patterns, name, op, annotations, *args):
ret = {}
for k, v in patterns.items():
ret_ann = v["annotation"].copy()
ret_ann.update(annotations)
ret[name + "." + k] = {"pattern": op(v["pattern"], *args), "annotation": ret_ann.copy()}
return ret
def conv_pattern():
"""Create a convolution pattern."""
data = wildcard()
weight = wildcard()
bias = is_const()
bn_scale = is_const()
bn_bias = is_const()
bn_mean = is_const()
bn_var = is_const()
annotations = {
"data": data,
"weight": weight,
}
patterns = {}
patterns["nn.conv2d"] = {
"pattern": is_op("relax.nn.conv2d")(data, weight),
"annotation": annotations.copy(),
}
pad_annotations = annotations.copy()
patterns["pad.nn.conv2d"] = {
"pattern": is_op("relax.nn.conv2d")(is_op("relax.nn.pad")(data), weight),
"annotation": pad_annotations,
}
patterns["nn.conv2d_transpose"] = {
"pattern": is_op("relax.nn.conv2d_transpose")(data, weight),
"annotation": annotations.copy(),
}
patterns.update(
populate_patterns(patterns, "bias", is_op("relax.add"), {"bias": bias}, bias)
)
patterns.update(
populate_patterns(
patterns,
"bn",
is_op("relax.nn.batch_norm"),
{
"bn_scale": bn_scale,
"bn_bias": bn_bias,
"bn_mean": bn_mean,
"bn_var": bn_var,
},
bn_scale,
bn_bias,
bn_mean,
bn_var,
)
)
tuple_patterns = {}
for k, v in patterns.items():
tuple_annotation = v["annotation"].copy()
tuple_patterns["tuple" + "." + k] = {
"pattern": is_tuple_get_item(v["pattern"], 0),
"annotation": tuple_annotation,
}
patterns.update(tuple_patterns)
relu_patterns = populate_patterns(patterns, "relu", is_op("relax.nn.relu"), {})
clip_patterns = populate_patterns(patterns, "clip", is_op("relax.clip"), {})
patterns.update(relu_patterns)
patterns.update(clip_patterns)
conv_patterns = []
for k, v in patterns.items():
ret_annotations = v["annotation"]
ret_annotations["root"] = v["pattern"]
conv_patterns.append(
("openclml." + (k), v["pattern"], ret_annotations.copy(), _check_conv2d)
)
return conv_patterns[::-1]
def _check_maxpool2d(context: PatternCheckContext) -> bool:
root = context.annotated_expr.get("root")
if not root or not isinstance(root, relax.Call):
return False
if root.op.name != "relax.nn.max_pool2d":
return False
if "data" not in context.annotated_expr:
return False
data = context.annotated_expr["data"]
input_shape = data.struct_info.shape
if len(input_shape) != 4:
return False
if any(dim <= 0 for dim in input_shape):
return False
pool_size = root.attrs.pool_size
if len(pool_size) != 2:
return False
if any(size <= 0 for size in pool_size):
return False
strides = root.attrs.strides
if len(strides) != 2:
return False
if any(stride <= 0 for stride in strides):
return False
dilation = root.attrs.dilation
if len(dilation) != 2:
return False
if any(d <= 0 for d in dilation):
return False
padding = root.attrs.padding
if len(padding) != 4:
return False
if any(p < 0 for p in padding):
return False
return True
def maxpool_pattern():
"""Create Pool Pattern"""
data = wildcard()
annotations = {
"data": data,
}
patterns = {}
patterns["nn.max_pool2d"] = {
"pattern": is_op("relax.nn.max_pool2d")(data),
"annotation": annotations.copy(),
}
pool_patterns = []
for k, v in patterns.items():
ret_annotations = v["annotation"]
ret_annotations["root"] = v["pattern"]
pool_patterns.append(
("openclml." + (k), v["pattern"], ret_annotations.copy(), _check_maxpool2d)
)
return pool_patterns
def _check_avgpool2d(context: PatternCheckContext) -> bool:
root = context.annotated_expr.get("root")
if not root or not isinstance(root, relax.Call):
return False
if root.op.name != "relax.nn.avg_pool2d":
return False
if "data" not in context.annotated_expr:
return False
data = context.annotated_expr["data"]
input_shape = data.struct_info.shape
if len(input_shape) != 4:
return False
if any(dim <= 0 for dim in input_shape):
return False
pool_size = root.attrs.pool_size
if len(pool_size) != 2:
return False
if any(size <= 0 for size in pool_size):
return False
strides = root.attrs.strides
if len(strides) != 2:
return False
if any(stride <= 0 for stride in strides):
return False
padding = root.attrs.padding
if len(padding) != 4:
return False
if any(p < 0 for p in padding):
return False
return True
def avgpool_pattern():
data = wildcard()
annotations = {
"data": data,
}
patterns = {}
patterns["nn.avg_pool2d"] = {
"pattern": is_op("relax.nn.avg_pool2d")(data),
"annotation": annotations.copy(),
}
pool_patterns = []
for k, v in patterns.items():
ret_annotations = v["annotation"]
ret_annotations["root"] = v["pattern"]
pool_patterns.append(
("openclml." + (k), v["pattern"], ret_annotations.copy(), _check_avgpool2d)
)
return pool_patterns
def _check_global_avgpool(context: PatternCheckContext) -> bool:
root = context.annotated_expr.get("root")
if not root or not isinstance(root, relax.Call):
return False
if root.op.name != "relax.mean":
return False
if "data" not in context.annotated_expr:
return False
data = context.annotated_expr["data"]
input_shape = data.struct_info.shape
if len(input_shape) != 4:
return False
if input_shape[1] <= 0 or input_shape[2] <= 0 or input_shape[3] <= 0:
return False
if not hasattr(root.attrs, "axis"):
return False
axis = root.attrs.axis
if not (len(axis) == 2 and axis[0] == 2 and axis[1] == 3):
return False
return True
def global_avgpool_pattern():
"""Create Pool Pattern"""
data = wildcard()
pattern = is_op("relax.mean")(data).has_attr({"axis": [2, 3]})
annotations = {
"data": data,
"root": pattern,
}
return [
("openclml.nn.global_avg_pool2d", pattern, annotations, _check_global_avgpool),
]
def _check_reshape(context: PatternCheckContext) -> bool:
root = context.annotated_expr.get("root")
if not root or not isinstance(root, relax.Call):
return False
if root.op.name != "relax.reshape":
return False
shape_arg = root.args[1]
if not isinstance(shape_arg, relax.Expr):
return False
return True
def reshape_pattern():
"""Create Reshape Pattern"""
pattern = is_op("relax.reshape")(wildcard(), wildcard())
annotations = {
"root": pattern,
}
return [("openclml.reshape", pattern, annotations, _check_reshape)]
def _check_batchnorm(context: PatternCheckContext) -> bool:
root = context.annotated_expr.get("root")
if not root or not isinstance(root, relax.Call):
return False
if root.op.name != "relax.reshape":
return False
required_params = ["moving_var", "gamma", "moving_mean", "beta"]
for param in required_params:
if param not in context.annotated_expr:
return False
params = {
"moving_var": context.annotated_expr["moving_var"],
"gamma": context.annotated_expr["gamma"],
"moving_mean": context.annotated_expr["moving_mean"],
"beta": context.annotated_expr["beta"],
}
for param in params.values():
if not isinstance(param, relax.expr.Constant):
return False
base_shape = None
for param in params.values():
shape = param.struct_info.shape
dtype = param.struct_info.dtype
if dtype not in {"float32"}:
return False
# Initialize base_shape if not set
if base_shape is None:
base_shape = shape
continue
# All parameters should have same shape
if len(shape) != len(base_shape):
return False
if any(s1 != s2 for s1, s2 in zip(shape, base_shape)):
return False
return True
def batch_norm_pattern():
"""Create a batch norm pattern."""
data = wildcard()
bn_scale = is_const()
bn_bias = is_const()
bn_mean = is_const()
bn_var = is_const()
pattern = is_op("relax.nn.batch_norm")(data, bn_scale, bn_bias, bn_mean, bn_var)
pattern = is_tuple_get_item(pattern, 0)
pattern = is_op("relax.reshape")(pattern, wildcard())
annotations = {
"gamma": bn_scale,
"beta": bn_bias,
"moving_mean": bn_mean,
"moving_var": bn_var,
"root": pattern,
}
return [
("openclml.nn.batch_norm", pattern, annotations, _check_batchnorm),
]
def _check_binary_op(context: PatternCheckContext) -> bool:
def _check_arg(input_expr):
input_dtype = input_expr.struct_info.dtype
input_shape = input_expr.struct_info.shape
if len(input_shape) == 0:
return False
# Avoid any operators with dtype Int64
if input_dtype == "int64":
return False
# No support for batch> 1
if input_shape[0] > 1:
return False
return True
def compare_shapes(lhs_shape, rhs_shape):
if len(lhs_shape) != len(rhs_shape):
return False
for lhs_dim, rhs_dim in zip(lhs_shape, rhs_shape):
if lhs_dim != rhs_dim:
return False
return True
lhs_shape = None
rhs_shape = None
if "lhs" in context.annotated_expr:
lhs = context.annotated_expr["lhs"]
lhs_shape = lhs.struct_info.shape
if not _check_arg(lhs):
return False
if "rhs" in context.annotated_expr:
rhs = context.annotated_expr["rhs"]
rhs_shape = rhs.struct_info.shape
if not _check_arg(rhs):
return False
# Checking for BinaryOps ( False for unaryOp )
if (
"lhs" in context.annotated_expr
and "rhs" in context.annotated_expr
and not compare_shapes(lhs_shape, rhs_shape)
):
return False
return True
def binary_op_pattern():
"""Create a binary op pattern."""
def make_pattern(op):
lhs = wildcard()
rhs = wildcard()
pattern = is_op(op)(lhs, rhs)
annotations = {"lhs": lhs, "rhs": rhs}
return ("openclml." + op, pattern, annotations, _check_binary_op)
binary_ops = [
"relax.add",
"relax.subtract",
"relax.multiply",
"relax.divide",
"relax.maximum",
"relax.minimum",
]
return [make_pattern(op) for op in binary_ops]
def unary_op_pattern():
"""Create a unary op pattern."""
def make_pattern(op):
lhs = wildcard()
pattern = is_op(op)(lhs)
annotations = {"lhs": lhs}
return ("openclml." + op, pattern, annotations, _check_binary_op)
unary_ops = [
"relax.nn.softmax",
"relax.nn.relu",
"relax.clip",
]
return [make_pattern(op) for op in unary_ops]
return [
*conv_pattern(),
*batch_norm_pattern(),
*binary_op_pattern(),
*unary_op_pattern(),
*maxpool_pattern(),
*avgpool_pattern(),
*global_avgpool_pattern(),
*reshape_pattern(),
]