def _parse_strategys()

in python/tvm/contrib/msc/core/tools/tool.py [0:0]


    def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]:
        """Parse the strategy to get valid strategy

        Parameters
        -------
        strategy_list: list<dict>
            The given strategys.

        Returns
        -------
        strategys: dict<str, ToolStrategy>
            The parsed strategy.
        """

        assert isinstance(strategy_list, list) and all(
            isinstance(s, dict) for s in strategy_list
        ), "ToolStrategy should be given as list of dict"
        assert self._graphs, "graphs are needed to parse strategys"
        all_tensor_names = set(t.name for t in self.get_tensors())
        all_tensor_ids = set(self.get_tensor_ids())
        all_op_types = set(n.optype for n in self.get_nodes())
        all_op_names = set(n.name for n in self.get_nodes())
        strategys = {}

        def _get_method(method_name):
            if "." in method_name:
                method_cls_name, method_name = method_name.split(".")
            else:
                method_cls_name = "default"
            method_cls = msc_utils.get_registered_tool_method(
                self.framework(), self.tool_type(), method_cls_name
            )
            if hasattr(method_cls, method_name):
                return getattr(method_cls, method_name)
            default_cls = msc_utils.get_registered_tool_method(
                MSCFramework.MSC, self.tool_type(), method_cls_name
            )
            if hasattr(default_cls, method_name):
                return getattr(default_cls, method_name)
            method = msc_utils.get_registered_func(method_name)
            assert method, "Can not find method with " + str(method_name)
            return method

        for strategy in strategy_list:
            meta_strategy = msc_utils.copy_dict(strategy)
            for t_type, method_def in meta_strategy["methods"].items():
                if isinstance(method_def, str):
                    method_name, method_kwargs = method_def, {}
                elif isinstance(method_def, dict):
                    assert "method_name" in method_def, "Can not find method_name"
                    method_name = method_def["method_name"]
                    method_kwargs = {k: v for k, v in method_def.items() if k != "method_name"}
                else:
                    raise TypeError(
                        "Only support string and dict as method define, get " + str(method_def)
                    )
                method = _get_method(method_name)
                if "marks" in strategy:
                    assert t_type == "mark", "mark strategy only support mark method, get " + str(
                        meta_strategy
                    )
                    marks = strategy["marks"]
                elif "tensor_names" in strategy:
                    assert (
                        t_type == "tensor"
                    ), "tensor strategy only support tensor method, get " + str(meta_strategy)
                    marks = [t for t in strategy["tensor_names"] if t in all_tensor_names]
                elif "tensor_ids" in strategy:
                    assert (
                        t_type == "tensor"
                    ), "tensor strategy only support tensor method, get " + str(meta_strategy)
                    marks = [t for t in strategy["tensor_ids"] if t in all_tensor_ids]
                elif "op_types" in strategy:
                    op_types = [t for t in strategy["op_types"] if t in all_op_types]
                    marks = ["{}.{}".format(t, t_type) for t in op_types]
                elif "op_names" in strategy:
                    op_names = [t for t in strategy["op_names"] if t in all_op_names]
                    marks = ["{}.{}".format(t, t_type) for t in op_names]
                else:
                    marks = ["default." + str(t_type)]
                for mark, stage in product(marks, strategy.get("stages", ["default"])):
                    if mark not in strategys:
                        strategys[mark] = ToolStrategy(mark, t_type, self._stage)
                    strategys[mark].add_executor(
                        stage, ToolExecutor(method_name, method, copy.deepcopy(method_kwargs))
                    )
        return strategys