tinynn/prune/oneshot_pruner.py (126 lines of code) (raw):
import copy
import typing
import torch
import torch.nn as nn
import torch.distributed as dist
from tinynn.graph import modifier
from tinynn.graph.tracer import ignore_mod_param_update_warning
from tinynn.util.util import get_logger
from tinynn.prune.base_pruner import BasePruner
from tinynn.graph.modifier import is_dw_conv
log = get_logger(__name__)
class OneShotChannelPruner(BasePruner):
required_params = ('sparsity', 'metrics')
metrics: str
sparsity: typing.Union[typing.Dict[str, float], float]
default_sparsity: float
metric_func: typing.Callable[[torch.Tensor, torch.nn.Module], float]
skip_last_fc: bool
bn_compensation: bool
exclude_ops: list
multiple: int
def __init__(self, model, dummy_input, config):
"""Constructs a new OneShotPruner (including random, l1_norm, l2_norm, fpgm)
Args:
model: The model to be pruned
dummy_input: A viable input to the model
config (dict, str): Configuration of the pruner (could be path to the json file)
Raises:
Exception: If a model without parameters or prunable ops is given, the exception will be thrown
"""
super(OneShotChannelPruner, self).__init__(model, dummy_input, config)
self.center_nodes = []
self.sparsity = {}
self.bn_compensation = False
self.parse_config()
for n in self.graph.forward_nodes:
# Only prune the specific operators
if (
(
n.type() in [nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d, nn.ConvTranspose1d]
and not is_dw_conv(n.module)
)
or (n.type() in [nn.Linear])
or (n.type() in [nn.RNN, nn.GRU, nn.LSTM] and len(n.prev_tensors) == 1)
):
self.center_nodes.append(n)
if n.unique_name not in self.sparsity:
self.sparsity[n.unique_name] = self.default_sparsity
last_center_node = self.center_nodes[-1] if len(self.center_nodes) > 0 else None
if last_center_node is None:
raise Exception("No operations to prune")
# 除非人工指定了剪枝率,否则当最后一个中心节点是linear时不对其进行剪枝(通常网络最后的全连接层与类别数量有关)
if self.skip_last_fc and last_center_node.type() in [nn.Linear]:
if self.sparsity[last_center_node.unique_name] == self.default_sparsity:
self.sparsity[last_center_node.unique_name] = 0.0
self.graph_modifier = modifier.GraphChannelModifier(self.graph, self.center_nodes, self.bn_compensation)
# TODO: 为了保持剪枝精度,目前暂时先将所有的常量参与的剪枝子图取消
self.exclude_ops.append('weight')
for sub_graph in self.graph_modifier.sub_graphs.values():
exclude = False
for m in sub_graph.modifiers:
if m.node.type() in self.exclude_ops:
exclude = True
break
if m.node.type() in (nn.RNN, nn.GRU, nn.LSTM) and len(m.node.prev_tensors) > 1:
exclude = True
break
if exclude:
sub_graph.skip = True
def parse_config(self):
"""Parses the context and copy the needed items to the pruner"""
super().parse_config()
all_param_keys = list(self.required_params) + list(self.default_values.keys())
for param_key in all_param_keys:
if param_key not in ['sparsity', 'metrics', 'skip_last_fc', 'exclude_ops', "bn_compensation"]:
setattr(self, param_key, self.config[param_key])
sparsity = self.config['sparsity']
metrics = self.config['metrics']
self.skip_last_fc = self.config.get('skip_last_fc', True)
self.multiple = self.config.get('multiple', None)
self.bn_compensation = self.config.get('bn_compensation', self.bn_compensation)
self.exclude_ops = self.config.get('exclude_ops', [])
if isinstance(sparsity, float):
self.default_sparsity = sparsity
elif isinstance(sparsity, dict):
if 'default' in sparsity:
self.default_sparsity = sparsity['default']
del sparsity['default']
self.sparsity.update(sparsity)
else:
raise Exception(f'The type of `sparsity` should either be float or dict for {type(self).__name__}')
if not hasattr(self, 'default_sparsity'):
raise Exception(f'Please specify a default sparsity using the key `default` for {type(self).__name__}')
if hasattr(modifier, metrics):
self.metric_func = getattr(modifier, metrics)
else:
raise Exception(f'{metrics} is not a known metrics for {type(self).__name__}')
def prune(self):
"""The main function for pruning.
As for oneshot pruning, it is simply calculating the masks (`register_mask`) and them apply them (`apply_mask`).
"""
self.register_mask()
with ignore_mod_param_update_warning():
self.apply_mask()
def register_mask(self):
"""Computes the mask for the parameters in the model and register them through the maskers"""
log.info("Register a mask for each operator")
importance = {}
for sub_graph in self.graph_modifier.sub_graphs.values():
if sub_graph.skip:
log.info(f"skip subgraph {sub_graph.center}")
continue
for m in sub_graph.modifiers:
if m.node in self.center_nodes and m in sub_graph.dependent_centers:
modifier.update_weight_metric(importance, self.metric_func, m.module(), m.unique_name())
sub_graph.calc_prune_idx(importance, self.sparsity)
log.info(f"subgraph [{sub_graph.center}] compute over")
for m in self.graph_modifier.modifiers.values():
m.register_mask(self.graph_modifier.modifiers, importance, self.sparsity)
def apply_mask(self):
"""Applies the masks for the parameters and updates the shape and properties of the tensors and modules"""
log.info("Apply the mask of each operator")
for m in self.graph_modifier.modifiers.values():
m.apply_mask(self.graph_modifier.modifiers)
self.graph_modifier.unregister_modifier()
# Sync parameters after the size of the tensors has been shrunk
if dist.is_available() and dist.is_initialized():
for ps in self.model.parameters():
dist.broadcast(ps, 0)
def generate_config(self, path: str) -> None:
"""Generates a new copy the updated configuration with the given path"""
config = copy.deepcopy(self.config)
config['sparsity'] = dict()
config['sparsity']['default'] = self.default_sparsity
config['sparsity'].update(self.sparsity)
super().generate_config(path, config)
def reset(self) -> None:
"""Regenerate the TraceGraph and the Graph Modifier when they are invalidated"""
self.graph = self.trace()
self.center_nodes.clear()
for n in self.graph.forward_nodes:
if (n.type() in [nn.Conv2d, nn.ConvTranspose2d] and not is_dw_conv(n.module)) or (n.type() in [nn.Linear]):
self.center_nodes.append(n)
if n.unique_name not in self.sparsity:
self.sparsity[n.unique_name] = self.default_sparsity
self.graph_modifier = modifier.GraphChannelModifier(self.graph, self.center_nodes, self.bn_compensation)