tinynn/prune/base_pruner.py (178 lines of code) (raw):
import collections
import typing
from abc import ABC, abstractmethod
from inspect import getsource
from pprint import pformat
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from tinynn.graph.modifier import is_dw_conv
from tinynn.graph.tracer import TraceGraph, trace
from tinynn.util.train_util import DLContext, get_module_device
from tinynn.util.util import conditional, get_actual_type, get_logger
try:
import ruamel_yaml as yaml
except ModuleNotFoundError:
import ruamel.yaml as yaml
NEW_YAML_FLAG = "error_deprecation" in getsource(yaml.load)
log = get_logger(__name__)
class BasePruner(ABC):
required_params = ()
required_context_params = ()
default_values = {}
context_from_params_dict = {}
condition_dict = {}
model: nn.Module
dummy_input: torch.Tensor
config: typing.Union[dict, collections.OrderedDict]
graph: TraceGraph
context: DLContext
def __init__(self, model, dummy_input, config):
self.model = model
self.config = config
self.dummy_input = dummy_input
self.graph = self.trace()
@abstractmethod
def prune(self):
"""The main function for pruning"""
pass
@abstractmethod
def register_mask(self):
"""Computes the mask for the parameters in the model and register them through the maskers"""
pass
@abstractmethod
def apply_mask(self):
"""Applies the masks for the parameters and updates the shape and properties of the tensors and modules"""
pass
def parse_config(self):
"""Parses the config and init the parameters of the pruner"""
if isinstance(self.config, str):
self.config = self.load_config(self.config)
if not isinstance(self.config, dict):
raise Exception('The `config` argument requires a parsed json object (e.g. dict or OrderedDict)')
missing_params = set(self.required_params) - set(self.config)
if len(missing_params) != 0:
missing_params_str = ', '.join(missing_params)
raise Exception(f'Missing param {missing_params_str} for {type(self).__name__}')
for param_key, default_value in self.default_values.items():
if param_key not in self.config:
self.config[param_key] = default_value
type_dict = {}
for cls in reversed(type(self).__mro__):
if hasattr(cls, '__annotations__') and cls != BasePruner:
type_dict.update(cls.__annotations__)
for param_key, param_type in type_dict.items():
if param_key not in self.config:
continue
type_expected = False
param_types = get_actual_type(param_type)
for type_cand in param_types:
if isinstance(self.config[param_key], type_cand):
type_expected = True
break
if not type_expected:
raise Exception(f'The type of `{param_key}` in {type(self).__name__} should be {param_type}')
for param_key, predicate in self.condition_dict.items():
if predicate(self.config[param_key]) is False:
raise Exception(f'The value of `{param_key}` doesn\'t meet the requirement: {getsource(predicate)}')
def parse_context(self, context: DLContext):
"""Parses the context and copy the needed items to the pruner"""
for context_key, param_keys in self.context_from_params_dict.items():
for param_key in param_keys:
if param_key in self.config:
setattr(context, context_key, self.config[param_key])
break
filtered_context = dict(filter(lambda x: x[1], context.__dict__.items()))
missing_context_items = list(set(self.required_context_params) - set(filtered_context))
if len(missing_context_items) != 0:
missing_context_items_str = ', '.join(missing_context_items)
raise Exception(f'Missing context items {missing_context_items_str} for {type(self).__name__}')
self.context = context
def summary(self):
"""Dumps the parameters and possibly the related context items of the pruner"""
if len(self.required_params) > 0:
log.info('-' * 80)
log.info(f'params ({type(self).__name__}):')
for k, v in self.__dict__.items():
if k in self.required_params or k in self.default_values:
log.info(f'{k}: {pformat(v)}')
if len(self.required_context_params) > 0:
log.info('-' * 80)
log.info(f'context ({type(self).__name__}):')
log.info('\n'.join((f'{k}: {pformat(v)}' for k, v in self.context.__dict__.items())))
@classmethod
def load_config(cls, path: str) -> dict:
"""Loads the configuration file and returns it as a dictionary"""
with open(path, 'r') as f:
if NEW_YAML_FLAG:
yaml_ = yaml.YAML(typ='rt')
config = yaml_.load(f)
else:
config = yaml.load(f, Loader=yaml.RoundTripLoader)
return config
@conditional(lambda: not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0)
def generate_config(self, path: str, config: dict = None) -> None:
"""Generates a new copy the updated configuration with the given path"""
if config is None:
config = self.config
with open(path, 'w') as f:
if NEW_YAML_FLAG:
yaml_ = yaml.YAML(typ='rt')
yaml_.default_flow_style = False
yaml_.dump(config, f)
else:
yaml.dump(config, f, default_flow_style=False, Dumper=yaml.RoundTripDumper)
def trace(self) -> TraceGraph:
with torch.no_grad():
if isinstance(self.model, DataParallel) or isinstance(self.model, DistributedDataParallel):
model = self.model.module
else:
model = self.model
old_device = get_module_device(model)
model.cpu()
graph = trace(model, self.dummy_input)
if old_device is not None:
model.to(device=old_device)
return graph
def reset(self):
"""Regenerate the TraceGraph when it is invalidated"""
self.graph = self.trace()
def calc_flops(self) -> int:
"""Calculate the flops of the given model"""
# If graph is invalidated, then we need to regenerate the graph
if not self.graph.inited:
self.reset()
graph: TraceGraph = self.graph
total_ops = 0
for node in graph.forward_nodes:
m = node.module
remove_in_channel_count = 0
remove_out_channel_count = 0
if hasattr(m, 'masker'):
if m.masker.in_remove_idx is not None:
remove_in_channel_count = len(m.masker.in_remove_idx)
if m.masker.ot_remove_idx is not None:
remove_out_channel_count = len(m.masker.ot_remove_idx)
if type(m) in (nn.Conv2d, nn.ConvTranspose2d):
kernel_ops = torch.zeros(m.weight.size()[2:]).numel()
bias_ops = 1 if m.bias is not None else 0
in_channels = m.in_channels - remove_in_channel_count
out_channels = m.out_channels - remove_out_channel_count
out_elements = node.next_tensors[0].nelement() // m.out_channels * out_channels
if is_dw_conv(m):
groups = in_channels
else:
groups = m.groups
total_ops += out_elements * (in_channels // groups * kernel_ops + bias_ops)
elif type(m) in (nn.BatchNorm2d,):
channels = m.num_features - remove_in_channel_count
nelements = node.prev_tensors[0].numel() // m.num_features * channels
total_ops += 2 * nelements
elif type(m) in (nn.AvgPool2d, nn.AdaptiveAvgPool2d):
channels = node.prev_tensors[0].size(1) - remove_in_channel_count
kernel_ops = 1
num_elements = node.prev_tensors[0].numel() // node.prev_tensors[0].size(1) * channels
total_ops += kernel_ops * num_elements
elif type(m) in (nn.ReLU,):
channels = node.prev_tensors[0].size(1) - remove_in_channel_count
kernel_ops = 1
num_elements = node.prev_tensors[0].numel() // node.prev_tensors[0].size(1) * channels
total_ops += kernel_ops * num_elements
elif type(m) in (nn.Linear,):
in_channels = m.in_features - remove_in_channel_count
out_channels = m.out_features - remove_out_channel_count
total_mul = in_channels
num_elements = node.next_tensors[0].numel() // m.out_features * out_channels
total_ops += total_mul * num_elements
return total_ops