tinynn/graph/quantization/algorithm/cross_layer_equalization.py (252 lines of code) (raw):
"""Cross Layer Equalization
Cross-Layer-Equalization can scale weights equivalently to reduce weight outliers in per_tensor mode.
You can use CLE to adjust the model weight, then do next ptq/qat(rep-model qat need to restore BN).
"""
import copy
import os
import functools
from typing import Tuple
import torch
import torch.nn as nn
import torch.quantization as torch_q
from tinynn.graph.tracer import model_tracer, trace, TraceNode
from tinynn.graph.quantization.quantizer import PostQuantizer, load_processed_ptq_rules
from tinynn.util.util import get_logger
from tinynn.util.bn_restore import model_restore_bn, ConvBnTrain
from tinynn.util.util import import_from_path
log = get_logger(__name__)
cls_support_type = (torch.nn.Conv2d, torch.nn.Conv1d, torch.nn.Linear)
cls_scalable_type = (torch.nn.ReLU, torch.nn.LeakyReLU, torch.nn.PReLU, torch.nn.Identity)
def is_group_supported(current_group):
"""Currently Supported layer combinations for CLS are:
1. [conv-conv]
"""
current_group_ = [mod for n, mod in current_group]
if (
len(current_group_) == 2
and isinstance(current_group_[0], cls_support_type)
and isinstance(current_group_[1], cls_support_type)
):
return True
else:
# Todo: more general CLE.
return False
def graph_traverse(node: TraceNode, layer_groups, current_group=None, visited_nodes=None):
"""Recursively traverse the computational graph and find all conv-groups that can be weight-equal."""
if visited_nodes is None:
visited_nodes = []
if node in visited_nodes:
return
if current_group is None:
current_group = []
# add cc or cdc to layer_group
if is_group_supported(current_group) and current_group not in layer_groups:
layer_groups.append(current_group)
current_group = [current_group[-1]]
visited_nodes.append(node)
if isinstance(node.module, cls_support_type):
current_group.append((node.unique_name, node.module))
if len(node.next_nodes) > 1 or not isinstance(node.module, (cls_scalable_type, cls_support_type)):
if is_group_supported(current_group) and current_group not in layer_groups:
layer_groups.append(current_group)
current_group = []
for n in node.next_nodes:
graph_traverse(n, layer_groups, current_group, visited_nodes)
current_group = []
def get_cls_set(cur_graph):
layer_groups = []
visited_nodes = []
for node in cur_graph.forward_nodes:
graph_traverse(node, layer_groups, visited_nodes=visited_nodes)
return layer_groups
def equalize(weight_1, weight_2, group=1, threshold=0.5, s_min=1e-6, s_max=1e6):
"""calculate scale for two layer according to their weights"""
shape_2 = weight_2.shape
# for group conv
weight_1_re = torch.reshape(weight_1, (weight_1.shape[0], -1))
weight_2_re = torch.reshape(
weight_2,
(
group,
shape_2[0] // group,
)
+ shape_2[1:],
)
num_dims = weight_2_re.dim()
assert num_dims >= 3, f"weight_2_re shape dim={num_dims}, <3"
new_order = [2, 0, 1] + list(range(3, num_dims))
weight_2_re = weight_2_re.permute(new_order)
weight_2_re = torch.reshape(weight_2_re, (weight_2_re.shape[0] * weight_2_re.shape[1], -1))
r1 = weight_1_re.abs().max(1).values.double()
r2 = weight_2_re.abs().max(1).values.double()
s = r1 / torch.sqrt(r1 * r2)
# ignore too small scale
s = torch.clamp(s, s_min, s_max)
# refuse to scale unnecessary layers pair
s = torch.where((r1 + r2) < threshold, torch.ones_like(s), s)
return s
def _weight_equal_helper(cls, threshold=0.5):
layer_pair = [m for n, m in cls]
if len(layer_pair) == 2:
conv_0, conv_1 = layer_pair
weight1, bias1, weight2, groups = (
conv_0.weight,
conv_0.bias,
conv_1.weight,
conv_1.groups if hasattr(conv_1, 'groups') else 1,
)
s = equalize(weight1, weight2, group=groups, threshold=threshold)
weight_1 = weight1 / s.reshape([-1] + ([1] * (weight1.ndim - 1)))
weight_2 = torch.reshape(weight2, (groups, weight2.shape[0] // groups) + weight2.shape[1:])
weight_2 *= torch.reshape(s, [groups, 1, -1] + [1] * (weight_2.ndim - 3))
weight_2 = torch.reshape(weight_2, (weight_2.shape[1] * groups,) + weight_2.shape[2:])
conv_0.weight.data.copy_(weight_1)
if conv_0.bias is not None:
conv_0.bias.data.copy_(bias1 / s)
conv_1.weight.data.copy_(weight_2)
else:
log.warning(f'layer_pair nums != 2, do not support, current layer:{cls}.')
def equalize_model(model: nn.Module, dummy_input, threshold=0.5, iters=2) -> Tuple[list, nn.Module]:
"""perform Cross-Layer Equalization(CLE) on the given model iters times.
Args:
model: The bn of model should be fused into conv.
dummy_input (torch.tensor): A viable input for the model.
threshold: Default to be 1000, used to prevent unquantifiable anomalies in the output of inter conv.
Returns:
typing.Tuple[List, nn.Module], layers groups and model after CLE.
"""
with torch.no_grad():
with model_tracer():
cur_graph = trace(model, dummy_input)
param = {}
for k, v in model.state_dict().items():
p, _ = cur_graph.get_submodule_with_parent_from_name(k)
if k.endswith('.weight'):
param[k] = p.abs().max()
elif k.endswith('.bias'):
param[k] = p.max()
layer_groups = get_cls_set(cur_graph)
for i in range(iters):
for cls in layer_groups:
_weight_equal_helper(cls, threshold)
stat_we = model.state_dict()
for k, v in stat_we.items():
p, mod = cur_graph.get_submodule_with_parent_from_name(k)
if isinstance(mod, cls_support_type):
if k.endswith('.weight'):
after_max = p.abs().max()
elif k.endswith('.bias'):
after_max = p.max()
if after_max.data.item() != param[k].data.item():
# Print the weight and bias change when applying CLE
log.info(f'{k}: {param[k].data.item():.5f} -> {after_max.data.item():.5f}')
return layer_groups, model
def cross_layer_equalize(
model: nn.Module, dummy_input, device, threshold=0.5, work_dir="out", cle_iters=2, hba_flag=False
) -> nn.Module:
"""Higher-level API to perform Cross-Layer Equalization(CLE) and High Bias Abosrb (HBA) on the given model.
Args:
model: The bn of model should be fused into conv.
dummy_input (torch.tensor): A viable input for the model.
device (torch.device): Specifies the device of the model.
threshold: Default to be 1000, used to prevent unquantifiable anomalies in the output of inter conv.
work_dir (typing.Optional[str], optional): The working directory in which the intermediate files will be
generated. Defaults to None, in which case "out" will be used.
cle_iters: The iteration nums of cle.
hba_flag: Whether to do HBA, default to be True.
Returns:
The model which has been done cle.
"""
model = model_rewrite(model, dummy_input, work_dir=work_dir)
model = model_fuse_bn(model, dummy_input)
log.info("start to do Cross Layer Equalization. the range change of weight/bias after CLE:")
layers_groups, model = equalize_model(model, dummy_input, threshold, iters=cle_iters)
if hba_flag:
log.info("start to do High Bias Absorbing. the range change of bias after HBA:")
model = high_bias_absorb(model, device, layers_groups)
clear_model_fused_bn(model)
return model
def bias_absorb_helper_(layer1, layer2, model, origin_model):
if not hasattr(layer1[1], 'bias') or not hasattr(layer2[1], 'bias'):
return
pre_layer = getattr(model, layer1[0])
cur_layer = getattr(model, layer2[0])
if isinstance(pre_layer, ConvBnTrain) and isinstance(cur_layer, ConvBnTrain):
# when use bn_restore to do HBA after CLE
pre_bn = pre_layer.bn
cur_conv = cur_layer.conv
elif isinstance(pre_layer, nn.Conv2d) and isinstance(pre_layer, nn.Conv2d):
if hasattr(pre_layer, 'fused_bn_') and hasattr(cur_layer, 'fused_bn_'):
pre_bn = pre_layer.fused_bn_
cur_conv = cur_layer
else:
log.info("High Bias Absorbing is not supported for conv without BatchNorm.")
return
# AIMET use BN's weight and bias to get 3sigma.
c = pre_bn.bias - 3 * torch.abs(pre_bn.weight)
zero = torch.zeros_like(c)
c = torch.where(c < 0, zero, c).to(torch.float)
cur_weight = cur_conv.weight.data
# sum along 3rd and 4rd aixs
reduced_weight = cur_weight.sum(dim=[2, 3])
if reduced_weight.shape[1] == 1:
# for dw conv
reduced_weight = reduced_weight.reshape(-1)
bias_correct = reduced_weight * c
else:
bias_correct = torch.matmul(reduced_weight, c)
cur_bias = cur_conv.bias + bias_correct
origin_pre_conv = getattr(origin_model, layer1[0])
origin_cur_conv = getattr(origin_model, layer2[0])
max_before = origin_pre_conv.bias.data.max()
origin_pre_conv.bias.data = origin_pre_conv.bias.data - c
origin_cur_conv.bias.data = cur_bias
if max_before != origin_pre_conv.bias.data.max():
log.info(f'{layer1[0]} bias: {max_before} -> {origin_pre_conv.bias.data.max()}')
def bias_absorb_(model, layers_groups, origin_model):
with torch.no_grad():
for layer_group in layers_groups:
if len(layer_group) == 2:
bias_absorb_helper_(layer_group[0], layer_group[1], model, origin_model)
else:
log.warning('Unsupported layer group')
def high_bias_absorb_empirical(
cle_model, device, layer_groups, cali_func=None, *cali_func_args, use_origin_bn=True, layers_fused_bn=None
):
"""Absorb bias value greater than 3 * sigma to next layer's bias, which use real data to get pre-bias
distribution."""
cle_model.to(device)
origin_model = copy.deepcopy(cle_model)
if use_origin_bn:
bias_absorb_(cle_model, layer_groups, origin_model)
else:
if cali_func is None:
log.warning(
"High Bias Absorbing can not run, you can setting args as below:\n"
"1. If your origin model has bn, please set `bn_fuse=True` at `cross_layer_equalize` \n"
"2. if your origin model do not have bn(e.g. RepVGG_deploy), please set the right "
"`cali_func`, `cali_func_arg` and `layers_fused_bn."
)
else:
if layers_fused_bn is None:
layers_fused_bn = [name for name, mod in cle_model.named_modules() if isinstance(mod, torch.nn.Conv2d)]
cle_bn_model = model_restore_bn(
cle_model, device, cali_func, *cali_func_args, layers_fused_bn=layers_fused_bn
)
bias_absorb_(cle_bn_model, layer_groups, origin_model)
clear_model_fused_bn(origin_model)
return origin_model
def high_bias_absorb(cle_model, device, layer_groups):
"""Absorb bias value greater than 3 * sigma to next layer's bias, which use origin BN to get pre-bias distribution.
Args:
cle_model: The model which has been done cle.
device: The appropriate device, e.g. torch.device("cuda").
layer_groups: The Layer groups which returned by CLE.
Return:
The model after HBA.
"""
cle_model.to(device)
origin_model = copy.deepcopy(cle_model)
bias_absorb_(cle_model, layer_groups, origin_model)
return origin_model
def model_fuse_bn(model: nn.Module, dummy_input):
"""Fuse bn to conv inplace, and attach the origin bn to fused conv with attr:`fused_bn_`"""
with model_tracer():
with torch.no_grad():
model.eval()
quantizer = PostQuantizer(
model,
dummy_input,
work_dir='out',
config={'rewrite_graph': False, 'force_overwrite': False, 'fuse_only': True},
)
graph = trace(quantizer.model, quantizer.dummy_input)
graph.quantized = True
for node in graph.forward_nodes:
node.quantized = True
custom_data = ([], set())
processed_rules = load_processed_ptq_rules()
processed_rules = {nn.BatchNorm2d: processed_rules[nn.BatchNorm2d]}
is_fusable = functools.partial(quantizer.is_fusable, current_rules=processed_rules, graph=graph)
graph.filter_forward_nodes(is_fusable, custom_data, reverse=True)
quant_list = custom_data[0]
for quant_nodes in quant_list:
if isinstance(getattr(graph.module, quant_nodes[1]), nn.BatchNorm2d):
bn_cur = getattr(graph.module, quant_nodes[1])
torch_q.fuse_modules(graph.module, quant_nodes, inplace=True)
if hasattr(getattr(graph.module, quant_nodes[0]), 'fused_bn_'):
log.warning("conv have attr fused_bn_, HBA can not apply on this conv")
else:
setattr(getattr(graph.module, quant_nodes[0]), 'fused_bn_', bn_cur)
fused_model = graph.module
return fused_model
def model_rewrite(model, dummy_input, work_dir='out') -> nn.Module:
"""rewrite model to non-block style"""
with model_tracer():
graph = trace(model, dummy_input)
model_name = type(model).__name__
model_rewrite = f'{model_name}_cle_Rewrite'
model_name_rewrite_lower = model_rewrite.lower()
model_ns = f'out.{model_name_rewrite_lower}'
model_code_path = os.path.join(work_dir, f'{model_name_rewrite_lower}.py')
model_weights_path = os.path.join(work_dir, f'{model_name_rewrite_lower}.pth')
graph.eliminate_dead_graph_pass()
if not os.path.exists(work_dir):
os.makedirs(work_dir)
graph.generate_code(model_code_path, model_weights_path, model_rewrite)
# Import the new model
rewritten_model = import_from_path(model_ns, model_code_path, model_rewrite)()
rewritten_model.load_state_dict(torch.load(model_weights_path))
os.unlink(model_weights_path)
return rewritten_model
def clear_model_fused_bn(model: nn.Module):
"""remove the attached bn from fused conv"""
for mod in model.modules():
if isinstance(mod, (nn.Conv2d, nn.ConvTranspose2d)) and hasattr(mod, 'fused_bn_'):
delattr(mod, 'fused_bn_')