tinynn/prune/base_pruner.py (178 lines of code) (raw):

import collections import typing from abc import ABC, abstractmethod from inspect import getsource from pprint import pformat import torch import torch.distributed as dist import torch.nn as nn from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel.distributed import DistributedDataParallel from tinynn.graph.modifier import is_dw_conv from tinynn.graph.tracer import TraceGraph, trace from tinynn.util.train_util import DLContext, get_module_device from tinynn.util.util import conditional, get_actual_type, get_logger try: import ruamel_yaml as yaml except ModuleNotFoundError: import ruamel.yaml as yaml NEW_YAML_FLAG = "error_deprecation" in getsource(yaml.load) log = get_logger(__name__) class BasePruner(ABC): required_params = () required_context_params = () default_values = {} context_from_params_dict = {} condition_dict = {} model: nn.Module dummy_input: torch.Tensor config: typing.Union[dict, collections.OrderedDict] graph: TraceGraph context: DLContext def __init__(self, model, dummy_input, config): self.model = model self.config = config self.dummy_input = dummy_input self.graph = self.trace() @abstractmethod def prune(self): """The main function for pruning""" pass @abstractmethod def register_mask(self): """Computes the mask for the parameters in the model and register them through the maskers""" pass @abstractmethod def apply_mask(self): """Applies the masks for the parameters and updates the shape and properties of the tensors and modules""" pass def parse_config(self): """Parses the config and init the parameters of the pruner""" if isinstance(self.config, str): self.config = self.load_config(self.config) if not isinstance(self.config, dict): raise Exception('The `config` argument requires a parsed json object (e.g. dict or OrderedDict)') missing_params = set(self.required_params) - set(self.config) if len(missing_params) != 0: missing_params_str = ', '.join(missing_params) raise Exception(f'Missing param {missing_params_str} for {type(self).__name__}') for param_key, default_value in self.default_values.items(): if param_key not in self.config: self.config[param_key] = default_value type_dict = {} for cls in reversed(type(self).__mro__): if hasattr(cls, '__annotations__') and cls != BasePruner: type_dict.update(cls.__annotations__) for param_key, param_type in type_dict.items(): if param_key not in self.config: continue type_expected = False param_types = get_actual_type(param_type) for type_cand in param_types: if isinstance(self.config[param_key], type_cand): type_expected = True break if not type_expected: raise Exception(f'The type of `{param_key}` in {type(self).__name__} should be {param_type}') for param_key, predicate in self.condition_dict.items(): if predicate(self.config[param_key]) is False: raise Exception(f'The value of `{param_key}` doesn\'t meet the requirement: {getsource(predicate)}') def parse_context(self, context: DLContext): """Parses the context and copy the needed items to the pruner""" for context_key, param_keys in self.context_from_params_dict.items(): for param_key in param_keys: if param_key in self.config: setattr(context, context_key, self.config[param_key]) break filtered_context = dict(filter(lambda x: x[1], context.__dict__.items())) missing_context_items = list(set(self.required_context_params) - set(filtered_context)) if len(missing_context_items) != 0: missing_context_items_str = ', '.join(missing_context_items) raise Exception(f'Missing context items {missing_context_items_str} for {type(self).__name__}') self.context = context def summary(self): """Dumps the parameters and possibly the related context items of the pruner""" if len(self.required_params) > 0: log.info('-' * 80) log.info(f'params ({type(self).__name__}):') for k, v in self.__dict__.items(): if k in self.required_params or k in self.default_values: log.info(f'{k}: {pformat(v)}') if len(self.required_context_params) > 0: log.info('-' * 80) log.info(f'context ({type(self).__name__}):') log.info('\n'.join((f'{k}: {pformat(v)}' for k, v in self.context.__dict__.items()))) @classmethod def load_config(cls, path: str) -> dict: """Loads the configuration file and returns it as a dictionary""" with open(path, 'r') as f: if NEW_YAML_FLAG: yaml_ = yaml.YAML(typ='rt') config = yaml_.load(f) else: config = yaml.load(f, Loader=yaml.RoundTripLoader) return config @conditional(lambda: not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0) def generate_config(self, path: str, config: dict = None) -> None: """Generates a new copy the updated configuration with the given path""" if config is None: config = self.config with open(path, 'w') as f: if NEW_YAML_FLAG: yaml_ = yaml.YAML(typ='rt') yaml_.default_flow_style = False yaml_.dump(config, f) else: yaml.dump(config, f, default_flow_style=False, Dumper=yaml.RoundTripDumper) def trace(self) -> TraceGraph: with torch.no_grad(): if isinstance(self.model, DataParallel) or isinstance(self.model, DistributedDataParallel): model = self.model.module else: model = self.model old_device = get_module_device(model) model.cpu() graph = trace(model, self.dummy_input) if old_device is not None: model.to(device=old_device) return graph def reset(self): """Regenerate the TraceGraph when it is invalidated""" self.graph = self.trace() def calc_flops(self) -> int: """Calculate the flops of the given model""" # If graph is invalidated, then we need to regenerate the graph if not self.graph.inited: self.reset() graph: TraceGraph = self.graph total_ops = 0 for node in graph.forward_nodes: m = node.module remove_in_channel_count = 0 remove_out_channel_count = 0 if hasattr(m, 'masker'): if m.masker.in_remove_idx is not None: remove_in_channel_count = len(m.masker.in_remove_idx) if m.masker.ot_remove_idx is not None: remove_out_channel_count = len(m.masker.ot_remove_idx) if type(m) in (nn.Conv2d, nn.ConvTranspose2d): kernel_ops = torch.zeros(m.weight.size()[2:]).numel() bias_ops = 1 if m.bias is not None else 0 in_channels = m.in_channels - remove_in_channel_count out_channels = m.out_channels - remove_out_channel_count out_elements = node.next_tensors[0].nelement() // m.out_channels * out_channels if is_dw_conv(m): groups = in_channels else: groups = m.groups total_ops += out_elements * (in_channels // groups * kernel_ops + bias_ops) elif type(m) in (nn.BatchNorm2d,): channels = m.num_features - remove_in_channel_count nelements = node.prev_tensors[0].numel() // m.num_features * channels total_ops += 2 * nelements elif type(m) in (nn.AvgPool2d, nn.AdaptiveAvgPool2d): channels = node.prev_tensors[0].size(1) - remove_in_channel_count kernel_ops = 1 num_elements = node.prev_tensors[0].numel() // node.prev_tensors[0].size(1) * channels total_ops += kernel_ops * num_elements elif type(m) in (nn.ReLU,): channels = node.prev_tensors[0].size(1) - remove_in_channel_count kernel_ops = 1 num_elements = node.prev_tensors[0].numel() // node.prev_tensors[0].size(1) * channels total_ops += kernel_ops * num_elements elif type(m) in (nn.Linear,): in_channels = m.in_features - remove_in_channel_count out_channels = m.out_features - remove_out_channel_count total_mul = in_channels num_elements = node.next_tensors[0].numel() // m.out_features * out_channels total_ops += total_mul * num_elements return total_ops