def clml_pattern_table()

in python/tvm/relax/backend/adreno/clml.py [0:0]


def clml_pattern_table():
    """Get the CLML pattern table."""

    def _check_conv2d(context: PatternCheckContext) -> bool:
        if "root" in context.annotated_expr:
            root_call = context.annotated_expr["root"]
            if root_call.op.name == "relax.nn.conv2d":
                input_layout = root_call.attrs.data_layout
                weight_layout = root_call.attrs.kernel_layout
                if input_layout != "NCHW" or weight_layout != "OIHW":
                    return False
            if root_call.op.name == "relax.nn.conv2d_transpose":
                input_layout = root_call.attrs.data_layout
                weight_layout = root_call.attrs.kernel_layout
                if input_layout != "NCHW" or weight_layout != "OIHW":
                    return False

        if "data" in context.annotated_expr:
            input_expr = context.annotated_expr["data"]
            input_dtype = input_expr.struct_info.dtype
            if input_dtype not in ["float32", "float16"]:
                return False

        if "weight" in context.annotated_expr:
            weight_expr = context.annotated_expr["weight"]
            weight_dtype = weight_expr.struct_info.dtype
            if weight_dtype not in ["float32", "float16"]:
                return False

        return True

    def populate_patterns(patterns, name, op, annotations, *args):
        ret = {}
        for k, v in patterns.items():
            ret_ann = v["annotation"].copy()
            ret_ann.update(annotations)
            ret[name + "." + k] = {"pattern": op(v["pattern"], *args), "annotation": ret_ann.copy()}

        return ret

    def conv_pattern():
        """Create a convolution pattern."""
        data = wildcard()
        weight = wildcard()
        bias = is_const()
        bn_scale = is_const()
        bn_bias = is_const()
        bn_mean = is_const()
        bn_var = is_const()

        annotations = {
            "data": data,
            "weight": weight,
        }

        patterns = {}
        patterns["nn.conv2d"] = {
            "pattern": is_op("relax.nn.conv2d")(data, weight),
            "annotation": annotations.copy(),
        }

        pad_annotations = annotations.copy()
        patterns["pad.nn.conv2d"] = {
            "pattern": is_op("relax.nn.conv2d")(is_op("relax.nn.pad")(data), weight),
            "annotation": pad_annotations,
        }

        patterns["nn.conv2d_transpose"] = {
            "pattern": is_op("relax.nn.conv2d_transpose")(data, weight),
            "annotation": annotations.copy(),
        }
        patterns.update(
            populate_patterns(patterns, "bias", is_op("relax.add"), {"bias": bias}, bias)
        )
        patterns.update(
            populate_patterns(
                patterns,
                "bn",
                is_op("relax.nn.batch_norm"),
                {
                    "bn_scale": bn_scale,
                    "bn_bias": bn_bias,
                    "bn_mean": bn_mean,
                    "bn_var": bn_var,
                },
                bn_scale,
                bn_bias,
                bn_mean,
                bn_var,
            )
        )
        tuple_patterns = {}
        for k, v in patterns.items():
            tuple_annotation = v["annotation"].copy()
            tuple_patterns["tuple" + "." + k] = {
                "pattern": is_tuple_get_item(v["pattern"], 0),
                "annotation": tuple_annotation,
            }
        patterns.update(tuple_patterns)

        relu_patterns = populate_patterns(patterns, "relu", is_op("relax.nn.relu"), {})
        clip_patterns = populate_patterns(patterns, "clip", is_op("relax.clip"), {})
        patterns.update(relu_patterns)
        patterns.update(clip_patterns)

        conv_patterns = []
        for k, v in patterns.items():
            ret_annotations = v["annotation"]
            ret_annotations["root"] = v["pattern"]
            conv_patterns.append(
                ("openclml." + (k), v["pattern"], ret_annotations.copy(), _check_conv2d)
            )
        return conv_patterns[::-1]

    def _check_maxpool2d(context: PatternCheckContext) -> bool:
        root = context.annotated_expr.get("root")
        if not root or not isinstance(root, relax.Call):
            return False

        if root.op.name != "relax.nn.max_pool2d":
            return False

        if "data" not in context.annotated_expr:
            return False

        data = context.annotated_expr["data"]
        input_shape = data.struct_info.shape

        if len(input_shape) != 4:
            return False

        if any(dim <= 0 for dim in input_shape):
            return False

        pool_size = root.attrs.pool_size
        if len(pool_size) != 2:
            return False
        if any(size <= 0 for size in pool_size):
            return False

        strides = root.attrs.strides
        if len(strides) != 2:
            return False
        if any(stride <= 0 for stride in strides):
            return False

        dilation = root.attrs.dilation
        if len(dilation) != 2:
            return False
        if any(d <= 0 for d in dilation):
            return False

        padding = root.attrs.padding
        if len(padding) != 4:
            return False
        if any(p < 0 for p in padding):
            return False

        return True

    def maxpool_pattern():
        """Create Pool Pattern"""
        data = wildcard()
        annotations = {
            "data": data,
        }
        patterns = {}
        patterns["nn.max_pool2d"] = {
            "pattern": is_op("relax.nn.max_pool2d")(data),
            "annotation": annotations.copy(),
        }

        pool_patterns = []
        for k, v in patterns.items():
            ret_annotations = v["annotation"]
            ret_annotations["root"] = v["pattern"]
            pool_patterns.append(
                ("openclml." + (k), v["pattern"], ret_annotations.copy(), _check_maxpool2d)
            )
        return pool_patterns

    def _check_avgpool2d(context: PatternCheckContext) -> bool:
        root = context.annotated_expr.get("root")
        if not root or not isinstance(root, relax.Call):
            return False

        if root.op.name != "relax.nn.avg_pool2d":
            return False

        if "data" not in context.annotated_expr:
            return False

        data = context.annotated_expr["data"]
        input_shape = data.struct_info.shape

        if len(input_shape) != 4:
            return False

        if any(dim <= 0 for dim in input_shape):
            return False

        pool_size = root.attrs.pool_size
        if len(pool_size) != 2:
            return False
        if any(size <= 0 for size in pool_size):
            return False

        strides = root.attrs.strides
        if len(strides) != 2:
            return False
        if any(stride <= 0 for stride in strides):
            return False

        padding = root.attrs.padding
        if len(padding) != 4:
            return False
        if any(p < 0 for p in padding):
            return False

        return True

    def avgpool_pattern():
        data = wildcard()
        annotations = {
            "data": data,
        }
        patterns = {}
        patterns["nn.avg_pool2d"] = {
            "pattern": is_op("relax.nn.avg_pool2d")(data),
            "annotation": annotations.copy(),
        }

        pool_patterns = []
        for k, v in patterns.items():
            ret_annotations = v["annotation"]
            ret_annotations["root"] = v["pattern"]
            pool_patterns.append(
                ("openclml." + (k), v["pattern"], ret_annotations.copy(), _check_avgpool2d)
            )
        return pool_patterns

    def _check_global_avgpool(context: PatternCheckContext) -> bool:
        root = context.annotated_expr.get("root")
        if not root or not isinstance(root, relax.Call):
            return False

        if root.op.name != "relax.mean":
            return False

        if "data" not in context.annotated_expr:
            return False

        data = context.annotated_expr["data"]
        input_shape = data.struct_info.shape

        if len(input_shape) != 4:
            return False

        if input_shape[1] <= 0 or input_shape[2] <= 0 or input_shape[3] <= 0:
            return False

        if not hasattr(root.attrs, "axis"):
            return False

        axis = root.attrs.axis
        if not (len(axis) == 2 and axis[0] == 2 and axis[1] == 3):
            return False

        return True

    def global_avgpool_pattern():
        """Create Pool Pattern"""
        data = wildcard()
        pattern = is_op("relax.mean")(data).has_attr({"axis": [2, 3]})

        annotations = {
            "data": data,
            "root": pattern,
        }

        return [
            ("openclml.nn.global_avg_pool2d", pattern, annotations, _check_global_avgpool),
        ]

    def _check_reshape(context: PatternCheckContext) -> bool:
        root = context.annotated_expr.get("root")
        if not root or not isinstance(root, relax.Call):
            return False

        if root.op.name != "relax.reshape":
            return False

        shape_arg = root.args[1]
        if not isinstance(shape_arg, relax.Expr):
            return False

        return True

    def reshape_pattern():
        """Create Reshape Pattern"""

        pattern = is_op("relax.reshape")(wildcard(), wildcard())
        annotations = {
            "root": pattern,
        }
        return [("openclml.reshape", pattern, annotations, _check_reshape)]

    def _check_batchnorm(context: PatternCheckContext) -> bool:
        root = context.annotated_expr.get("root")
        if not root or not isinstance(root, relax.Call):
            return False

        if root.op.name != "relax.reshape":
            return False

        required_params = ["moving_var", "gamma", "moving_mean", "beta"]
        for param in required_params:
            if param not in context.annotated_expr:
                return False

        params = {
            "moving_var": context.annotated_expr["moving_var"],
            "gamma": context.annotated_expr["gamma"],
            "moving_mean": context.annotated_expr["moving_mean"],
            "beta": context.annotated_expr["beta"],
        }

        for param in params.values():
            if not isinstance(param, relax.expr.Constant):
                return False

        base_shape = None
        for param in params.values():
            shape = param.struct_info.shape
            dtype = param.struct_info.dtype

            if dtype not in {"float32"}:
                return False

            # Initialize base_shape if not set
            if base_shape is None:
                base_shape = shape
                continue

            # All parameters should have same shape
            if len(shape) != len(base_shape):
                return False
            if any(s1 != s2 for s1, s2 in zip(shape, base_shape)):
                return False

        return True

    def batch_norm_pattern():
        """Create a batch norm pattern."""
        data = wildcard()
        bn_scale = is_const()
        bn_bias = is_const()
        bn_mean = is_const()
        bn_var = is_const()

        pattern = is_op("relax.nn.batch_norm")(data, bn_scale, bn_bias, bn_mean, bn_var)
        pattern = is_tuple_get_item(pattern, 0)
        pattern = is_op("relax.reshape")(pattern, wildcard())

        annotations = {
            "gamma": bn_scale,
            "beta": bn_bias,
            "moving_mean": bn_mean,
            "moving_var": bn_var,
            "root": pattern,
        }

        return [
            ("openclml.nn.batch_norm", pattern, annotations, _check_batchnorm),
        ]

    def _check_binary_op(context: PatternCheckContext) -> bool:
        def _check_arg(input_expr):
            input_dtype = input_expr.struct_info.dtype
            input_shape = input_expr.struct_info.shape
            if len(input_shape) == 0:
                return False

            # Avoid any operators with dtype Int64
            if input_dtype == "int64":
                return False

            # No support for batch> 1
            if input_shape[0] > 1:
                return False

            return True

        def compare_shapes(lhs_shape, rhs_shape):
            if len(lhs_shape) != len(rhs_shape):
                return False
            for lhs_dim, rhs_dim in zip(lhs_shape, rhs_shape):
                if lhs_dim != rhs_dim:
                    return False
            return True

        lhs_shape = None
        rhs_shape = None
        if "lhs" in context.annotated_expr:
            lhs = context.annotated_expr["lhs"]
            lhs_shape = lhs.struct_info.shape
            if not _check_arg(lhs):
                return False

        if "rhs" in context.annotated_expr:
            rhs = context.annotated_expr["rhs"]
            rhs_shape = rhs.struct_info.shape
            if not _check_arg(rhs):
                return False

        # Checking for BinaryOps ( False for unaryOp )
        if (
            "lhs" in context.annotated_expr
            and "rhs" in context.annotated_expr
            and not compare_shapes(lhs_shape, rhs_shape)
        ):
            return False

        return True

    def binary_op_pattern():
        """Create a binary op pattern."""

        def make_pattern(op):
            lhs = wildcard()
            rhs = wildcard()
            pattern = is_op(op)(lhs, rhs)
            annotations = {"lhs": lhs, "rhs": rhs}
            return ("openclml." + op, pattern, annotations, _check_binary_op)

        binary_ops = [
            "relax.add",
            "relax.subtract",
            "relax.multiply",
            "relax.divide",
            "relax.maximum",
            "relax.minimum",
        ]

        return [make_pattern(op) for op in binary_ops]

    def unary_op_pattern():
        """Create a unary op pattern."""

        def make_pattern(op):
            lhs = wildcard()
            pattern = is_op(op)(lhs)
            annotations = {"lhs": lhs}
            return ("openclml." + op, pattern, annotations, _check_binary_op)

        unary_ops = [
            "relax.nn.softmax",
            "relax.nn.relu",
            "relax.clip",
        ]

        return [make_pattern(op) for op in unary_ops]

    return [
        *conv_pattern(),
        *batch_norm_pattern(),
        *binary_op_pattern(),
        *unary_op_pattern(),
        *maxpool_pattern(),
        *avgpool_pattern(),
        *global_avgpool_pattern(),
        *reshape_pattern(),
    ]