def __init__()

in tinynn/prune/oneshot_pruner.py [0:0]


    def __init__(self, model, dummy_input, config):
        """Constructs a new OneShotPruner (including random, l1_norm, l2_norm, fpgm)

        Args:
            model: The model to be pruned
            dummy_input: A viable input to the model
            config (dict, str): Configuration of the pruner (could be path to the json file)

        Raises:
            Exception: If a model without parameters or prunable ops is given, the exception will be thrown
        """

        super(OneShotChannelPruner, self).__init__(model, dummy_input, config)

        self.center_nodes = []
        self.sparsity = {}
        self.bn_compensation = False
        self.parse_config()

        for n in self.graph.forward_nodes:
            # Only prune the specific operators
            if (
                (
                    n.type() in [nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d, nn.ConvTranspose1d]
                    and not is_dw_conv(n.module)
                )
                or (n.type() in [nn.Linear])
                or (n.type() in [nn.RNN, nn.GRU, nn.LSTM] and len(n.prev_tensors) == 1)
            ):
                self.center_nodes.append(n)
                if n.unique_name not in self.sparsity:
                    self.sparsity[n.unique_name] = self.default_sparsity

        last_center_node = self.center_nodes[-1] if len(self.center_nodes) > 0 else None
        if last_center_node is None:
            raise Exception("No operations to prune")

        # 除非人工指定了剪枝率,否则当最后一个中心节点是linear时不对其进行剪枝(通常网络最后的全连接层与类别数量有关)
        if self.skip_last_fc and last_center_node.type() in [nn.Linear]:
            if self.sparsity[last_center_node.unique_name] == self.default_sparsity:
                self.sparsity[last_center_node.unique_name] = 0.0

        self.graph_modifier = modifier.GraphChannelModifier(self.graph, self.center_nodes, self.bn_compensation)

        # TODO: 为了保持剪枝精度,目前暂时先将所有的常量参与的剪枝子图取消
        self.exclude_ops.append('weight')

        for sub_graph in self.graph_modifier.sub_graphs.values():
            exclude = False
            for m in sub_graph.modifiers:
                if m.node.type() in self.exclude_ops:
                    exclude = True
                    break

                if m.node.type() in (nn.RNN, nn.GRU, nn.LSTM) and len(m.node.prev_tensors) > 1:
                    exclude = True
                    break

            if exclude:
                sub_graph.skip = True