def get_patterns()

in python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py [0:0]


def get_patterns(target) -> List[Pattern]:
    """Get all the tensorrt patterns.

    Parameters
    ----------
    target: str
        The target name for tensorrt patterns.

    Returns
    -------
    patterns: list<Pattern>
        The patterns
    """

    basic_ops = {
        "nn.adaptive_avg_pool2d": ["input"],
        "nn.avg_pool2d": ["input"],
        "nn.conv2d": ["input", "constant"],
        "nn.max_pool2d": ["input"],
        "astype": ["input"],
        "concat": ["input"],
        "clip": ["input", "input", "input"],
        "image.resize2d": ["input", "input"],
        "matmul": ["input", "input"],
        "permute_dims": ["input"],
        "strided_slice": ["input", "input", "input", "input", "input"],
        "topk": ["input"],
    }
    activation_ops = ["nn.relu", "nn.softmax", "sigmoid", "tanh"]
    reduce_ops = ["max", "min", "mean", "sum"]
    unary_ops = ["cos", "erf", "exp", "negative", "round", "sin", "square", "sqrt", "tan"]
    elemwise_ops = [
        "add",
        "divide",
        "floor_divide",
        "maximum",
        "minimum",
        "multiply",
        "power",
        "subtract",
    ]
    compare_ops = ["greater", "less"]
    patterns = []
    # basic ops
    for op, in_types in basic_ops.items():
        inputs = ["input_" + str(i) for i in range(len(in_types))]
        patterns.append(
            (
                target + "." + op,
                *basic_pattern("relax." + op, in_types),
                _basic_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=inputs),
            )
        )
    # activation ops
    for op in activation_ops:
        patterns.append(
            (
                target + "." + op,
                *basic_pattern("relax." + op, ["input"]),
                _basic_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]),
            )
        )
    # reduce ops
    for op in reduce_ops:
        patterns.append(
            (
                target + "." + op,
                *basic_pattern("relax." + op, ["input"]),
                _basic_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]),
            )
        )
    # unary ops
    for op in unary_ops:
        patterns.append(
            (
                target + "." + op,
                *basic_pattern("relax." + op, ["input"]),
                _basic_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]),
            )
        )
    # elemwise ops
    for op in elemwise_ops:
        patterns.append(
            (
                target + "." + op,
                *elemwise_pattern("relax." + op),
                _elemwise_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]),
            )
        )
    # compare ops
    for op in compare_ops:
        patterns.append(
            (
                target + "." + op,
                *elemwise_pattern("relax." + op),
                _compare_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]),
            )
        )

    # special ops
    patterns.extend(
        [
            (
                target + ".take",
                *basic_pattern("relax.take", ["input", "input"]),
                _take_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]),
            ),
            (
                target + ".argmax",
                *argmaxmin_pattern("relax.argmax"),
                _argmaxmin_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input"]),
            ),
            (
                target + ".argmin",
                *argmaxmin_pattern("relax.argmin"),
                _argmaxmin_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input"]),
            ),
            (
                target + ".reshape",
                *basic_pattern("relax.reshape", ["input", "input"]),
                _reshape_check,
                partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]),
            ),
        ]
    )
    # fusable ops
    patterns.extend(
        [
            (
                target + ".msc.conv2d_bias",
                *msc_pattern.make_opt_relax_conv_bias_pattern("relax.nn.conv2d"),
                wrap_basic_check(msc_pattern._check_opt_relax_conv_bias),
                partial(
                    msc_pattern.msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]
                ),
            ),
        ]
    )
    # plugin ops
    patterns.append(
        (
            target + ".plugin",
            *basic_pattern("relax.call_dps_packed", ["input", "input"]),
            _plugin_check,
            plugin_attrs_getter,
        )
    )

    return patterns