tinynn/prune/identity_pruner.py (25 lines of code) (raw):

from tinynn.prune import OneShotChannelPruner from tinynn.util.util import get_logger log = get_logger(__name__) class IdentityChannelPruner(OneShotChannelPruner): required_params = ('sparsity', 'metrics') bn_compensation: bool exclude_ops: list def __init__(self, model, dummy_input, config=None): """Constructs a new IdentityChannelPruner Args: model: The model to be pruned dummy_input: A viable input to the model config (dict, str): Configuration of the pruner (could be path to the json file) Raises: Exception: If a model without parameters or prunable ops is given, the exception will be thrown """ self.bn_compensation = True if config is None: config = {"metrics": "l2_norm", "sparsity": 0.5} if "metrics" not in config: config["metrics"] = "l2_norm" config["sparsity"] = 1.0 super(IdentityChannelPruner, self).__init__(model, dummy_input, config) def register_mask(self): """Computes the mask for the parameters in the model and register them through the maskers""" log.info("Register a mask for each operator") for sub_graph in self.graph_modifier.sub_graphs.values(): if sub_graph.skip: log.info(f"skip subgraph {sub_graph.center}") continue sub_graph.calc_prune_idx(None, self.sparsity, multiple=self.multiple) log.info(f"subgraph [{sub_graph.center}] compute over") for m in self.graph_modifier.modifiers.values(): m.register_mask(self.graph_modifier.modifiers, None, self.sparsity)