in python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py [0:0]
def get_patterns(target) -> List[Pattern]:
"""Get all the tensorrt patterns.
Parameters
----------
target: str
The target name for tensorrt patterns.
Returns
-------
patterns: list<Pattern>
The patterns
"""
basic_ops = {
"nn.adaptive_avg_pool2d": ["input"],
"nn.avg_pool2d": ["input"],
"nn.conv2d": ["input", "constant"],
"nn.max_pool2d": ["input"],
"astype": ["input"],
"concat": ["input"],
"clip": ["input", "input", "input"],
"image.resize2d": ["input", "input"],
"matmul": ["input", "input"],
"permute_dims": ["input"],
"strided_slice": ["input", "input", "input", "input", "input"],
"topk": ["input"],
}
activation_ops = ["nn.relu", "nn.softmax", "sigmoid", "tanh"]
reduce_ops = ["max", "min", "mean", "sum"]
unary_ops = ["cos", "erf", "exp", "negative", "round", "sin", "square", "sqrt", "tan"]
elemwise_ops = [
"add",
"divide",
"floor_divide",
"maximum",
"minimum",
"multiply",
"power",
"subtract",
]
compare_ops = ["greater", "less"]
patterns = []
# basic ops
for op, in_types in basic_ops.items():
inputs = ["input_" + str(i) for i in range(len(in_types))]
patterns.append(
(
target + "." + op,
*basic_pattern("relax." + op, in_types),
_basic_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=inputs),
)
)
# activation ops
for op in activation_ops:
patterns.append(
(
target + "." + op,
*basic_pattern("relax." + op, ["input"]),
_basic_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]),
)
)
# reduce ops
for op in reduce_ops:
patterns.append(
(
target + "." + op,
*basic_pattern("relax." + op, ["input"]),
_basic_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]),
)
)
# unary ops
for op in unary_ops:
patterns.append(
(
target + "." + op,
*basic_pattern("relax." + op, ["input"]),
_basic_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]),
)
)
# elemwise ops
for op in elemwise_ops:
patterns.append(
(
target + "." + op,
*elemwise_pattern("relax." + op),
_elemwise_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]),
)
)
# compare ops
for op in compare_ops:
patterns.append(
(
target + "." + op,
*elemwise_pattern("relax." + op),
_compare_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]),
)
)
# special ops
patterns.extend(
[
(
target + ".take",
*basic_pattern("relax.take", ["input", "input"]),
_take_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]),
),
(
target + ".argmax",
*argmaxmin_pattern("relax.argmax"),
_argmaxmin_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input"]),
),
(
target + ".argmin",
*argmaxmin_pattern("relax.argmin"),
_argmaxmin_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input"]),
),
(
target + ".reshape",
*basic_pattern("relax.reshape", ["input", "input"]),
_reshape_check,
partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]),
),
]
)
# fusable ops
patterns.extend(
[
(
target + ".msc.conv2d_bias",
*msc_pattern.make_opt_relax_conv_bias_pattern("relax.nn.conv2d"),
wrap_basic_check(msc_pattern._check_opt_relax_conv_bias),
partial(
msc_pattern.msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]
),
),
]
)
# plugin ops
patterns.append(
(
target + ".plugin",
*basic_pattern("relax.call_dps_packed", ["input", "input"]),
_plugin_check,
plugin_attrs_getter,
)
)
return patterns