in tinynn/prune/oneshot_pruner.py [0:0]
def __init__(self, model, dummy_input, config):
"""Constructs a new OneShotPruner (including random, l1_norm, l2_norm, fpgm)
Args:
model: The model to be pruned
dummy_input: A viable input to the model
config (dict, str): Configuration of the pruner (could be path to the json file)
Raises:
Exception: If a model without parameters or prunable ops is given, the exception will be thrown
"""
super(OneShotChannelPruner, self).__init__(model, dummy_input, config)
self.center_nodes = []
self.sparsity = {}
self.bn_compensation = False
self.parse_config()
for n in self.graph.forward_nodes:
# Only prune the specific operators
if (
(
n.type() in [nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d, nn.ConvTranspose1d]
and not is_dw_conv(n.module)
)
or (n.type() in [nn.Linear])
or (n.type() in [nn.RNN, nn.GRU, nn.LSTM] and len(n.prev_tensors) == 1)
):
self.center_nodes.append(n)
if n.unique_name not in self.sparsity:
self.sparsity[n.unique_name] = self.default_sparsity
last_center_node = self.center_nodes[-1] if len(self.center_nodes) > 0 else None
if last_center_node is None:
raise Exception("No operations to prune")
# 除非人工指定了剪枝率,否则当最后一个中心节点是linear时不对其进行剪枝(通常网络最后的全连接层与类别数量有关)
if self.skip_last_fc and last_center_node.type() in [nn.Linear]:
if self.sparsity[last_center_node.unique_name] == self.default_sparsity:
self.sparsity[last_center_node.unique_name] = 0.0
self.graph_modifier = modifier.GraphChannelModifier(self.graph, self.center_nodes, self.bn_compensation)
# TODO: 为了保持剪枝精度,目前暂时先将所有的常量参与的剪枝子图取消
self.exclude_ops.append('weight')
for sub_graph in self.graph_modifier.sub_graphs.values():
exclude = False
for m in sub_graph.modifiers:
if m.node.type() in self.exclude_ops:
exclude = True
break
if m.node.type() in (nn.RNN, nn.GRU, nn.LSTM) and len(m.node.prev_tensors) > 1:
exclude = True
break
if exclude:
sub_graph.skip = True