in python/tvm/contrib/msc/core/tools/tool.py [0:0]
def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]:
"""Parse the strategy to get valid strategy
Parameters
-------
strategy_list: list<dict>
The given strategys.
Returns
-------
strategys: dict<str, ToolStrategy>
The parsed strategy.
"""
assert isinstance(strategy_list, list) and all(
isinstance(s, dict) for s in strategy_list
), "ToolStrategy should be given as list of dict"
assert self._graphs, "graphs are needed to parse strategys"
all_tensor_names = set(t.name for t in self.get_tensors())
all_tensor_ids = set(self.get_tensor_ids())
all_op_types = set(n.optype for n in self.get_nodes())
all_op_names = set(n.name for n in self.get_nodes())
strategys = {}
def _get_method(method_name):
if "." in method_name:
method_cls_name, method_name = method_name.split(".")
else:
method_cls_name = "default"
method_cls = msc_utils.get_registered_tool_method(
self.framework(), self.tool_type(), method_cls_name
)
if hasattr(method_cls, method_name):
return getattr(method_cls, method_name)
default_cls = msc_utils.get_registered_tool_method(
MSCFramework.MSC, self.tool_type(), method_cls_name
)
if hasattr(default_cls, method_name):
return getattr(default_cls, method_name)
method = msc_utils.get_registered_func(method_name)
assert method, "Can not find method with " + str(method_name)
return method
for strategy in strategy_list:
meta_strategy = msc_utils.copy_dict(strategy)
for t_type, method_def in meta_strategy["methods"].items():
if isinstance(method_def, str):
method_name, method_kwargs = method_def, {}
elif isinstance(method_def, dict):
assert "method_name" in method_def, "Can not find method_name"
method_name = method_def["method_name"]
method_kwargs = {k: v for k, v in method_def.items() if k != "method_name"}
else:
raise TypeError(
"Only support string and dict as method define, get " + str(method_def)
)
method = _get_method(method_name)
if "marks" in strategy:
assert t_type == "mark", "mark strategy only support mark method, get " + str(
meta_strategy
)
marks = strategy["marks"]
elif "tensor_names" in strategy:
assert (
t_type == "tensor"
), "tensor strategy only support tensor method, get " + str(meta_strategy)
marks = [t for t in strategy["tensor_names"] if t in all_tensor_names]
elif "tensor_ids" in strategy:
assert (
t_type == "tensor"
), "tensor strategy only support tensor method, get " + str(meta_strategy)
marks = [t for t in strategy["tensor_ids"] if t in all_tensor_ids]
elif "op_types" in strategy:
op_types = [t for t in strategy["op_types"] if t in all_op_types]
marks = ["{}.{}".format(t, t_type) for t in op_types]
elif "op_names" in strategy:
op_names = [t for t in strategy["op_names"] if t in all_op_names]
marks = ["{}.{}".format(t, t_type) for t in op_names]
else:
marks = ["default." + str(t_type)]
for mark, stage in product(marks, strategy.get("stages", ["default"])):
if mark not in strategys:
strategys[mark] = ToolStrategy(mark, t_type, self._stage)
strategys[mark].add_executor(
stage, ToolExecutor(method_name, method, copy.deepcopy(method_kwargs))
)
return strategys