tinynn/graph/modifier.py (2,585 lines of code) (raw):
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
import typing
from tinynn.graph.tracer import TraceNode, TraceGraph
from tinynn.util.util import get_logger
from tinynn.graph import masker
import numpy as np
log = get_logger(__name__)
def complementary_list(a, b):
return list(set(a).difference(set(b)))
def get_smallest_k(lst, k, offset=0):
idx_lst = [(i, float(lst[i])) for i in range(len(lst))]
sorted_lst = sorted(idx_lst, key=lambda x: x[1])
sorted_lst_k = sorted_lst[:k]
idx = [sorted_lst_k[i][0] + offset for i in range(len(sorted_lst_k))]
return sorted(idx)
def rnn_gate_size(module: nn.Module) -> int:
"""the gate size of the recurrent modules"""
if isinstance(module, nn.RNN):
return 1
elif isinstance(module, nn.GRU):
return 3
elif isinstance(module, nn.LSTM):
return 4
else:
raise AttributeError(f'gate size of {type(module)} is unknown')
def update_weight_metric(importance, metric_func, module, name):
if type(module) in [nn.Linear, nn.Conv2d, nn.Conv1d, nn.ConvTranspose2d, nn.ConvTranspose1d]:
importance[name] = metric_func(module.weight, module)
elif type(module) in [nn.GRU, nn.LSTM, nn.RNN]:
num_directions = 2 if module.bidirectional else 1
has_proj = hasattr(module, 'proj_size') and module.proj_size > 0
gs = rnn_gate_size(module)
weights = []
if has_proj:
for i in range(module.num_layers):
weight_hrs = []
for j in range(num_directions):
suffix = '_reverse' if j > 0 else ''
weight_hr = getattr(module, f'weight_hr_l{i}{suffix}')
weight_hrs.append(weight_hr)
weights.append(torch.cat(weight_hrs, dim=0))
importance[name] = metric_func(weights, module)
weights.clear()
name = f'{name}:h'
for i in range(module.num_layers):
weight_ihs = []
weight_hhs = []
for j in range(num_directions):
suffix = '_reverse' if j > 0 else ''
weight_ih = getattr(module, f'weight_ih_l{i}{suffix}')
weight_hh = getattr(module, f'weight_hh_l{i}{suffix}')
weight_ihs.append(weight_ih)
weight_hhs.append(weight_hh)
if gs == 1:
weights.append(torch.cat(weight_ihs, dim=0))
weights.append(torch.cat(weight_hhs, dim=0))
else:
w_ih_splits = zip(*[torch.unbind(x.view(gs, module.hidden_size, -1)) for x in weight_ihs])
w_hh_splits = zip(*[torch.unbind(x.view(gs, module.hidden_size, -1)) for x in weight_hhs])
ih_gate_weights = [torch.cat(x) for x in w_ih_splits]
hh_gate_weights = [torch.cat(x) for x in w_hh_splits]
weights.extend(ih_gate_weights)
weights.extend(hh_gate_weights)
importance[name] = metric_func(weights, module)
else:
raise AttributeError(f'{type(module).__name__}({name}) is not supported for importance calculation')
def random(tensor, module):
if type(module) in [nn.Linear, nn.Conv2d, nn.Conv1d]:
return torch.randperm(tensor.shape[0])
if type(module) in [nn.ConvTranspose2d, nn.ConvTranspose1d]:
return torch.randperm(tensor.shape[1])
if type(module) in [nn.GRU, nn.LSTM, nn.RNN]:
assert isinstance(tensor, (tuple, list))
return torch.randperm(tensor[0].shape[1])
def l1_norm(tensor, module):
"""Calculate the L1-normalization of each channel"""
if type(module) in [nn.Conv2d]:
return torch.norm(tensor, p=1, dim=[1, 2, 3])
if type(module) in [nn.Conv1d]:
return torch.norm(tensor, p=1, dim=[1, 2])
if type(module) in [nn.Linear]:
return torch.norm(tensor, p=1, dim=[1])
if type(module) in [nn.ConvTranspose2d]:
return torch.norm(tensor, p=1, dim=[0, 2, 3])
if type(module) in [nn.ConvTranspose1d]:
return torch.norm(tensor, p=1, dim=[0, 2])
if type(module) in [nn.GRU, nn.LSTM, nn.RNN]:
assert isinstance(tensor, (tuple, list))
return torch.sum(torch.stack([torch.norm(t, p=1, dim=[1]) for t in tensor]), dim=0)
def l2_norm(tensor, module):
"""Calculate the L2-normalization of each channel"""
if type(module) in [nn.Conv2d]:
return torch.norm(tensor, p=2, dim=[1, 2, 3])
if type(module) in [nn.Conv1d]:
return torch.norm(tensor, p=2, dim=[1, 2])
if type(module) in [nn.Linear]:
return torch.norm(tensor, p=2, dim=[1])
if type(module) in [nn.ConvTranspose2d]:
return torch.norm(tensor, p=2, dim=[0, 2, 3])
if type(module) in [nn.ConvTranspose1d]:
return torch.norm(tensor, p=2, dim=[0, 2])
if type(module) in [nn.GRU, nn.LSTM, nn.RNN]:
assert isinstance(tensor, (tuple, list))
return torch.sum(torch.stack([torch.norm(t, p=2, dim=[1]) for t in tensor]), dim=0)
def fpgm(tensor, module):
"""Calculate the geometric median (Filter Pruning via Geometric Median for Deep Convolutional Neural
Networks Acceleration, https://arxiv.org/abs/1811.00250)"""
assert type(module) in [nn.Linear, nn.Conv2d]
num_channels = tensor.shape[0]
batched_weight = tensor.view(num_channels, -1)
return torch.cdist(batched_weight, batched_weight, p=2).abs().sum(0)
def is_dw_conv(module):
"""Check whether the model is depth-wise convolution"""
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d, nn.ConvTranspose1d)):
if module.in_channels == module.groups == module.out_channels:
return True
return False
def merge_group(group: list):
new_group = []
while len(group) > 0:
loop = False
a = group[0]
for b in group[1:]:
if len(a & b) > 0:
k = a & b
m = a - k
n = b - k
group.remove(a)
group.remove(b)
group += [i for i in [m, n, k] if len(i) > 0]
loop = True
break
if loop:
continue
new_group.append(a)
group.remove(a)
group[:] = new_group[:]
for i in range(len(group)):
gi = group[i]
for gj in group[i + 1 :]:
if gi == gj and gi is not gj:
new_group.remove(gi)
break
group[:] = new_group[:]
group.sort()
return group
def merge_constraint(constraint: typing.List[typing.Set]):
"""Merge all constraints with intersection"""
if {-1.0} in constraint:
constraint.remove({-1.0})
value_mapping = dict()
# remove empty constraint
for idx in reversed(range(len(constraint))):
if len(constraint[idx]) == 0:
del constraint[idx]
continue
# build value_map,used to quickly find which sets the value is in
for idx in range(len(constraint)):
for value in constraint[idx]:
value_mapping[value] = value_mapping.get(value, [])
value_mapping[value].append(idx)
reindex_list = [i for i in range(len(constraint))]
# for quickly finding all sets that redirect to the same set,
homology_idx = {i: [i] for i in range(len(constraint))}
for value, idx_need_merge in value_mapping.items():
if len(idx_need_merge) <= 1:
continue
target_idx = reindex_list[idx_need_merge[0]]
for i in idx_need_merge:
if reindex_list[i] == target_idx:
continue
src_idx = reindex_list[i]
constraint[target_idx] = constraint[target_idx] | constraint[src_idx]
for j in homology_idx[src_idx]:
reindex_list[j] = target_idx
homology_idx[target_idx].append(j)
homology_idx[src_idx] = []
valid_idx = sorted(list(set(reindex_list)))
new_constraint = [constraint[i] for i in valid_idx]
constraint[:] = new_constraint[:]
return constraint
def calc_dim_constraint(tensor: torch.Tensor, dim_changes: typing.List):
"""Count all constraints under a dimension"""
constraints = {}
arr = tensor.detach().numpy()
for dim in dim_changes:
constraints[dim] = []
for i in range(tensor.shape[dim]):
arr_i = arr.take(i, axis=dim)
constraint = np.unique(arr_i)
constraint = set(constraint.tolist())
constraints[dim].append(constraint)
return constraints
def calc_dim_changes(node, tensors_i, tensors_o=None) -> typing.List[typing.Tuple[typing.List, torch.Tensor]]:
"""Calculate which dimensions of tensor have changed"""
def vector_wrapper(t):
if type(t) not in [tuple, list]:
return (t,)
else:
return t
# operator inference
if isinstance(node, nn.Module):
tensors = vector_wrapper(node(*tensors_i))
else:
with torch.no_grad():
tensors = vector_wrapper(node(vector_wrapper(tensors_i)))
if tensors_o is not None:
for i in range(len(tensors_o)):
tensors_o[i].data.copy_(tensors[i].clone().data)
else:
tensors_o = tensors
dim_changes = []
for tensor_o in tensors_o:
dim_change = []
for i in range(len(tensor_o.shape)):
reduce_dim = [j for j in range(len(tensor_o.shape)) if j != i]
if not reduce_dim:
value = set(tensor_o.detach().tolist())
else:
value = set(torch.sum(tensor_o, dim=reduce_dim).detach().tolist())
if len(value) > 1:
if i not in dim_change:
dim_change.append(i)
dim_changes.append((dim_change, tensor_o))
return dim_changes
class DimensionChangeInfo(object):
def __init__(self, modifier: 'Modifier'):
self.modifier = modifier
# Which dimensions of the input/output tensor will change
self.dim_changes_i = OrderedDict()
self.dim_changes_o = OrderedDict()
# All center nodes(operators that change actively, such as conv2d, linear in the pruning process)
self.centers = OrderedDict()
self.tensor_changes = OrderedDict()
self.tensor_keys = OrderedDict()
# The dimension of the final pruning (when multiple dimensions may change, a dimension will be
# selected through dependency analysis and conflict elimination)
self.dim_choices = OrderedDict()
self.tensor_choices = OrderedDict()
self.dim_transform = None
# The mapping relationship between the current node and the central node
self.constraints_i = OrderedDict()
self.constraints_o = OrderedDict()
# In operators such as grouped convolution and bidirectional LSTM, each tensor needs to be
# modified uniformly according to the group.
self.groups_i = []
self.groups_o = []
# Pruning index for input and output ( for structured pruning, it may be a channel,
# for unstructured pruning, it may be every point)
self.pruned_idx_i = []
self.pruned_idx_o = []
def build_key(self, center, tensor):
"""Generate human-readable keys to reduce the debugging cost of complex computational graphs"""
pre_tensor_idx = [id(t) for t in self.modifier.pre_tensors()]
nxt_tensor_idx = [id(t) for t in self.modifier.next_tensors()]
if id(tensor) in pre_tensor_idx:
tensor_idx = pre_tensor_idx.index(id(tensor))
return f'{center.unique_name()}:input_{tensor_idx}'
elif id(tensor) in nxt_tensor_idx:
tensor_idx = nxt_tensor_idx.index(id(tensor))
return f'{center.unique_name()}:output_{tensor_idx}'
else:
assert False
def build_choice_key(self, tensor):
"""Generate human-readable keys to reduce the debugging cost of complex computational graphs"""
pre_tensor_idx = [id(t) for t in self.modifier.pre_tensors()]
nxt_tensor_idx = [id(t) for t in self.modifier.next_tensors()]
if id(tensor) in pre_tensor_idx:
tensor_idx = pre_tensor_idx.index(id(tensor))
return f'input_{tensor_idx}'
elif id(tensor) in nxt_tensor_idx:
tensor_idx = nxt_tensor_idx.index(id(tensor))
return f'output_{tensor_idx}'
else:
assert False
def is_multi_dim_changed(self):
dim_change_i = self.merge_i()
dim_change_o = self.merge_o()
if len(dim_change_i) > 1 or len(dim_change_o) > 1:
return True
return False
def is_changes_conflict(self, tensor_changes):
if len(tensor_changes) == 0 and len(self.tensor_changes) > 0:
return True
for tensor_id, dim_choose in tensor_changes.items():
dim_changes_flat = list()
for dim_changes in self.tensor_changes[tensor_id]:
dim_changes_flat += dim_changes
if not set(dim_choose).issubset(set(dim_changes_flat)):
return True
return False
def merge_t(self, tensor):
"""Merge the dimension change information of all tensors"""
dim_change = set()
for change in self.tensor_changes[id(tensor)]:
dim_change.update(change)
return sorted(list(dim_change))
def merge_i(self) -> typing.List:
"""Merge the dimension change information of input tensors"""
dim_change = set()
for t in self.modifier.pre_tensors():
if id(t) not in self.tensor_changes.keys():
continue
for change in self.tensor_changes[id(t)]:
dim_change.update(change)
return sorted(list(dim_change))
def merge_o(self) -> typing.List:
"""Merge the dimension change information of output tensors"""
dim_change = set()
for t in self.modifier.next_tensors():
# The tensor used internally such as hidden state in RNN is not included in the dependency analysis
if id(t) in self.tensor_changes:
for change in self.tensor_changes[id(t)]:
dim_change.update(change)
return sorted(list(dim_change))
def update_i(
self,
center: 'Modifier',
tensor: torch.Tensor,
dim_changes: typing.List,
dim_transform=None,
update_constraint=True,
tensor_constraint=None,
):
"""Update the dimension change information of the input tensor"""
if dim_transform is not None:
self.dim_transform = dim_transform
constraint_i = None
if update_constraint:
if tensor_constraint is not None:
constraint_i = tensor_constraint
else:
constraint_i = calc_dim_constraint(tensor, dim_changes)
# Redirect pruning constraints to central node
if dim_transform:
for dim, constraint in constraint_i.items():
for i in range(len(constraint)):
new_constraint = set()
for c in constraint[i]:
if c in dim_transform.keys():
transformed_idx = dim_transform[c]
new_constraint.update(transformed_idx)
if len(new_constraint) > 0:
constraint[i] = new_constraint
for dim, constraint in constraint_i.items():
if dim not in self.constraints_i:
self.constraints_i[dim] = {}
self.constraints_i[dim][center.unique_name()] = self.constraints_i[dim].get(center.unique_name(), [])
self.constraints_i[dim][center.unique_name()].append(constraint)
self.update_(self.dim_changes_i, center, tensor, dim_changes)
return constraint_i
def update_o(
self,
center: 'Modifier',
tensor: torch.Tensor,
dim_changes: typing.List,
update_constraint=False,
default_constraint=None,
):
"""Update the dimension change information of the output tensor"""
constraint_o = None
if update_constraint:
if default_constraint is not None:
constraint_o = default_constraint
else:
constraint_o = calc_dim_constraint(tensor, dim_changes)
for dim, constraint in constraint_o.items():
if dim not in self.constraints_o:
self.constraints_o[dim] = {}
self.constraints_o[dim][center.unique_name()] = self.constraints_o[dim].get(center.unique_name(), [])
self.constraints_o[dim][center.unique_name()].append(constraint)
self.update_(self.dim_changes_o, center, tensor, dim_changes)
return constraint_o
def update_(self, dim_changes_dict: typing.Dict, center: 'Modifier', tensor, dim_changes: typing.List):
key = self.build_key(center, tensor)
if key not in dim_changes_dict.keys():
dim_changes_dict[key] = dim_changes
if id(tensor) not in self.tensor_changes:
self.tensor_changes[id(tensor)] = []
self.tensor_keys[id(tensor)] = []
self.tensor_changes[id(tensor)].append(dim_changes)
self.tensor_keys[id(tensor)].append(key)
self.tensor_keys[id(tensor)] = list(set(self.tensor_keys[id(tensor)]))
self.centers[center.unique_name()] = center
return self
def update_choice(self, tensor, choice):
key = self.build_choice_key(tensor)
self.dim_choices[key] = choice
self.tensor_choices[id(tensor)] = choice
def get_neighbor_changes(self, center, neighbor):
if isinstance(center, TraceNode) and isinstance(neighbor, TraceNode):
center = center.modifier
neighbor = neighbor.modifier
changes = []
if neighbor in self.modifier.pre_modifiers():
for t in self.modifier.pre_tensors(neighbor):
changes.append(self.dim_changes_i.get(self.build_key(center, t), None))
else:
for t in self.modifier.next_tensors(neighbor):
changes.append(self.dim_changes_o.get(self.build_key(center, t), None))
if changes == [None]:
return None
return changes
def get_neighbor_choices(self, neighbor):
if isinstance(neighbor, TraceNode):
neighbor = neighbor.modifier
choices = []
if neighbor in self.modifier.pre_modifiers():
for t in self.modifier.pre_tensors(neighbor):
choices.append(self.get_tensor_choices(t))
else:
for t in self.modifier.next_tensors(neighbor):
choices.append(self.get_tensor_choices(t))
if choices == [None]:
return None
return choices
def get_tensor_choices(self, tensor) -> typing.List:
return self.dim_choices.get(self.build_choice_key(tensor), None)
def get_tensor_changes(self, tensor) -> typing.List:
return self.tensor_changes[id(tensor)]
def get_input_centers(self):
center_names = []
for key, value in self.dim_changes_i.items():
center_names.append(key.split(":")[0])
return center_names
def rebuild(self):
"""Reconstruct the entire dimension change information according to dim_choice"""
valid_changes = []
for tensor_id, choice in self.tensor_choices.items():
for key in self.tensor_keys[tensor_id]:
if 'input' in key:
tensor_change = self.dim_changes_i[key]
else:
tensor_change = self.dim_changes_o[key]
if set(choice).issubset(set(tensor_change)):
center_name = key.split(":")[0]
center = self.centers[center_name]
all_tensors = self.modifier.pre_tensors() + self.modifier.next_tensors()
tensor = [t for t in all_tensors if id(t) == tensor_id][0]
valid_changes.append((center, tensor, tensor_change))
self.dim_changes_i = OrderedDict()
self.dim_changes_o = OrderedDict()
self.centers = OrderedDict()
self.tensor_changes = OrderedDict()
self.tensor_keys = OrderedDict()
constraint_i_new = OrderedDict()
constraint_o_new = OrderedDict()
for changes in valid_changes:
center, tensor, tensor_change = changes
if self.modifier.is_pre_tensor(tensor):
self.update_i(center, tensor, tensor_change, update_constraint=False)
constraint_old = self.constraints_i
constraint_new = constraint_i_new
else:
self.update_o(center, tensor, tensor_change, update_constraint=False)
constraint_old = self.constraints_o
constraint_new = constraint_o_new
choice = self.get_tensor_choices(tensor)
for dim, constraints in constraint_old.items():
if dim in choice:
if dim not in constraint_new:
constraint_new[dim] = {}
constraint_new[dim] = constraints
for dim, dim_constraints in constraint_i_new.items():
for center_name, constraints in dim_constraints.items():
if len(constraints) == 1:
continue
merge = [set() for i in constraints[0]]
for constraint in constraints:
for i in range(len(constraint)):
if constraint[i] != {-1}:
merge[i].update(constraint[i])
constraints[:] = [merge]
self.constraints_i = constraint_i_new
self.constraints_o = constraint_o_new
def __str__(self):
return (
f"dim_changes_i:{str(self.dim_changes_i)}, dim_changes_o:{str(self.dim_changes_o)},"
f" dim_choices:{str(self.dim_choices)}"
)
class Modifier(object):
graph_modifier: "GraphChannelModifier"
node: TraceNode
dim_changes_info: DimensionChangeInfo
forward_dim_mapping: typing.Dict[int, typing.Dict[int, typing.Dict[int, typing.Set]]]
backward_dim_mapping: typing.Dict[int, typing.Dict[int, typing.Dict[int, typing.Set]]]
prunable: bool
weight_mask: typing.Dict[str, torch.Tensor]
bias_mask: typing.Dict[str, torch.Tensor]
def __init__(self, node: TraceNode):
self.graph_modifier = None
self.node = node
# Tensor change dependencies between this operator and other operators
self.dim_changes_info = DimensionChangeInfo(self)
# When the dimension of the input/output tensor changes, the dimension of the affected output/input tensor
self.forward_dim_mapping = OrderedDict()
self.backward_dim_mapping = OrderedDict()
# Whether the operator allows pruning (for example, conv2d, linear, rnn, etc. are pruned,
# add, mul, etc. are not pruned)
self.prunable = False
self.weight_mask = OrderedDict()
self.bias_mask = OrderedDict()
self.mask_applied = False
self.tensor_id_to_str = {}
for i in range(len(self.pre_tensors())):
self.tensor_id_to_str[id(self.pre_tensors()[i])] = f"input_{i}"
for i in range(len(self.next_tensors())):
self.tensor_id_to_str[id(self.next_tensors()[i])] = f"output_{i}"
self.constant_node = False
if len(self.pre_tensors()) == 1:
self.constant_node = True
for t in self.next_tensors():
if isinstance(t, torch.Tensor) and len(t.shape) > 0:
self.constant_node = False
break
def __hash__(self):
return hash(self.unique_name())
def __eq__(self, other: 'Modifier'):
return self.unique_name() == other.unique_name()
def args_parsed(self):
return self.node.module.args_parsed
def masker(self) -> masker.ChannelMasker:
return getattr(self.node.module, "masker", None)
def enable_mask(self):
if self.masker() is not None:
self.masker().enable()
def disable_mask(self):
if self.masker() is not None:
self.masker().disable()
def reset_mask(self):
self.weight_mask.clear()
self.bias_mask.clear()
if hasattr(self.module(), "weight"):
self.weight_mask["weight"] = torch.ones_like(self.module().weight)
if hasattr(self.module(), "bias"):
self.bias_mask["bias"] = (
torch.ones_like(self.module().bias) if type(self.module().bias) is torch.nn.Parameter else None
)
def register_mask(self, modifiers, importance, sparsity):
if self.masker() is not None:
self.masker().set_in_remove_idx(self.dim_changes_info.pruned_idx_i)
self.masker().set_ot_remove_idx(self.dim_changes_info.pruned_idx_o)
def apply_mask(self, modifiers):
"""Use mask to modify the channel of the operator"""
if self.masker() is not None and self.masker().in_remove_idx is not None:
self.modify_input(self.masker().in_remove_idx)
if self.masker() is not None and self.masker().ot_remove_idx is not None:
self.modify_output(self.masker().ot_remove_idx)
self.mask_applied = True
def modify_input(self, remove_idx):
"""Modify the input tensor of the operator"""
pass
def modify_output(self, remove_idx):
"""Modify the output tensor of the operator"""
pass
def module(self):
return self.node.module
def unique_name(self):
return self.node.unique_name
def is_pre_tensor(self, tensor):
return id(tensor) in [id(x) for x in self.pre_tensors()]
def is_nxt_tensor(self, tensor):
return id(tensor) in [id(x) for x in self.next_tensors()]
def pre_tensors(self, parent: 'Modifier' = None, non_constant=False):
if parent is None:
if non_constant:
return [t for t in self.node.prev_tensors if len(t.shape) > 0]
return self.node.prev_tensors
tensors = []
for t in parent.next_tensors():
if self.is_pre_tensor(t):
if non_constant and len(t.shape) == 0:
continue
tensors.append(t)
return tensors
def next_tensors(self, child: 'Modifier' = None, non_constant=False):
if child is None:
if non_constant:
return [t for t in self.node.next_tensors if len(t.shape) > 0]
return self.node.next_tensors
tensors = []
for t in child.pre_tensors():
if self.is_nxt_tensor(t):
if non_constant and len(t.shape) == 0:
continue
tensors.append(t)
return tensors
def pre_modifiers(self, edge: torch.Tensor = None) -> typing.List['Modifier']:
if edge is None:
return [n.modifier for n in self.node.prev_nodes]
modifiers = []
for m in self.pre_modifiers():
for t in m.next_tensors():
if t is edge:
modifiers.append(m)
return modifiers
def next_modifiers(self, edge: torch.Tensor = None) -> typing.List['Modifier']:
if edge is None:
return [n.modifier for n in self.node.next_nodes]
modifiers = []
for m in self.next_modifiers():
for t in m.pre_tensors():
if t is edge:
modifiers.append(m)
return modifiers
def get_pruned_idx(self, modifiers):
"""Obtain the input tensor pruning information from the pruning information of other operators"""
pruned_idx = set()
input_modify_dim = self.dim_changes_info.get_tensor_choices(self.pre_tensors()[0])[0]
for center_name, _ in self.dim_changes_info.centers.items():
center = modifiers[center_name]
center_pruned_idx = set(center.dim_changes_info.pruned_idx_o)
constraints_i = self.dim_changes_info.constraints_i[input_modify_dim][center_name]
for constraint_i in constraints_i:
for leaf_idx in range(len(constraint_i)):
center_idx = constraint_i[leaf_idx]
if len(center_idx & center_pruned_idx) > 0:
pruned_idx.add(leaf_idx)
pruned_idx = list(pruned_idx)
pruned_idx.sort()
sparsity = len(pruned_idx) / self.pre_tensors()[0].shape[input_modify_dim]
return pruned_idx, sparsity
def calc_idx_group(self):
return None
def calc_dim_changes(self) -> typing.List[typing.Tuple[typing.List, torch.Tensor]]:
return calc_dim_changes(self.module(), self.pre_tensors(), self.next_tensors())
def change_dimension(self) -> bool:
return False
def dim_change_forward(self, center, tensor, dim_changes_i: typing.List, dim_transform, tensor_constraint):
# Leaf nodes require no additional computation
if len(self.next_tensors()) == 0:
if len(self.pre_tensors()) > 1:
for dim_change_i in dim_changes_i:
dims = [t.shape[dim_change_i] for t in self.pre_tensors()]
if len(set(dims)) > 1:
log.warning(f"Skip the {self.unique_name()} because the input shape is inconsistent")
return True
self.dim_changes_info.update_i(center, tensor, dim_changes_i, dim_transform)
return True
if self.node.kind() == "data":
return True
# Skip constant node
if self.constant_node:
return True
# The default implementation is regarded as Identity()
# Directly inheriting the dim_constraint of the previous layer, reducing the amount of calculation
tensor_constraint = self.dim_changes_info.update_i(
center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
)
for tensor_o in self.next_tensors():
if id(tensor) == id(tensor_o):
continue
try:
tensor_o.copy_(tensor.clone())
except Exception as e:
log.error(
f"error modifier = {self.unique_name()}, type = {type(self.module())}, kind = {self.node.kind()}"
)
raise e
for tensor_o in self.next_tensors():
# Case [1, c0, c1] + [c0, c1](center_node) -> [1, c0, c1], to keep dim_change_o keep consistent.
if len(tensor_o.shape) > len(tensor.shape) and tensor_o.shape[0] == 1:
old_dim_change_i = dim_changes_i
omitted_dim_len = 1
dim_changes_i = [i + omitted_dim_len for i in dim_changes_i]
for dim_ in old_dim_change_i:
tensor_constraint[dim_ + omitted_dim_len] = tensor_constraint[dim_]
tensor_constraint.pop(dim_)
self.dim_changes_info.update_o(center, tensor_o, dim_changes_i)
for m in self.next_modifiers(tensor_o):
# The identity() operator does not change the constraint, so it can directly pass its own
# constraints to reduce the calculation of the next layer
m.dim_change_forward(center, tensor_o, dim_changes_i, dim_transform, tensor_constraint)
def calc_dim_mapping(self) -> bool:
"""Calculate the dimension change map between input and output tensor"""
pre_tensors = self.pre_tensors(non_constant=True)
# input 的维度数量必须相同,否则需要创建一个子类单独实现
input_dim_num = len(pre_tensors[0].shape)
for dim_change_i in range(input_dim_num):
for tensor_i in self.pre_tensors(non_constant=True):
fill_tensor_by_dim_changes(tensor_i, [dim_change_i])
for dim_changes in self.calc_dim_changes():
dim_changes_o, tensor_o = dim_changes
for dim_change_o in dim_changes_o:
id_o = id(tensor_o)
for tensor_i in self.pre_tensors(non_constant=True):
id_i = id(tensor_i)
self.forward_dim_mapping[id_i][dim_change_i][id_o].add(dim_change_o)
if len(tensor_o.shape) > 0:
self.backward_dim_mapping[id_o][dim_change_o][id_i].add(dim_change_i)
return True
def init_dim_mapping(self) -> bool:
# Init the dimension change map between input and output tensor
# TODO:use cache to speed up
if len(self.forward_dim_mapping) > 0:
return True
# this is a constant or I/O node
if len(self.pre_tensors()) == 0 or len(self.next_tensors()) == 0:
return False
pre_tensors = self.pre_tensors(non_constant=True)
nxt_tensors = self.next_tensors(non_constant=True)
for tensor_i in pre_tensors:
self.forward_dim_mapping[id(tensor_i)] = OrderedDict()
for i in range(len(tensor_i.shape)):
self.forward_dim_mapping[id(tensor_i)][i] = OrderedDict()
for tensor_o in nxt_tensors:
self.forward_dim_mapping[id(tensor_i)][i][id(tensor_o)] = set()
for tensor_o in nxt_tensors:
self.backward_dim_mapping[id(tensor_o)] = OrderedDict()
for i in range(len(tensor_o.shape)):
self.backward_dim_mapping[id(tensor_o)][i] = OrderedDict()
for tensor_i in pre_tensors:
self.backward_dim_mapping[id(tensor_o)][i][id(tensor_i)] = set()
if not self.calc_dim_mapping():
return False
return True
def print_dim_mapping(self):
mapping_str = []
mappings = OrderedDict({**self.forward_dim_mapping, **self.backward_dim_mapping})
for id_1, v1 in mappings.items():
name_1 = self.tensor_id_to_str[id_1]
for dim_1, v2 in v1.items():
for id_2, dim_2 in v2.items():
name_2 = self.tensor_id_to_str[id_2]
format_str = f"{name_1}:{dim_1}->{name_2}:{dim_2}"
mapping_str.append(format_str)
log.debug("\n".join(mapping_str))
return mapping_str
def dim_choose(self, tensor_changes: typing.Dict[int, typing.List]):
"""When multiple dimensions are variable, select one of the dimensions.
The selection method of each operator and pruning algorithm may be different"""
return True
def dim_choose_traversal(
self,
modifiers: typing.List,
tensor_choices: typing.Dict[int, typing.List],
tensor: torch.Tensor,
):
"""Propagate the selected dimension to all relevant operators"""
if self in modifiers:
return True
if self.constant_node:
return True
dim_choice = tensor_choices[id(tensor)]
changed = False
dim_changes = self.dim_changes_info.get_tensor_changes(tensor)
for dim_change in dim_changes:
if dim_choice != dim_change:
changed = True
break
# dim choose 和 dim_change 完全相同,说明此tensor完全未发生变化,无需传播
if not changed:
return True
# 根据dim choose重新计算input、output的dim change
tensor_choices_cur = self.calc_dim_choices(tensor, dim_choice)
if len(tensor_choices_cur) == 0 and len(dim_choice) > 0:
return True
# 根据dim choose计算出的dim change存在冲突(例如input_0只支持裁剪dim=0,而input_1只支持裁剪dim=1)
if self.dim_changes_info.is_changes_conflict(tensor_choices_cur):
log.warning(
f'[{self.unique_name()}][{self.tensor_id_to_str[id(tensor)]}][{dim_choice}] dim choose conflict'
)
return False
modifiers.append(self)
tensor_choices.update(tensor_choices_cur)
valid = True
for m in self.pre_modifiers():
if not valid:
break
for t in self.pre_tensors(m):
if id(t) not in tensor_choices:
continue
if not m.dim_choose_traversal(modifiers, tensor_choices, t):
valid = False
break
if valid:
for m in self.next_modifiers():
if not valid:
break
for t in self.next_tensors(m):
if id(t) not in tensor_choices:
continue
if not m.dim_choose_traversal(modifiers, tensor_choices, t):
valid = False
break
if not valid:
modifiers.remove(self)
for tensor_id in tensor_choices_cur.keys():
if tensor_id in tensor_choices.keys():
del tensor_choices[tensor_id]
return False
return True
def calc_dim_choices(self, tensor, dim_choose: typing.List) -> typing.Dict[int, typing.List]:
"""Select the dimension of the current operator according to the dimension selection of other operators"""
# For the identity operator, the dim_choose of all tensors is the same
tensor_choices = {}
if not self.init_dim_mapping():
return tensor_choices
if self.is_pre_tensor(tensor):
if len(dim_choose) > 0:
for t in self.pre_tensors(non_constant=True):
tensor_choices[id(t)] = dim_choose
dim_choices_o = OrderedDict()
for i in dim_choose:
choices = self.forward_dim_mapping[id(tensor)][i]
for tid, choice in choices.items():
if tid not in dim_choices_o:
dim_choices_o[tid] = set()
dim_choices_o[tid].update(choice)
for tid, choice in dim_choices_o.items():
# Even the empty set must be passed forward, otherwise the selection between nodes may be inconsistent
tensor_choices[tid] = list(choice)
elif self.is_nxt_tensor(tensor):
if len(dim_choose) > 0:
for t in self.next_tensors(non_constant=True):
tensor_choices[id(t)] = dim_choose
dim_choices_i = OrderedDict()
for i in dim_choose:
choices = self.backward_dim_mapping[id(tensor)][i]
for tid, choice in choices.items():
if tid not in dim_choices_i:
dim_choices_i[tid] = set()
dim_choices_i[tid].update(choice)
for tid, choice in dim_choices_i.items():
tensor_choices[tid] = list(choice)
else:
assert False
return tensor_choices
class PaddingModifier(Modifier):
def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tensor_constraint):
changes = self.calc_dim_changes()
tensor_o = changes[0][1]
self.dim_changes_info.update_i(
center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
)
self.dim_changes_info.update_o(center, tensor_o, dim_changes_i)
# padding will change index info
fill_tensor_by_dim_changes(self.next_tensors()[0], dim_changes_i)
constraint_i = self.dim_changes_info.constraints_i[dim_changes_i[0]][center.unique_name()][0]
transform = OrderedDict()
for i in range(len(constraint_i)):
transform[i] = constraint_i[i]
for m in self.next_modifiers(tensor_o):
m.dim_change_forward(center, tensor_o, dim_changes_i, transform, None)
class PoolingModifier(Modifier):
def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tensor_constraint):
assert dim_changes_i == [1], "Pooling2D only support change channel dimension."
if isinstance(self.node.module, nn.Module):
output = self.node.module(self.pre_tensors()[0])
tensor_o = self.next_tensors()[0]
tensor_o.data.copy_(output.data)
else:
changes = self.calc_dim_changes()
tensor_o = changes[0][1]
self.dim_changes_info.update_i(
center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
)
self.dim_changes_info.update_o(center, tensor_o, dim_changes_i)
# padding will change index info
fill_tensor_by_dim_changes(self.next_tensors()[0], dim_changes_i)
constraint_i = self.dim_changes_info.constraints_i[dim_changes_i[0]][center.unique_name()][0]
transform = OrderedDict()
for i in range(len(constraint_i)):
transform[i] = constraint_i[i]
for m in self.next_modifiers(tensor_o):
m.dim_change_forward(center, tensor_o, dim_changes_i, transform, None)
class PReLUChannelModifier(Modifier):
def register_mask(self, modifiers, importance, sparsity):
pruned_idx, sparsity = self.get_pruned_idx(modifiers)
self.dim_changes_info.pruned_idx_i = pruned_idx
self.dim_changes_info.pruned_idx_o = pruned_idx
if len(pruned_idx) > 0:
remove_idx = self.dim_changes_info.pruned_idx_i
self.masker().set_in_remove_idx(remove_idx)
self.masker().set_ot_remove_idx(remove_idx)
self.weight_mask["weight"][remove_idx] = 0
self.masker().register_mask("weight", self.weight_mask["weight"])
def modify_input(self, remove_idx):
bn = self.node.module
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[0])], remove_idx)
if bn.weight.shape[0] != len(preserve_idx):
log.info(f'[PRELU] {self.unique_name()}: channel {bn.num_parameters} -> {len(preserve_idx)}')
bn.weight = torch.nn.Parameter(bn.weight[preserve_idx])
bn.num_parameters = len(preserve_idx)
class NormChannelModifier(Modifier):
def __init__(self, node: TraceNode):
super(NormChannelModifier, self).__init__(node)
self.prunable = True
def register_mask(self, modifiers, importance, sparsity):
pruned_idx, sparsity = self.get_pruned_idx(modifiers)
self.dim_changes_info.pruned_idx_i = pruned_idx
self.dim_changes_info.pruned_idx_o = pruned_idx
if len(pruned_idx) > 0:
remove_idx = self.dim_changes_info.pruned_idx_i
self.masker().set_in_remove_idx(remove_idx)
self.masker().set_ot_remove_idx(remove_idx)
self.weight_mask["weight"][remove_idx] = 0
self.bias_mask["bias"] = self.weight_mask["weight"]
self.masker().register_mask("weight", self.weight_mask["weight"])
self.masker().register_mask("bias", self.bias_mask["bias"])
def modify_input(self, remove_idx):
bn = self.node.module
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[0])], remove_idx)
if bn.weight.shape[0] != len(preserve_idx):
while self.graph_modifier.bn_compensation:
if len(self.next_modifiers()) != 1:
break
if len(self.next_modifiers()[0].next_modifiers()) != 1:
break
act = self.next_modifiers()[0].module()
conv = self.next_modifiers()[0].next_modifiers()[0].module()
if isinstance(act, nn.Module):
if type(act) not in [
nn.ReLU,
nn.ReLU6,
nn.LeakyReLU,
nn.Sigmoid,
nn.Tanh,
nn.Hardsigmoid,
nn.Hardtanh,
nn.Hardswish,
nn.LogSigmoid,
]:
log.debug(f"unsupported activation for bn compensation: {type(act)}")
break
if isinstance(act, TraceNode):
if self.next_modifiers()[0].node.kind() not in [
'relu',
'relu6',
'leaky_relu',
'sigmoid',
'tanh',
'hardsigmoid',
'hardtanh',
'hardswish',
'logsigmoid',
]:
log.debug(f"unsupported activation for bn compensation: {self.next_modifiers()[0].node.kind()}")
break
if type(conv) is not torch.nn.Conv2d:
break
if conv.groups == 1:
with torch.no_grad():
bias = torch.tensor(bn.bias)
activation_bias = act(bias)
fuse_weight = torch.sum(conv.weight, dim=[2, 3])
bn_bias = fuse_weight * activation_bias
bn_bias = bn_bias[:, [True if i in remove_idx else False for i in range(bn_bias.shape[1])]]
bn_bias = torch.sum(bn_bias, dim=[1])
if conv.bias is None:
conv.bias = torch.nn.Parameter(bn_bias)
else:
conv.bias = torch.nn.Parameter(conv.bias + bn_bias)
break
if isinstance(bn, nn.BatchNorm1d) or isinstance(bn, nn.BatchNorm2d):
log.info(f'[BN] {self.unique_name()}: channel {bn.num_features} -> {len(preserve_idx)}')
bn.register_buffer('running_mean', bn.running_mean[preserve_idx])
bn.register_buffer('running_var', bn.running_var[preserve_idx])
bn.num_batches_tracked = bn.num_batches_tracked.zero_()
bn.num_features = len(preserve_idx)
elif isinstance(bn, nn.LayerNorm):
if len(bn.normalized_shape) == 1:
log.info(f'[LN] {self.unique_name()}: channel {bn.normalized_shape} -> ({len(preserve_idx)},)')
bn.normalized_shape = (len(preserve_idx),)
else:
log.error("The Layer Normalization (LN) Modifier supports only one-dimensional normalized_shape.")
else:
log.error("Unsupported Norm Type")
bn.weight = torch.nn.Parameter(bn.weight[preserve_idx])
bn.bias = torch.nn.Parameter(bn.bias[preserve_idx])
class ReIndexModifier(Modifier):
"""Only change the shape/layout of the tensor without changing the value in it,such as Reshape, Transpose,
Permute, View, Expand, Flatten, Split and other operators"""
def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tensor_constraint):
changes = self.calc_dim_changes()
self.dim_changes_info.update_i(
center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
)
for change in changes:
dim_change_o, tensor_o = change
if dim_change_o:
self.dim_changes_info.update_o(center, tensor_o, dim_change_o)
for m in self.next_modifiers(tensor_o):
m.dim_change_forward(center, tensor_o, dim_change_o, dim_transform, None)
class SplitModifier(ReIndexModifier):
def __init__(self, node: TraceNode):
super(SplitModifier, self).__init__(node)
# TODO:Get pruning information from center node
self.prunable = True
self.split_dict = OrderedDict()
self.split_dim = self.get_split_dim(self.node)
start = end = 0
for t in self.node.next_tensors:
end += t.shape[self.split_dim]
for n in self.node.next_nodes:
for t_ in n.prev_tensors:
if torch.equal(t, t_):
self.split_dict[n.unique_name] = (start, end)
start = end
self.ot_channel = [t.shape[self.split_dim] for t in self.node.next_tensors]
def get_split_dim(self, node):
args_parsed = node.module.args_parsed
if len(args_parsed) > 2:
split_dim = args_parsed[-1]
split_dim = split_dim[split_dim.find('=') + 1 :]
split_dim = int(split_dim)
return split_dim
return 0
def apply_mask(self, modifiers):
# Input parameters also need to be regenerated after pruning
remove_idx = self.dim_changes_info.pruned_idx_i
args_parsed = self.node.module.args_parsed_origin
if len(args_parsed) > 1:
if type(args_parsed[1]) is list:
ch = [int(i) for i in args_parsed[1]]
ch_new = []
for k, v in self.split_dict.items():
origin_ch = [i for i in range(v[0], v[1])]
for idx in remove_idx:
if idx in origin_ch:
origin_ch.remove(idx)
ch_new.append(len(origin_ch))
for i in range(len(ch)):
# Each subgraph only deletes its corresponding channel
ch[i] = str(min(ch[i], ch_new[i]))
self.args_parsed()[1] = ch
elif args_parsed[1].isdigit():
ch = self.ot_channel[0] - len(remove_idx) // len(self.ot_channel)
self.args_parsed()[1] = str(ch)
elif args_parsed[1] == '{}':
return True
else:
assert False
self.node.module.args_to_string(deepcopy(self.args_parsed()))
class ReshapeModifier(ReIndexModifier):
def __init__(self, node: TraceNode):
super(ReshapeModifier, self).__init__(node)
self.input_tensor = self.pre_tensors()[0]
self.output_tensor = self.next_tensors()[0]
self.original_shape = list(self.pre_tensors()[0].shape)
self.changed_shape = list(self.next_tensors()[0].shape)
self.input_modify_dim = -1
self.output_modify_dim = -1
# Multiple reshapes may be eliminated, so the constraints of reshapes are uncertain and cannot be used as leaves
# self.prunable = True
# Unify the input parameters of reshape and view into the same format
if self.node.kind() in ['reshape', 'view']:
args_parsed = self.args_parsed()
args = args_parsed[1:] if type(args_parsed[1]) is not list else args_parsed[1]
self.node.module.args_parsed = [args_parsed[0], args]
args_parsed_origin = self.node.module.args_parsed_origin
args = args_parsed_origin[1:] if type(args_parsed_origin[1]) is not list else args_parsed_origin[1]
self.node.module.args_parsed_origin = [args_parsed_origin[0], args]
def dim_choose(self, tensor_changes: typing.Dict[int, typing.List]) -> bool:
dim_changes_i = tensor_changes.get(id(self.pre_tensors()[0]), None)
if dim_changes_i is None:
dim_changes_i = self.dim_changes_info.merge_i()
if len(dim_changes_i) > 1:
# Prefer the rear axis to reduce the range of dependent transfer
for dim_choose_i in reversed(sorted(dim_changes_i)):
modifiers = [self]
self.init_dim_mapping()
if self.original_shape[dim_choose_i] <= 2:
continue
tensor_changes[id(self.input_tensor)] = [dim_choose_i]
tensor_changes[id(self.output_tensor)] = list(
self.forward_dim_mapping[id(self.input_tensor)][dim_choose_i][id(self.output_tensor)]
)
valid = True
for m in self.pre_modifiers(self.pre_tensors()[0]):
valid = m.dim_choose_traversal(modifiers, tensor_changes, self.pre_tensors()[0])
if not valid:
break
if valid:
for m in self.next_modifiers(self.next_tensors()[0]):
valid = m.dim_choose_traversal(modifiers, tensor_changes, self.next_tensors()[0])
if not valid:
break
if valid:
return True
else:
if id(self.input_tensor) in tensor_changes:
del tensor_changes[id(self.input_tensor)]
if id(self.output_tensor) in tensor_changes:
del tensor_changes[id(self.output_tensor)]
return False
else:
return True
def apply_mask(self, modifiers):
# reshape does not need to register the mask, but if the parameters are constant, you need to modify the
# parameters, otherwise it will cause abnormal inference after pruning
choice = self.dim_changes_info.get_tensor_choices(self.next_tensors()[0])
if choice is None:
return
output_modify_dim = choice[0]
args_origin = self.node.module.args_parsed_origin[1]
if self.node.type() == 'reshape':
if len(args_origin) > output_modify_dim and args_origin[output_modify_dim].isdigit():
if int(args_origin[output_modify_dim]) == -1:
return
pruned_idx, sparsity = self.get_pruned_idx(modifiers)
output_shape = deepcopy(self.changed_shape)
output_shape[output_modify_dim] = int(output_shape[output_modify_dim] * sparsity)
self.args_parsed()[1] = [str(i) for i in output_shape]
self.node.module.args_to_string(deepcopy(self.args_parsed()))
self.mask_applied = True
else:
return
elif self.node.type() == 'view':
if args_origin[output_modify_dim].isdigit():
if int(args_origin[output_modify_dim]) == -1:
return
pruned_idx, sparsity = self.get_pruned_idx(modifiers)
output_shape = deepcopy(self.changed_shape)
output_shape[output_modify_dim] = output_shape[output_modify_dim] - int(
output_shape[output_modify_dim] * sparsity
)
if self.args_parsed()[1][output_modify_dim].isdigit():
self.args_parsed()[1][output_modify_dim] = str(output_shape[output_modify_dim])
self.node.module.args_to_string(deepcopy(self.args_parsed()))
self.mask_applied = True
else:
return
class CatModifier(Modifier):
def __init__(self, node: TraceNode):
super(CatModifier, self).__init__(node)
if self.node.module.kwargs.get('dim', None) is not None:
self.dim = self.node.module.kwargs['dim']
elif self.node.module.kwargs.get('axis', None) is not None:
self.dim = self.node.module.kwargs['axis']
else:
if len(self.args_parsed()) > 1:
self.dim = int(self.args_parsed()[1])
else:
self.dim = 0
def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tensor_constraint):
offset = 0
for t in self.pre_tensors():
if id(t) != id(tensor):
offset += t.shape[dim_changes_i[0]]
else:
break
# In the case of batch cat, there will be dependencies between multiple input tensors
if self.dim not in dim_changes_i:
for t in self.pre_tensors():
if id(t) != id(tensor):
if list(t.shape) == list(tensor.shape):
t.data.copy_(tensor.clone().data)
else:
# Different shapes between tensors
if tensor.shape[self.dim] >= t.shape[self.dim]:
idx_tensor = torch.tensor([idx for idx in range(t.shape[self.dim])])
select_tensor = torch.index_select(tensor, self.dim, idx_tensor)
t.data.copy_(select_tensor.data)
else:
idx_tensor = torch.tensor([idx for idx in range(tensor.shape[self.dim])])
t.data.index_copy_(self.dim, idx_tensor, tensor)
dim_changes_o, tensor_o = self.calc_dim_changes()[0]
self.dim_changes_info.update_i(
center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
)
self.dim_changes_info.update_o(center, tensor_o, dim_changes_o)
if offset == 0 or dim_transform is None:
new_dim_transform = dim_transform
else:
new_dim_transform = {}
for key, value in dim_transform.items():
new_dim_transform[key + offset] = value
for m in self.next_modifiers(tensor_o):
m.dim_change_forward(center, tensor_o, dim_changes_o, new_dim_transform, None)
class MatMulModifier(Modifier):
def __init__(self, node: TraceNode):
super(MatMulModifier, self).__init__(node)
self.prunable = True
def dim_choose(self, tensor_changes: typing.Dict[int, typing.List]) -> bool:
for i in [0, 1]:
input_tensor = self.pre_tensors()[i]
output_tensor = self.next_tensors()[0]
dim_changes_i = tensor_changes.get(id(input_tensor), None)
if dim_changes_i is None:
dim_changes_i = self.dim_changes_info.merge_t(input_tensor)
if len(dim_changes_i) > 1:
valid = True
for dim_choose_i in sorted(dim_changes_i):
modifiers = [self]
tensor_changes[id(input_tensor)] = [dim_choose_i]
for m in self.pre_modifiers(self.pre_tensors()[0]):
valid = m.dim_choose_traversal(modifiers, tensor_changes, self.pre_tensors()[0])
if not valid:
break
if len(self.dim_changes_info.merge_o()) > 1:
self.init_dim_mapping()
tensor_changes[id(output_tensor)] = list(
self.forward_dim_mapping[id(input_tensor)][dim_choose_i][id(output_tensor)]
)
if valid:
for m in self.next_modifiers(self.next_tensors()[0]):
valid = m.dim_choose_traversal(modifiers, tensor_changes, self.next_tensors()[0])
if not valid:
break
if valid:
break
else:
del tensor_changes[id(input_tensor)]
del tensor_changes[id(output_tensor)]
if not valid:
return False
else:
continue
return True
def calc_dim_mapping(self) -> bool:
(input_0, input_1) = self.pre_tensors()
output = self.next_tensors()[0]
id_i0 = id(input_0)
id_i1 = id(input_1)
id_o = id(output)
input_dim0 = len(input_0.shape)
input_dim1 = len(input_1.shape)
for i in range(input_dim0 - 1):
self.forward_dim_mapping[id_i0][i][id_o].add(i)
self.backward_dim_mapping[id_o][i][id_i0].add(i)
if input_dim1 >= 2:
dim_mapping = [i for i in range(input_dim0)][-input_dim1:]
for i in range(input_dim1):
if i != input_dim1 - 2:
dim = dim_mapping[i]
self.forward_dim_mapping[id_i1][i][id_o].add(dim)
self.backward_dim_mapping[id_o][dim][id_i1].add(i)
return True
def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tensor_constraint):
self.dim_changes_info.update_i(
center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
)
if id(tensor) == id(self.pre_tensors()[0]):
other_tensor = self.pre_tensors()[1]
other_tensor[:] = 0
if len(other_tensor.shape) >= 2:
idx = [
0 if i == len(other_tensor.shape) - 2 else slice(None, None, None)
for i in range(len(other_tensor.shape))
]
else:
idx = [
0 if i == len(other_tensor.shape) - 2 else slice(None, None, None)
for i in range(len(other_tensor.shape))
]
torch.Tensor.__setitem__(other_tensor, tuple(idx), 1)
else:
other_tensor = self.pre_tensors()[0]
other_tensor[:] = 0
idx = [
0 if i == len(other_tensor.shape) - 1 else slice(None, None, None)
for i in range(len(other_tensor.shape))
]
torch.Tensor.__setitem__(other_tensor, idx, 1)
changes = self.calc_dim_changes()
for change in changes:
dim_change_o, tensor_o = change
if dim_change_o:
self.dim_changes_info.update_o(center, tensor_o, dim_change_o)
for m in self.next_modifiers(tensor_o):
m.dim_change_forward(center, tensor_o, dim_change_o, dim_transform, None)
class LinearChannelModifier(Modifier):
def __init__(self, node: TraceNode):
super().__init__(node)
self.input_tensor = self.pre_tensors()[0]
self.output_tensor = self.next_tensors()[0]
self.dim_c = len(self.input_tensor.shape) - 1
self.prunable = True
def calc_idx_group(self):
dim_choice_i = self.dim_changes_info.get_tensor_choices(self.input_tensor)
if dim_choice_i is None:
dim = self.dim_c
elif len(dim_choice_i) == 1:
dim = dim_choice_i[0]
else:
assert False
self.dim_changes_info.groups_i = [set([i for i in range(self.input_tensor.shape[dim])])]
self.dim_changes_info.groups_o = [set([i for i in range(self.output_tensor.shape[dim])])]
def calc_dim_mapping(self) -> bool:
for i in range(len(list(self.input_tensor.shape))):
self.forward_dim_mapping[id(self.input_tensor)][i][id(self.output_tensor)] = {i}
self.backward_dim_mapping[id(self.output_tensor)][i][id(self.input_tensor)] = {i}
return True
def register_mask(self, modifiers, importance, sparsity):
if self.dim_changes_info.pruned_idx_i:
remove_idx = self.dim_changes_info.pruned_idx_i
self.weight_mask["weight"][:, remove_idx] = 0
self.masker().set_in_remove_idx(remove_idx)
if self.dim_changes_info.pruned_idx_o:
remove_idx = self.dim_changes_info.pruned_idx_o
self.weight_mask["weight"][remove_idx, :] = 0
self.masker().set_ot_remove_idx(remove_idx)
bias_mask = self.bias_mask.get("bias", None)
if bias_mask is not None:
bias_mask[remove_idx] = 0
self.masker().register_mask("bias", bias_mask)
self.masker().register_mask("weight", self.weight_mask["weight"])
def modify_input(self, remove_idx):
if self.dim_changes_info.get_tensor_choices(self.input_tensor) != [self.dim_c]:
return
linear = self.node.module
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[1])], remove_idx)
if linear.weight.shape[1] != len(preserve_idx):
log.info(f'[FC] {self.unique_name()}: input {linear.in_features} -> {len(preserve_idx)}')
linear.weight = torch.nn.Parameter(linear.weight[:, preserve_idx])
linear.in_features = len(preserve_idx)
def modify_output(self, remove_idx):
if self.dim_changes_info.get_tensor_choices(self.output_tensor) != [self.dim_c]:
return
log.debug(f'[FC] {self.unique_name()}: remove_idx = {remove_idx}')
linear = self.node.module
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[0])], remove_idx)
if linear.weight.shape[0] != len(preserve_idx):
log.info(f'[FC] {self.unique_name()}: output {linear.out_features} -> {len(preserve_idx)}')
linear.weight = torch.nn.Parameter(linear.weight[preserve_idx, :])
linear.out_features = len(preserve_idx)
if linear.bias is not None:
linear.bias = torch.nn.Parameter(linear.bias[preserve_idx])
def dim_choose(self, tensor_changes: typing.Dict[int, typing.List]) -> bool:
id_i = id(self.input_tensor)
id_o = id(self.output_tensor)
dim_changes_i = tensor_changes.get(id_i, None)
dim_choice_o = tensor_changes.get(id_o, None)
if dim_changes_i is None:
dim_changes_i = self.dim_changes_info.merge_i()
if len(dim_changes_i) > 1:
for dim_choose_i in reversed(sorted(dim_changes_i)):
modifiers = [self]
self.init_dim_mapping()
tensor_changes[id_i] = [dim_choose_i]
if dim_choice_o is not None:
dim_choose_i_mapping = list(self.backward_dim_mapping[id_o][dim_choice_o[0]][id_i])
if dim_choose_i_mapping and dim_choose_i_mapping != tensor_changes[id_i]:
continue
tensor_changes[id_o] = [dim_choose_i]
valid = True
for m in self.pre_modifiers(self.pre_tensors()[0]):
valid = m.dim_choose_traversal(modifiers, tensor_changes, self.pre_tensors()[0])
if not valid:
break
if valid > 0:
for m in self.next_modifiers(self.next_tensors()[0]):
valid = m.dim_choose_traversal(modifiers, tensor_changes, self.next_tensors()[0])
if not valid:
break
if valid:
return True
else:
if id_i in tensor_changes.keys():
del tensor_changes[id_i]
if id_o in tensor_changes.keys():
del tensor_changes[id_o]
return False
else:
return True
def change_dimension(self) -> bool:
dim_changes_o = [self.dim_c]
fill_tensor_by_dim_changes(self.output_tensor, dim_changes_o)
tensor_constraint = self.dim_changes_info.update_o(
self, self.next_tensors()[0], dim_changes_o, update_constraint=True
)
for m in self.next_modifiers():
m.dim_change_forward(self, self.next_tensors()[0], dim_changes_o, None, tensor_constraint)
return True
def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tensor_constraint):
self.dim_changes_info.update_i(
center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
)
# Full connection can isolate changes in dim_c dimension
dim_changes_o = deepcopy(dim_changes_i)
if self.dim_c in dim_changes_o:
dim_changes_o.remove(self.dim_c)
# Dimension changes other than dim_c need to be passed to downstream nodes
if len(dim_changes_o) > 0:
self.dim_changes_info.update_o(center, self.next_tensors()[0], dim_changes_o)
fill_tensor_by_dim_changes(self.output_tensor, dim_changes_o)
constraint_i = self.dim_changes_info.constraints_i[dim_changes_o[0]][center.unique_name()][0]
transform = OrderedDict()
for i in range(len(constraint_i)):
transform[i] = constraint_i[i]
for m in self.next_modifiers():
m.dim_change_forward(center, self.next_tensors()[0], dim_changes_o, transform, None)
class RNNChannelModifier(Modifier):
def __init__(self, node: TraceNode):
super().__init__(node)
self.input_tensor = self.pre_tensors()[0]
self.output_tensor = self.next_tensors()[0]
self.dim_c = len(self.input_tensor.shape) - 1
self.bidirectional = self.module().bidirectional
self.prunable = True
def reset_mask(self):
self.weight_mask.clear()
self.bias_mask.clear()
for n, p in self.module().named_parameters():
if n.startswith('weight'):
self.weight_mask[n] = torch.ones_like(p)
elif n.startswith('bias'):
self.bias_mask[n] = torch.ones_like(p)
def calc_idx_group(self):
dim_choice_i = self.dim_changes_info.get_tensor_choices(self.input_tensor)
if dim_choice_i is None:
dim = self.dim_c
elif len(dim_choice_i) == 1:
dim = dim_choice_i[0]
else:
assert False
self.dim_changes_info.groups_i = [set([i for i in range(self.input_tensor.shape[dim])])]
if self.bidirectional:
output_idx = [i for i in range(self.output_tensor.shape[self.dim_c])]
output_chunk = len(output_idx) // 2
self.dim_changes_info.groups_o = [
set(output_idx[i : i + output_chunk]) for i in range(0, len(output_idx), output_chunk)
]
else:
self.dim_changes_info.groups_o = [set([i for i in range(self.output_tensor.shape[dim])])]
def calc_dim_mapping(self) -> bool:
for i in range(len(list(self.input_tensor.shape))):
self.forward_dim_mapping[id(self.input_tensor)][i][id(self.output_tensor)] = {i}
self.backward_dim_mapping[id(self.output_tensor)][i][id(self.input_tensor)] = {i}
return True
def tile_indices_with_gate_size(self, indices, gate_size, offset):
broadcasted = [indices] * gate_size
return [offset * idx + i for idx, x in enumerate(broadcasted) for i in x]
def split_indices_with_directions(self, indices, offset, num_directions):
split_pos = len(indices) // num_directions
idx_bwd = [i - offset for i in indices[split_pos:]]
idx_fwd = indices[:split_pos]
return idx_fwd, idx_bwd
def register_mask(self, modifiers, importance, sparsity):
gs = rnn_gate_size(self.module())
num_directions = 2 if self.module().bidirectional else 1
has_proj = hasattr(self.module(), 'proj_size') and self.module().proj_size > 0
if self.dim_changes_info.pruned_idx_i:
remove_idx = self.dim_changes_info.pruned_idx_i
self.weight_mask['weight_ih_l0'][:, remove_idx] = 0
self.masker().set_in_remove_idx(remove_idx)
if self.dim_changes_info.pruned_idx_o:
if has_proj:
u_name = self.unique_name()
hu_name = f'{u_name}:h'
remove_idx = []
idx_num = len(importance[hu_name])
remove_num = int(sparsity[u_name] * len(importance[hu_name]))
if self.bidirectional:
idx_num //= 2
remove_num //= 2
remove_idx += get_smallest_k(importance[hu_name][:idx_num], remove_num)
remove_idx += get_smallest_k(importance[hu_name][idx_num:], remove_num, offset=idx_num)
else:
remove_idx += get_smallest_k(importance[hu_name], remove_num)
remove_idx_proj = self.dim_changes_info.pruned_idx_o
else:
remove_idx = self.dim_changes_info.pruned_idx_o
remove_idx_proj = None
remove_idx_bwd = None
remove_idx_fwd = None
remove_idx_proj_bwd = None
remove_idx_proj_fwd = None
if num_directions > 1:
offset = self.module().hidden_size
remove_idx_fwd, remove_idx_bwd = self.split_indices_with_directions(remove_idx, offset, num_directions)
if remove_idx_proj is not None:
offset = self.module().proj_size
remove_idx_proj_fwd, remove_idx_proj_bwd = self.split_indices_with_directions(
remove_idx_proj, offset, num_directions
)
assert len(remove_idx_proj_fwd) == len(remove_idx_proj_bwd)
if gs > 1:
offset = self.module().hidden_size
if num_directions > 1:
remove_idx_bwd_gs = self.tile_indices_with_gate_size(remove_idx_bwd, gs, offset)
remove_idx_fwd_gs = self.tile_indices_with_gate_size(remove_idx_fwd, gs, offset)
else:
remove_idx_gs = self.tile_indices_with_gate_size(remove_idx, gs, offset)
for n in self.weight_mask:
remove_idx_r = remove_idx
remove_idx_c = remove_idx
remove_idx_pc = None
if num_directions > 1:
if n.endswith('_reverse'):
if gs > 1:
remove_idx_r = remove_idx_bwd_gs
else:
remove_idx_r = remove_idx_bwd
remove_idx_c = remove_idx_bwd
if has_proj:
remove_idx_pc = remove_idx_proj_bwd
else:
if gs > 1:
remove_idx_r = remove_idx_fwd_gs
else:
remove_idx_r = remove_idx_fwd
remove_idx_c = remove_idx_fwd
if has_proj:
remove_idx_pc = remove_idx_proj_fwd
elif gs > 1:
remove_idx_r = remove_idx_gs
remove_idx_pc = remove_idx_proj
if n.startswith('weight_ih_l0'):
self.weight_mask[n][remove_idx_r, :] = 0
elif n.startswith('weight_ih'):
self.weight_mask[n][remove_idx_r, :] = 0
if remove_idx_proj is None:
self.weight_mask[n][:, remove_idx] = 0
else:
self.weight_mask[n][:, remove_idx_proj] = 0
self.masker().register_mask(n, self.weight_mask[n])
elif n.startswith('weight_hh'):
self.weight_mask[n][remove_idx_r, :] = 0
if remove_idx_pc is None:
self.weight_mask[n][:, remove_idx_c] = 0
else:
self.weight_mask[n][:, remove_idx_pc] = 0
self.masker().register_mask(n, self.weight_mask[n])
elif n.startswith('weight_hr'):
if remove_idx_pc is not None:
self.weight_mask[n][remove_idx_pc, :] = 0
self.weight_mask[n][:, remove_idx_c] = 0
self.masker().register_mask(n, self.weight_mask[n])
for n in self.bias_mask:
if self.bias_mask[n] is None:
continue
remove_idx_ = remove_idx
if num_directions > 1:
if n.endswith('_reverse'):
if gs > 1:
remove_idx_ = remove_idx_bwd_gs
else:
remove_idx_ = remove_idx_bwd
else:
if gs > 1:
remove_idx_ = remove_idx_fwd_gs
else:
remove_idx_ = remove_idx_fwd
elif gs > 1:
remove_idx_ = remove_idx_gs
self.bias_mask[n][remove_idx_] = 0
self.masker().register_mask(n, self.bias_mask[n])
self.masker().set_ot_remove_idx(remove_idx)
if remove_idx_proj is not None:
self.masker().set_custom_remove_idx(remove_idx_proj)
self.masker().register_mask('weight_ih_l0', self.weight_mask['weight_ih_l0'])
def modify_input(self, remove_idx):
rnn = self.node.module
assert len(self.node.prev_tensors) == 1, 'RNNs with hidden state inputs are not supported'
preserve_idx = complementary_list([i for i in range(self.weight_mask['weight_ih_l0'].shape[1])], remove_idx)
if rnn.weight_ih_l0.shape[1] != len(preserve_idx):
log.info(f'[RNN] {self.unique_name()}: input {rnn.input_size} -> {len(preserve_idx)}')
rnn.weight_ih_l0 = torch.nn.Parameter(rnn.weight_ih_l0[:, preserve_idx])
if rnn.bidirectional:
rnn.weight_ih_l0_reverse = torch.nn.Parameter(rnn.weight_ih_l0_reverse[:, preserve_idx])
rnn.input_size = len(preserve_idx)
def modify_output(self, remove_idx):
rnn = self.node.module
log.debug(f'[RNN] {self.unique_name()}: remove_idx = {remove_idx}')
num_directions = 2 if rnn.bidirectional else 1
has_proj = hasattr(self.module(), 'proj_size') and self.module().proj_size > 0
gs = rnn_gate_size(rnn)
if num_directions > 1:
offset = rnn.hidden_size
remove_idx_fwd, remove_idx_bwd = self.split_indices_with_directions(remove_idx, offset, num_directions)
if gs > 1:
offset = rnn.hidden_size
if num_directions > 1:
remove_idx_bwd_gs = self.tile_indices_with_gate_size(remove_idx_bwd, gs, offset)
remove_idx_fwd_gs = self.tile_indices_with_gate_size(remove_idx_fwd, gs, offset)
else:
remove_idx_gs = self.tile_indices_with_gate_size(remove_idx, gs, offset)
remove_idx_proj = None
if has_proj:
remove_idx_proj = self.masker().custom_remove_idx
if remove_idx_proj is not None:
offset = rnn.proj_size
remove_idx_proj_fwd, remove_idx_proj_bwd = self.split_indices_with_directions(
remove_idx_proj, offset, num_directions
)
for i in range(rnn.num_layers):
for j in range(num_directions):
suffix = '_reverse' if j > 0 else ''
desc = f'layer{suffix} hidden #{i}'
weight_ih = getattr(rnn, f'weight_ih_l{i}{suffix}')
weight_hh = getattr(rnn, f'weight_hh_l{i}{suffix}')
weight_hr = getattr(rnn, f'weight_hr_l{i}{suffix}', None)
bias_ih = getattr(rnn, f'bias_ih_l{i}{suffix}', None)
bias_hh = getattr(rnn, f'bias_hh_l{i}{suffix}', None)
remove_idx_r = remove_idx
remove_idx_c = remove_idx
remove_idx_pc = None
if num_directions > 1:
if j > 0:
if gs > 1:
remove_idx_r = remove_idx_bwd_gs
else:
remove_idx_r = remove_idx_bwd
remove_idx_c = remove_idx_bwd
if has_proj:
remove_idx_pc = remove_idx_proj_bwd
else:
if gs > 1:
remove_idx_r = remove_idx_fwd_gs
else:
remove_idx_r = remove_idx_fwd
remove_idx_c = remove_idx_fwd
if has_proj:
remove_idx_pc = remove_idx_proj_fwd
elif gs > 1:
remove_idx_r = remove_idx_gs
remove_idx_pc = remove_idx_proj
preserve_idx_ih_r = complementary_list(
[j for j in range(self.weight_mask[f'weight_ih_l{i}{suffix}'].shape[0])], remove_idx_r
)
preserve_idx_hh_r = complementary_list(
[j for j in range(self.weight_mask[f'weight_hh_l{i}{suffix}'].shape[0])], remove_idx_r
)
if weight_hr is None:
preserve_idx_hh_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_hh_l{i}{suffix}'].shape[1])], remove_idx_c
)
else:
preserve_idx_hh_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_hh_l{i}{suffix}'].shape[1])], remove_idx_pc
)
preserve_idx_hr_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_hr_l{i}{suffix}'].shape[1])], remove_idx_c
)
preserve_idx_ih_c = None
if i != 0 and preserve_idx_ih_c is None:
if weight_hr is not None:
preserve_idx_ih_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_ih_l{i}{suffix}'].shape[1])], remove_idx_proj
)
else:
preserve_idx_ih_c = preserve_idx_ih_r
if num_directions > 1 or gs > 1:
preserve_idx_ih_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_ih_l{i}{suffix}'].shape[1])], remove_idx
)
if weight_ih.shape[0] != len(preserve_idx_ih_r):
if i != 0 and weight_ih.shape[1] != len(preserve_idx_ih_c):
desc_i = f'layer{suffix} input #{i}'
log.info(
f'[RNN] {self.unique_name()}: {desc_i} {weight_ih.shape[1]} -> {len(preserve_idx_ih_c)}'
)
log.info(f'[RNN] {self.unique_name()}: {desc} {rnn.hidden_size * gs} -> {len(preserve_idx_ih_r)}')
if i != 0:
new_w = weight_ih[preserve_idx_ih_r, :][:, preserve_idx_ih_c]
setattr(rnn, f'weight_ih_l{i}{suffix}', torch.nn.Parameter(new_w))
else:
setattr(rnn, f'weight_ih_l{i}{suffix}', torch.nn.Parameter(weight_ih[preserve_idx_ih_r, :]))
if bias_ih is not None:
setattr(rnn, f'bias_ih_l{i}{suffix}', torch.nn.Parameter(bias_ih[preserve_idx_ih_r]))
desc = f'layer{suffix} output #{i}'
if weight_hh.shape[0] != len(preserve_idx_hh_r) or weight_hh.shape[1] != len(preserve_idx_hh_c):
log.info(f'[RNN] {self.unique_name()}: {desc} {rnn.hidden_size * gs} -> {len(preserve_idx_hh_r)}')
if weight_hr is None:
setattr(
rnn,
f'weight_hh_l{i}{suffix}',
torch.nn.Parameter(weight_hh[preserve_idx_hh_r, :][:, preserve_idx_hh_c]),
)
else:
setattr(
rnn,
f'weight_hh_l{i}{suffix}',
torch.nn.Parameter(weight_hh[preserve_idx_hh_r, :][:, preserve_idx_hh_c]),
)
setattr(
rnn,
f'weight_hr_l{i}{suffix}',
torch.nn.Parameter(weight_hr[preserve_idx_hh_c, :][:, preserve_idx_hr_c]),
)
if bias_hh is not None:
setattr(rnn, f'bias_hh_l{i}{suffix}', torch.nn.Parameter(bias_hh[preserve_idx_hh_r]))
if weight_hr is None:
rnn.hidden_size = len(preserve_idx_hh_c)
else:
rnn.proj_size = len(preserve_idx_hh_c)
rnn.hidden_size = len(preserve_idx_hr_c)
def dim_choose(self, tensor_changes: typing.Dict[int, typing.List]) -> bool:
id_i = id(self.input_tensor)
id_o = id(self.output_tensor)
dim_changes_i = tensor_changes.get(id_i, None)
dim_choice_o = tensor_changes.get(id_o, None)
if dim_changes_i is None:
dim_changes_i = self.dim_changes_info.merge_i()
if len(dim_changes_i) > 1:
if self.dim_c not in dim_changes_i:
return False
dim_choose_i = self.dim_c
modifiers = [self]
self.init_dim_mapping()
tensor_changes[id_i] = [dim_choose_i]
if dim_choice_o is not None:
dim_choose_i_mapping = list(self.backward_dim_mapping[id_o][dim_choice_o[0]][id_i])
if dim_choose_i_mapping and dim_choose_i_mapping != tensor_changes[id_i]:
return False
tensor_changes[id_o] = [dim_choose_i]
valid = True
for m in self.pre_modifiers(self.pre_tensors()[0]):
valid = m.dim_choose_traversal(modifiers, tensor_changes, self.pre_tensors()[0])
if not valid:
break
if valid > 0:
for m in self.next_modifiers(self.next_tensors()[0]):
valid = m.dim_choose_traversal(modifiers, tensor_changes, self.next_tensors()[0])
if not valid:
break
if valid:
return True
else:
if id_i in tensor_changes.keys():
del tensor_changes[id_i]
if id_o in tensor_changes.keys():
del tensor_changes[id_o]
return False
else:
return True
def change_dimension(self) -> bool:
dim_changes_o = [self.dim_c]
fill_tensor_by_dim_changes(self.output_tensor, dim_changes_o)
tensor_constraint = self.dim_changes_info.update_o(
self, self.next_tensors()[0], dim_changes_o, update_constraint=True
)
for m in self.next_modifiers():
m.dim_change_forward(self, self.next_tensors()[0], dim_changes_o, None, tensor_constraint)
return True
def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tensor_constraint):
self.dim_changes_info.update_i(
center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
)
dim_changes_o = deepcopy(dim_changes_i)
if self.dim_c in dim_changes_o:
dim_changes_o.remove(self.dim_c)
if len(dim_changes_o) > 0:
log.error(f"[{self.unique_name()}] Modifying dimensions other than dim_c is temporarily not supported")
assert False
class ConvChannelModifier(Modifier):
def __init__(self, node: TraceNode):
super().__init__(node)
self.dim_n = 0
self.dim_c = 1
self.dim_h = 2
self.dim_w = 3
self.input_tensor = self.pre_tensors()[0]
self.output_tensor = self.next_tensors()[0]
self.prunable = True
self.group = self.module().groups
def calc_idx_group(self):
if not is_dw_conv(self.module()):
if self.group > 1:
input_idx = [i for i in range(self.input_tensor.shape[self.dim_c])]
output_idx = [i for i in range(self.output_tensor.shape[self.dim_c])]
input_chunk = len(input_idx) // self.group
output_chunk = len(output_idx) // self.group
self.dim_changes_info.groups_i = [
set(input_idx[i : i + input_chunk]) for i in range(0, len(input_idx), input_chunk)
]
self.dim_changes_info.groups_o = [
set(output_idx[i : i + output_chunk]) for i in range(0, len(output_idx), output_chunk)
]
else:
self.dim_changes_info.groups_i = [set([i for i in range(self.input_tensor.shape[self.dim_c])])]
self.dim_changes_info.groups_o = [set([i for i in range(self.output_tensor.shape[self.dim_c])])]
def init_dim_mapping(self) -> bool:
if len(self.forward_dim_mapping) > 0:
return True
# Convolutional pruning is not allowed to modify dim_w, dim_h
self.forward_dim_mapping[id(self.input_tensor)][self.dim_n][id(self.output_tensor)] = {self.dim_n}
self.backward_dim_mapping[id(self.output_tensor)][self.dim_n][id(self.input_tensor)] = {self.dim_n}
self.forward_dim_mapping[id(self.input_tensor)][self.dim_c][id(self.output_tensor)] = set()
self.backward_dim_mapping[id(self.output_tensor)][self.dim_c][id(self.input_tensor)] = set()
return True
def register_mask(self, modifiers, importance, sparsity):
if is_dw_conv(self.module()):
remove_idx = self.dim_changes_info.pruned_idx_i
self.weight_mask["weight"][remove_idx, :] = 0
self.masker().set_in_remove_idx(remove_idx)
self.masker().set_ot_remove_idx(remove_idx)
bias_mask = self.bias_mask.get("bias", None)
if bias_mask is not None:
bias_mask[remove_idx] = 0
self.masker().register_mask("bias", bias_mask)
else:
if self.dim_changes_info.pruned_idx_i:
remove_idx = self.dim_changes_info.pruned_idx_i
group = self.group
remove_idx.sort()
if group != 1:
num_g_out = self.weight_mask["weight"].shape[0] // group
weight_2 = self.weight_mask["weight"].shape[1]
start_in = end_in = 0
for i in range(group):
end_in += weight_2
g_remove_idx = []
for idx in remove_idx:
if start_in <= idx < end_in:
g_remove_idx.append(idx)
g_remove_idx = [(idx - weight_2 * i) for idx in g_remove_idx]
self.weight_mask["weight"][num_g_out * i : num_g_out * (i + 1), g_remove_idx] = 0
start_in = end_in
else:
self.weight_mask["weight"][:, remove_idx] = 0
self.masker().set_in_remove_idx(remove_idx)
if self.dim_changes_info.pruned_idx_o:
remove_idx = self.dim_changes_info.pruned_idx_o
self.register_out_mask(remove_idx)
self.masker().register_mask("weight", self.weight_mask["weight"])
def register_out_mask(self, remove_idx):
self.weight_mask["weight"][remove_idx, :] = 0
self.masker().set_ot_remove_idx(remove_idx)
bias_mask = self.bias_mask.get("bias", None)
if bias_mask is not None:
bias_mask[remove_idx] = 0
self.masker().register_mask("bias", bias_mask)
def modify_input(self, remove_idx):
conv = self.node.module
log.debug(f'[CONV] {self.unique_name()}: remove_idx = {remove_idx}')
if is_dw_conv(self.module()):
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[0])], remove_idx)
if conv.groups != len(preserve_idx):
log.info(f'[DW_CONV] {self.unique_name()}: input {conv.in_channels} -> {len(preserve_idx)}')
conv.groups = len(preserve_idx)
conv.in_channels = len(preserve_idx)
conv.out_channels = len(preserve_idx)
conv.weight = torch.nn.Parameter(conv.weight[preserve_idx, :])
if conv.bias is not None:
log.info(f'[DW_CONV] {self.unique_name()}: bias {conv.bias.shape[0]} -> {len(preserve_idx)}')
conv.bias = torch.nn.Parameter(conv.bias[preserve_idx])
else:
group = self.group
if group != 1:
if conv.in_channels == (self.weight_mask["weight"].shape[1]) * group - len(remove_idx):
return
num_g_remove_idx = len(remove_idx) // group
num_g_out = self.weight_mask["weight"].shape[0] // group
weight_2 = self.weight_mask["weight"].shape[1]
conv_weight = None
for i in range(group):
g_remove_idx = remove_idx[num_g_remove_idx * i : num_g_remove_idx * (i + 1)]
g_remove_idx = [idx - weight_2 * i for idx in g_remove_idx]
preserve_idx = complementary_list(
[j for j in range(self.weight_mask["weight"].shape[1])], g_remove_idx
)
weight = conv.weight[num_g_out * i : num_g_out * (i + 1), preserve_idx]
if conv_weight is None:
conv_weight = weight
else:
conv_weight = torch.cat([conv_weight, weight], dim=0)
remove_channel = conv.in_channels - len(remove_idx)
log.info(f'[CONV-group] {self.unique_name()}: input {conv.in_channels} -> {remove_channel}')
conv.weight = torch.nn.Parameter(conv_weight)
conv.in_channels = remove_channel
else:
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[1])], remove_idx)
if conv.in_channels != len(preserve_idx):
log.info(f'[CONV] {self.unique_name()}: input {conv.in_channels} -> {len(preserve_idx)}')
conv.weight = torch.nn.Parameter(
conv.weight[
:,
preserve_idx,
]
)
conv.in_channels = len(preserve_idx)
def modify_output(self, remove_idx):
conv = self.node.module
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[0])], remove_idx)
log.debug(f'[CONV] {self.unique_name()}: remove_idx = {remove_idx}')
if is_dw_conv(self.module()):
if conv.groups != len(preserve_idx):
log.info(f'[DW_CONV] {self.unique_name()}: input {conv.in_channels} -> {len(preserve_idx)}')
conv.groups = len(preserve_idx)
conv.in_channels = len(preserve_idx)
conv.out_channels = len(preserve_idx)
conv.weight = torch.nn.Parameter(conv.weight[preserve_idx, :])
if conv.bias is not None:
log.info(f'[DW_CONV] {self.unique_name()}: bias {conv.bias.shape[0]} -> {len(preserve_idx)}')
conv.bias = torch.nn.Parameter(conv.bias[preserve_idx])
else:
if conv.out_channels != len(preserve_idx):
log.info(f'[CONV] {self.unique_name()}: output {conv.out_channels} -> {len(preserve_idx)}')
conv.weight = torch.nn.Parameter(conv.weight[preserve_idx, :])
conv.out_channels = len(preserve_idx)
if conv.bias is not None:
log.info(f'[CONV] {self.unique_name()}: bias {conv.bias.shape[0]} -> {len(preserve_idx)}')
conv.bias = torch.nn.Parameter(conv.bias[preserve_idx])
def change_dimension(self) -> bool:
if is_dw_conv(self.module()):
return True
dim_changes_o = [self.dim_c]
fill_tensor_by_dim_changes(self.output_tensor, dim_changes_o)
tensor_constraint = self.dim_changes_info.update_o(
self, self.next_tensors()[0], dim_changes_o, update_constraint=True
)
for m in self.next_modifiers():
m.dim_change_forward(self, self.next_tensors()[0], dim_changes_o, None, tensor_constraint)
return True
def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tensor_constraint):
# "Conv2d don't support change wh dimensions."
if self.dim_h in dim_changes_i:
dim_changes_i.remove(self.dim_h)
if self.dim_w in dim_changes_i:
dim_changes_i.remove(self.dim_w)
tensor_constraint = self.dim_changes_info.update_i(
center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
)
dw_conv = is_dw_conv(self.module())
if dw_conv:
dim_changes_o = dim_changes_i
elif self.dim_n in dim_changes_i:
dim_changes_o = [self.dim_n]
else:
dim_changes_o = None
if dim_changes_o:
self.dim_changes_info.update_o(center, self.next_tensors()[0], dim_changes_o)
fill_tensor_by_dim_changes(self.output_tensor, dim_changes_o)
constraint_i = self.dim_changes_info.constraints_i[dim_changes_o[0]][center.unique_name()][0]
transform = OrderedDict()
for i in range(len(constraint_i)):
transform[i] = constraint_i[i]
for m in self.next_modifiers():
if dw_conv:
m.dim_change_forward(center, self.next_tensors()[0], dim_changes_o, transform, tensor_constraint)
else:
m.dim_change_forward(center, self.next_tensors()[0], dim_changes_o, transform, None)
class TransConvChannelModifier(ConvChannelModifier):
def register_mask(self, modifiers, importance, sparsity):
if self.dim_changes_info.pruned_idx_i:
remove_idx = self.dim_changes_info.pruned_idx_i
self.weight_mask["weight"][
remove_idx,
:,
] = 0
self.masker().set_in_remove_idx(remove_idx)
if self.dim_changes_info.pruned_idx_o:
remove_idx = self.dim_changes_info.pruned_idx_o
self.weight_mask["weight"][
:,
remove_idx,
] = 0
self.masker().set_ot_remove_idx(remove_idx)
# 普通conv中bias仅在output改变时改变
bias_mask = self.bias_mask.get("bias", None)
if bias_mask is not None:
bias_mask[remove_idx] = 0
self.masker().register_mask("bias", bias_mask)
self.masker().register_mask("weight", self.weight_mask["weight"])
def modify_input(self, remove_idx):
conv = self.node.module
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[0])], remove_idx)
if conv.in_channels != len(preserve_idx):
log.info(f'[TRANS_CONV2D] {self.unique_name()}: input {conv.in_channels} -> {len(preserve_idx)}')
conv.weight = torch.nn.Parameter(conv.weight[preserve_idx, :])
conv.in_channels = len(preserve_idx)
def modify_output(self, remove_idx):
conv = self.node.module
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[1])], remove_idx)
if conv.out_channels != len(preserve_idx):
log.info(f'[TRANS_CONV2D] {self.unique_name()}: output {conv.out_channels} -> {len(preserve_idx)}')
conv.weight = torch.nn.Parameter(conv.weight[:, preserve_idx])
conv.out_channels = len(preserve_idx)
if conv.bias is not None:
log.info(f'[TRANS_CONV2D] {self.unique_name()}: bias {conv.bias.shape[0]} -> {len(preserve_idx)}')
conv.bias = torch.nn.Parameter(conv.bias[preserve_idx])
class ConstantModifier(LinearChannelModifier):
def __init__(self, node: TraceNode):
Modifier.__init__(self, node)
self.output_tensor = self.next_tensors()[0]
self.input_tensor = self.output_tensor
# Pruning operation occurs along the second dimension
self.dim_c = 1
self.prunable = True
def change_dimension(self) -> bool:
dim_changes_o = [self.dim_c]
fill_tensor_by_dim_changes(self.output_tensor, dim_changes_o)
tensor_constraint = self.dim_changes_info.update_o(
self, self.next_tensors()[0], dim_changes_o, update_constraint=True
)
for m in self.next_modifiers():
m.dim_change_forward(self, self.next_tensors()[0], dim_changes_o, None, tensor_constraint)
return True
def register_mask(self, modifiers, importance, sparsity):
Modifier.reset_mask(self)
CHANNEL_MODIFIERS = {
nn.Conv1d: ConvChannelModifier,
nn.Conv2d: ConvChannelModifier,
nn.Linear: LinearChannelModifier,
nn.ConvTranspose2d: TransConvChannelModifier,
nn.ConvTranspose1d: TransConvChannelModifier,
nn.AvgPool1d: PoolingModifier,
nn.AvgPool2d: PoolingModifier,
nn.AdaptiveAvgPool2d: PoolingModifier,
nn.MaxPool2d: PoolingModifier,
'adaptive_avg_pool2d': PoolingModifier,
'max_pool2d': PoolingModifier,
'pad': PaddingModifier,
nn.Upsample: PoolingModifier,
nn.UpsamplingBilinear2d: PoolingModifier,
nn.UpsamplingNearest2d: PoolingModifier,
"interpolate": PoolingModifier,
nn.PReLU: PReLUChannelModifier,
nn.BatchNorm2d: NormChannelModifier,
nn.BatchNorm1d: NormChannelModifier,
nn.LayerNorm: NormChannelModifier,
'matmul': MatMulModifier,
'bmm': MatMulModifier,
'cat': CatModifier,
'view': ReshapeModifier,
"flatten": ReIndexModifier,
"squeeze": ReIndexModifier,
nn.Flatten: ReIndexModifier,
'reshape': ReshapeModifier,
'transpose': ReIndexModifier,
'permute': ReIndexModifier,
'split': SplitModifier,
'chunk': ReIndexModifier,
'mean': ReIndexModifier,
'sum': ReIndexModifier,
'getitem': ReIndexModifier,
nn.PixelShuffle: ReIndexModifier,
nn.RNN: RNNChannelModifier,
nn.GRU: RNNChannelModifier,
nn.LSTM: RNNChannelModifier,
'weight': ConstantModifier,
}
def create_channel_modifier(n):
for key in CHANNEL_MODIFIERS.keys():
if type(key) is str:
if n.kind() == key:
return CHANNEL_MODIFIERS[key](n)
elif isinstance(n.module, key):
return CHANNEL_MODIFIERS[key](n)
# ChannelModifier is used by default
return Modifier(n)
class SubGraph(object):
modifiers: typing.List[Modifier]
leaf: typing.List[Modifier]
def __init__(self, center: Modifier, modifiers=None):
self.center = center
self.modifiers = modifiers if modifiers is not None else []
self.modifiers_dict = {m.unique_name(): m for m in self.modifiers}
self.leaf = list()
self.dependent_centers = set()
self.center_constraint = OrderedDict()
self.center_group = OrderedDict()
self.leaf_group = OrderedDict()
self.skip = False
def add_modifier(self, modifier: Modifier):
self.modifiers.append(modifier)
self.modifiers_dict[modifier.unique_name()] = modifier
def constraint_mapping(self, constraint, mapping):
result = set()
for i in constraint:
result.update(mapping[i])
return result
def calc_prune_idx_by_bn_variance(
self,
center_list,
center_to_leaf_all,
leaf_to_center_all,
importance,
center_to_center_all,
sparsity,
multiple,
):
pruned_leaf_constraint_all = {}
pruned_center_constraint_all = {}
invalid_bn_idxes = {}
invalid_center_idxes = {}
ignored_bn = set()
for leaf in self.leaf:
if type(leaf.module()) is not nn.BatchNorm2d:
continue
while True:
if len(leaf.pre_modifiers()) == 1:
leaf = leaf.pre_modifiers()[0]
else:
break
if leaf in self.leaf:
if type(leaf.module()) is not nn.BatchNorm2d:
continue
ignored_bn.add(leaf)
for leaf in self.leaf:
if leaf in ignored_bn:
continue
if type(leaf.module()) is not nn.BatchNorm2d:
continue
is_real_leaf = True
for leaf_center_name in leaf.dim_changes_info.centers.keys():
# All centers of a valid leaf must be in this subgraph
if leaf_center_name not in self.modifiers_dict.keys():
is_real_leaf = False
if not is_real_leaf:
continue
center_to_leaf = center_to_leaf_all[leaf.unique_name()]
leaf_to_center = leaf_to_center_all[leaf.unique_name()]
invalid_bn = (leaf.module().running_var < 1e-8).tolist()
invalid_bn_dict = {}
for i in range(len(invalid_bn)):
if invalid_bn[i]:
invalid_bn_dict[i] = True
else:
invalid_bn_dict[i] = False
invalid_bn_idxes[leaf.unique_name()] = invalid_bn_dict
for idx, state in invalid_bn_dict.items():
for center_name, idx_mapping in leaf_to_center.items():
center_idxes = list(idx_mapping[idx])
if set(center_idxes) == {-1}:
continue
for center_idx in center_idxes:
if center_name not in invalid_center_idxes:
invalid_center_idxes[center_name] = {leaf.unique_name(): {}}
if leaf.unique_name() not in invalid_center_idxes[center_name]:
invalid_center_idxes[center_name][leaf.unique_name()] = {}
invalid_center_idxes[center_name][leaf.unique_name()][center_idx] = state
center_to_center = center_to_center_all[center_name]
for depend_center_name in center_to_center[center_idx].keys():
if depend_center_name not in invalid_center_idxes:
invalid_center_idxes[depend_center_name] = {leaf.unique_name(): {}}
if leaf.unique_name() not in invalid_center_idxes[depend_center_name]:
invalid_center_idxes[depend_center_name][leaf.unique_name()] = {}
depend_center_idxes = list(center_to_center[center_idx][depend_center_name])
for depend_center_idx in depend_center_idxes:
invalid_center_idxes[depend_center_name][leaf.unique_name()][depend_center_idx] = state
for center_name, leaf_info in invalid_center_idxes.items():
invalid_center_idx_dict = {}
for leaf_name, invalid_center_idx in leaf_info.items():
for idx, state in invalid_center_idx.items():
if idx not in invalid_center_idx_dict:
invalid_center_idx_dict[idx] = state
else:
invalid_center_idx_dict[idx] &= state
invalid_center_idx_set = set()
for idx, state in invalid_center_idx_dict.items():
if state:
invalid_center_idx_set.add(idx)
if multiple is not None:
invalid_center_idx_lst = list(invalid_center_idx_set)
remainder = len(invalid_center_idx_lst) % multiple
invalid_center_idx_lst = invalid_center_idx_lst[:-remainder]
invalid_center_idx_set = set(invalid_center_idx_lst)
pruned_center_constraint_all[center_name] = invalid_center_idx_set
for leaf in self.leaf:
center_to_leaf = center_to_leaf_all[leaf.unique_name()]
leaf_to_center = leaf_to_center_all[leaf.unique_name()]
if center_name not in center_to_leaf.keys():
continue
is_real_leaf = True
for leaf_center_name in leaf.dim_changes_info.centers.keys():
# All centers of a valid leaf must be in this subgraph
if leaf_center_name not in self.modifiers_dict.keys():
is_real_leaf = False
if not is_real_leaf:
continue
for idx in invalid_center_idx_set:
leaf_idx = self.constraint_mapping([idx], center_to_leaf[center_name])
if leaf_idx != {-1}:
if leaf.unique_name() not in pruned_leaf_constraint_all:
pruned_leaf_constraint_all[leaf.unique_name()] = []
pruned_leaf_constraint_all[leaf.unique_name()].append(leaf_idx)
return pruned_center_constraint_all, pruned_leaf_constraint_all
def calc_prune_idx_by_center_importance(
self, center_list, center_to_leaf_all, leaf_to_center_all, importance, center_to_center_all, sparsity
):
leaf_delta_idx = {}
pruned_leaf_constraint_all = {}
pruned_center_constraint_all = {}
calculated_center_constraint_all = {}
for center in center_list:
calculated_constraint = calculated_center_constraint_all.get(center.unique_name(), set())
constraint_need_prune = []
for i in self.center_constraint[center.unique_name()]:
if not i.issubset(calculated_constraint):
constraint_need_prune.append(i)
for leaf in self.leaf:
center_to_leaf = center_to_leaf_all[leaf.unique_name()]
leaf_to_center = leaf_to_center_all[leaf.unique_name()]
if center.unique_name() not in center_to_leaf.keys():
continue
is_real_leaf = True
for leaf_center_name in leaf.dim_changes_info.centers.keys():
# All centers of a valid leaf must be in this subgraph
if leaf_center_name not in self.modifiers_dict.keys():
is_real_leaf = False
if not is_real_leaf:
continue
log.debug(f"calc leaf prune idx: {center.unique_name()}, {leaf.unique_name()}")
leaf_constraint_need_prune = []
for constraint in constraint_need_prune:
constraint_set = set()
for center_idxes in constraint:
if center_idxes in center_to_leaf[center.unique_name()]:
constraint_set.update(center_to_leaf[center.unique_name()][center_idxes])
leaf_constraint_need_prune.append(constraint_set)
leaf_constraint_all = []
calculated_leaf_constraint_all = []
pruned_leaf_constraint = []
for center_name, constraints in self.center_constraint.items():
if center_name not in center_to_leaf.keys():
continue
calculated_constraint = calculated_center_constraint_all.get(center_name, set())
pruned_constraint = pruned_center_constraint_all.get(center_name, set())
for constraint in constraints:
leaf_idx_constraint = set()
for center_idxes in constraint:
if center_idxes in center_to_leaf[center_name]:
leaf_idx_constraint.update(center_to_leaf[center_name][center_idxes])
if constraint.issubset(calculated_constraint):
calculated_leaf_constraint_all.append(leaf_idx_constraint)
if len(leaf_idx_constraint) > 0 and constraint.issubset(pruned_constraint):
# When a center has been pruned, the corresponding leaf directly
# reuses the pruning result of the center
pruned_leaf_constraint.append(leaf_idx_constraint)
else:
if len(leaf_idx_constraint) > 0:
leaf_constraint_all.append(leaf_idx_constraint)
# All center idx corresponding to leaf have been calculated
if len(leaf_constraint_all) == 0:
pruned_leaf_constraint_all[leaf.unique_name()] = pruned_leaf_constraint
continue
merge_constraint(leaf_constraint_all)
merge_constraint(calculated_leaf_constraint_all)
constraint = []
for i in leaf_constraint_all:
if i in leaf_constraint_need_prune and i not in calculated_leaf_constraint_all:
constraint.append(i)
leaf_constraint_all = constraint
leaf_importance = []
for constraint in leaf_constraint_all:
importance_ = 0
for leaf_idxes in constraint:
for center_name, idx_mapping in leaf_to_center.items():
center_idxes = list(idx_mapping[leaf_idxes])
if set(center_idxes) == {-1}:
continue
if max(center_idxes) >= len(importance[center_name]):
assert False
importance_ += float(sum(importance[center_name][center_idxes]))
center_to_center = center_to_center_all[center_name]
for center_idx in center_idxes:
for depend_center_name in center_to_center[center_idx].keys():
depend_center_idxes = list(center_to_center[center_idx][depend_center_name])
importance_ += float(sum(importance[depend_center_name][depend_center_idxes]))
leaf_importance.append((constraint, importance_))
for group in self.leaf_group[leaf.unique_name()]:
valid_importance = []
for i in leaf_importance:
constraint = i[0]
if len(constraint & group) > 0:
assert constraint.issubset(group)
valid_importance.append(i)
if len(valid_importance) == 0:
continue
valid_importance = sorted(valid_importance, key=lambda x: x[1])
current_sparsity = 0
total_idx = sum([len(i[0]) for i in valid_importance])
target_idx = total_idx * sparsity[center.unique_name()]
pruned_leaf_idx = set()
while current_sparsity < sparsity[center.unique_name()]:
if len(valid_importance) == 0:
break
unimportance_idx = valid_importance.pop(0)
constraint = unimportance_idx[0]
if center.unique_name() in pruned_center_constraint_all:
center_constraint_len = len(self.center_constraint[center.unique_name()])
center_pruned_constraint_len = len(pruned_center_constraint_all[center.unique_name()])
center_constraint = self.constraint_mapping(
constraint, leaf_to_center[center.unique_name()]
)
global_center_sparsity = (
center_pruned_constraint_len + len(pruned_leaf_idx) + len(center_constraint)
) / center_constraint_len
if global_center_sparsity > sparsity[center.unique_name()]:
break
current_sparsity = (
len(pruned_leaf_idx) + len(constraint) + leaf_delta_idx.get(leaf.unique_name(), 0)
) / total_idx
pruned_leaf_idx.update(constraint)
pruned_leaf_constraint.append(constraint)
calculated_leaf_constraint_all.append(constraint)
delta_idx = len(pruned_leaf_idx) - target_idx
leaf_delta_idx[leaf.unique_name()] = leaf_delta_idx.get(leaf.unique_name(), 0)
leaf_delta_idx[leaf.unique_name()] = leaf_delta_idx[leaf.unique_name()] + delta_idx
pruned_leaf_constraint_all[leaf.unique_name()] = pruned_leaf_constraint
for center_name in leaf_to_center.keys():
calculated_center_constraint_all[center_name] = calculated_center_constraint_all.get(
center_name, set()
)
pruned_center_constraint_all[center_name] = pruned_center_constraint_all.get(center_name, set())
# Sync leaf's pruning idx to all centers
for constraint in calculated_leaf_constraint_all + leaf_constraint_all:
center_constraint = self.constraint_mapping(constraint, leaf_to_center[center_name])
if center_constraint != -1:
calculated_center_constraint_all[center_name].update(center_constraint)
for constraint in pruned_leaf_constraint:
center_constraint = self.constraint_mapping(constraint, leaf_to_center[center_name])
if center_constraint != {-1}:
pruned_center_constraint_all[center_name].update(center_constraint)
return pruned_center_constraint_all, pruned_leaf_constraint_all
def calc_prune_idx(self, importance, sparsity, multiple=None):
"""
Calculate the dependence of index in the process of pruning. For convolutional pruning, it is the dependence
between channels. For more complex unstructured/semi-structured pruning, it may have a finer granularity.
"""
if self.center not in sparsity or sparsity[self.center] == 0.0:
return
center_constraint = {}
leaf_prune_dim = {}
leaf_constraint = {}
leaf_constraint_len = {}
for m in self.modifiers:
log.debug(f"modifier {m.unique_name()} prune dim = {dict(m.dim_changes_info.dim_choices)}")
for leaf in self.leaf:
# After all subgraph dependencies are resolved, each operator will only be pruned in a single dimension
if len(leaf.dim_changes_info.constraints_i) != 1:
log.warning(f"[{leaf.unique_name()}] Pruning in two dimensions at the same time is not supported")
return
leaf_prune_dim[leaf.unique_name()] = list(leaf.dim_changes_info.constraints_i.keys())[0]
leaf_constraint[leaf.unique_name()] = list(leaf.dim_changes_info.constraints_i.values())[0]
for center_name, constraints in leaf_constraint[leaf.unique_name()].items():
if center_name not in self.modifiers_dict.keys():
continue
leaf_constraint_len[leaf.unique_name()] = len(constraints[0])
for constraint in constraints:
center_constraint[center_name] = center_constraint.get(center_name, [])
center_constraint[center_name] += constraint
for center_name, constraints in center_constraint.items():
merge_constraint(constraints)
self.center_constraint = center_constraint
for center in self.dependent_centers:
if len(center.dim_changes_info.groups_o) > 0:
self.center_group[center.unique_name()] = center.dim_changes_info.groups_o
# Build constraint mapping between center and leaf
center_to_leaf_all = {}
leaf_to_center_all = {}
for leaf in self.leaf:
center_to_leaf = {}
leaf_to_center = {}
center_to_leaf_all[leaf.unique_name()] = center_to_leaf
leaf_to_center_all[leaf.unique_name()] = leaf_to_center
# center_constraint loses the constraint mapping information between center
# and leaf, so use the original leaf_constraint
for center_name, constraints in leaf_constraint[leaf.unique_name()].items():
if center_name not in self.modifiers_dict:
continue
if center_name not in center_to_leaf:
center_to_leaf[center_name] = {}
leaf_to_center[center_name] = {}
for constraint in constraints:
for leaf_idxes in range(len(constraint)):
leaf_to_center[center_name][leaf_idxes] = leaf_to_center[center_name].get(leaf_idxes, set())
leaf_to_center[center_name][leaf_idxes].update(constraint[leaf_idxes])
for center_idxes in constraint[leaf_idxes]:
center_to_leaf[center_name][center_idxes] = center_to_leaf[center_name].get(
center_idxes, set()
)
center_to_leaf[center_name][center_idxes].add(leaf_idxes)
if -1.0 in center_to_leaf[center_name]:
del center_to_leaf[center_name][-1.0]
# Aggregate all leaf constraints into a global center constraint
for leaf in self.leaf:
center_to_leaf = center_to_leaf_all[leaf.unique_name()]
leaf_to_center = leaf_to_center_all[leaf.unique_name()]
# Obtain the constraint of leaf through the constraint of center
leaf_constraint_all = []
for center_name, constraints in self.center_constraint.items():
if center_name not in center_to_leaf.keys():
continue
for constraint in constraints:
leaf_idx_constraint = set()
for center_idxes in constraint:
if center_idxes in center_to_leaf[center_name]:
leaf_idx_constraint.update(center_to_leaf[center_name][center_idxes])
if leaf_idx_constraint not in leaf_constraint_all:
leaf_constraint_all.append(leaf_idx_constraint)
merge_constraint(leaf_constraint_all)
# Pass the leaf constraint back to the center, so that the center nodes
# can get dependencies between each other
leaf_center_constraints = {}
for center_name in leaf_to_center.keys():
if center_name not in self.modifiers_dict:
continue
leaf_center_constraints[center_name] = []
for leaf_idx_constraint in leaf_constraint_all:
index_constraint = set()
for leaf_idxes in leaf_idx_constraint:
index_constraint.update(leaf_to_center[center_name][leaf_idxes])
if -1.0 in index_constraint:
index_constraint.remove(-1.0)
if index_constraint not in leaf_center_constraints[center_name]:
leaf_center_constraints[center_name].append(index_constraint)
for center_name in leaf_center_constraints.keys():
self.center_constraint[center_name] += leaf_center_constraints[center_name]
merge_constraint(self.center_constraint[center_name])
log.debug(f"leaf {leaf.unique_name()} constraint merge over")
# Aggregate all leaf group into a global center group
for leaf in self.leaf:
center_to_leaf = center_to_leaf_all[leaf.unique_name()]
leaf_to_center = leaf_to_center_all[leaf.unique_name()]
leaf_group_all = []
# TODO: Is it possible to skip when center_group has only one element?
for center_name, center_group in self.center_group.items():
if center_name not in center_to_leaf.keys():
continue
for group in center_group:
leaf_idx_group = set()
for center_idxes in group:
# Nodes such as split may cause the number of idx in leaf and center to be inconsistent
if center_idxes in center_to_leaf[center_name]:
leaf_idx_group.update(center_to_leaf[center_name][center_idxes])
leaf_group_all.append(leaf_idx_group)
if len(leaf.dim_changes_info.groups_i) > 0:
leaf_group_all += leaf.dim_changes_info.groups_i
merge_group(leaf_group_all)
leaf_center_groups = {}
for center_name in leaf_to_center.keys():
leaf_center_groups[center_name] = []
for leaf_idx_group in leaf_group_all:
index_group = set()
# Nodes such as split may cause the number of idx in leaf and center to be inconsistent
for leaf_idxes in leaf_idx_group:
if leaf_idxes in leaf_to_center[center_name]:
index_group.update(leaf_to_center[center_name][leaf_idxes])
if -1.0 in index_group:
index_group.remove(-1.0)
if len(index_group) > 0:
leaf_center_groups[center_name].append(index_group)
for center_name in leaf_center_groups.keys():
self.center_group[center_name] = self.center_group.get(center_name, [])
self.center_group[center_name] += leaf_center_groups[center_name]
merge_group(self.center_group[center_name])
for leaf in self.leaf:
center_to_leaf = center_to_leaf_all[leaf.unique_name()]
leaf_group = []
self.leaf_group[leaf.unique_name()] = leaf_group
for center_name, center_group in self.center_group.items():
if center_name not in center_to_leaf.keys():
continue
for group in center_group:
leaf_idx_group = set()
for center_idxes in group:
# split等节点可能导致leaf和center中的idx数量不一致,需要判断存在合法性
if center_idxes in center_to_leaf[center_name]:
leaf_idx_group.update(center_to_leaf[center_name][center_idxes])
leaf_group.append(leaf_idx_group)
if len(leaf.dim_changes_info.groups_i) > 0:
leaf_group += leaf.dim_changes_info.groups_i
merge_group(leaf_group)
log.debug(f"subgraph {self.center} group merge over")
# 1) Select a center
# 2) Map the center to all leaves, and then complete leaf pruning
# 3) Update the global center_pruned_constraint after leaf pruning
# 4) Prune the next center, and then exclude the pruned idx in center_pruned_constraint
# 5) Repeat the above steps until all centers are pruned
center_list = []
for center_name, constraint in self.center_constraint.items():
constraint_all = set()
for i in constraint:
constraint_all.update(i)
center_list.append((len(constraint_all), self.modifiers_dict[center_name]))
# Prioritize the center with the shortest constraint. If the center with the longest constraint
# is processed first, the short one may have an incorrect sparsity rate
center_list = sorted(center_list, key=lambda x: x[0])
center_list = [i[1] for i in center_list]
center_to_center_all = {}
for center in center_list:
center_name = center.unique_name()
center_to_center = {}
center_to_center_all[center.unique_name()] = center_to_center
for center_idxes in self.center_constraint[center.unique_name()]:
for center_idxes in center_idxes:
if center_idxes not in center_to_center:
center_to_center[center_idxes] = {}
for leaf in self.leaf:
leaf_name = leaf.unique_name()
center_to_leaf = center_to_leaf_all[leaf_name]
leaf_to_center = leaf_to_center_all[leaf_name]
if center.unique_name() not in center_to_leaf.keys():
continue
if center_idxes not in center_to_leaf[center_name]:
continue
leaf_idxes = center_to_leaf[center_name][center_idxes]
for leaf_idx in leaf_idxes:
for depend_center_name in leaf_to_center.keys():
if depend_center_name == center_name:
continue
depend_center_idxes = leaf_to_center[depend_center_name][leaf_idx]
if depend_center_idxes == {-1}:
continue
if depend_center_name not in center_to_center[center_idxes]:
center_to_center[center_idxes][depend_center_name] = set()
center_to_center[center_idxes][depend_center_name].update(depend_center_idxes)
if importance is not None:
pruned_center_constraint_all, pruned_leaf_constraint_all = self.calc_prune_idx_by_center_importance(
center_list, center_to_leaf_all, leaf_to_center_all, importance, center_to_center_all, sparsity
)
else:
pruned_center_constraint_all, pruned_leaf_constraint_all = self.calc_prune_idx_by_bn_variance(
center_list,
center_to_leaf_all,
leaf_to_center_all,
importance,
center_to_center_all,
sparsity,
multiple,
)
for center_name, constraint in pruned_center_constraint_all.items():
calculated_constraint = constraint
if -1 in calculated_constraint:
calculated_constraint.remove(-1)
calculated_constraint = list(calculated_constraint)
calculated_constraint.sort()
self.modifiers_dict[center_name].dim_changes_info.pruned_idx_o = calculated_constraint
for leaf_name, constraint in pruned_leaf_constraint_all.items():
calculated_constraint = set()
for i in constraint:
calculated_constraint.update(i)
if -1 in calculated_constraint:
calculated_constraint.remove(-1)
calculated_constraint = list(calculated_constraint)
calculated_constraint.sort()
self.modifiers_dict[leaf_name].dim_changes_info.pruned_idx_i = calculated_constraint
log.debug(f"subgraph {self.center} prune idx compute over")
def eliminate_conflict(self):
tensor_choices = OrderedDict()
# Make sure each tensor only prunes at most one dimension
for m in reversed(self.modifiers):
if m.dim_changes_info.is_multi_dim_changed():
log.debug(f"[{m.unique_name()}] multi dim changed")
if not m.dim_choose(tensor_choices):
log.error(f"[{m.unique_name()}] conflict can't be eliminated")
raise Exception("conflict can't be eliminated")
for m in self.modifiers:
for t in m.pre_tensors() + m.next_tensors():
if id(t) in tensor_choices:
m.dim_changes_info.update_choice(t, tensor_choices[id(t)])
elif id(t) in m.dim_changes_info.tensor_changes:
m.dim_changes_info.update_choice(t, m.dim_changes_info.merge_t(t))
else:
# The tensor has not changed
pass
for m in self.modifiers:
m.dim_changes_info.rebuild()
m.calc_idx_group()
def calc_importance(self):
pass
def build(self):
self.modifiers = list(set(self.modifiers))
self.modifiers = sorted(self.modifiers, key=lambda m: m.node.forward_order)
leaf = set()
for m in self.modifiers:
is_leaf = True
for next_m in m.next_modifiers():
if next_m in self.modifiers:
is_leaf = False
break
# In some ring subgraphs, a operator may be both center and leaf at the same time
for center in m.dim_changes_info.centers.values():
if m.prunable and center in self.modifiers and center is not m:
is_leaf = True
if not is_leaf:
continue
leaf.add(m)
for center in m.dim_changes_info.centers.values():
if center is not m:
self.dependent_centers.add(center)
self.leaf = list(leaf)
self.leaf = sorted(self.leaf, key=lambda m: m.node.forward_order)
return self
def __eq__(self, other):
if type(other) in [list, tuple]:
if self.modifiers == other:
return True
return False
class SubGraphDivider(object):
def __init__(self, graph: TraceGraph, modifiers: typing.Dict[str, Modifier]):
self.graph = graph
self.modifiers = OrderedDict(sorted(modifiers.items(), key=lambda x: x[1].node.forward_order))
self.tensors = graph.all_tensors()
self.sub_graphs = OrderedDict()
def reset_tensors(self):
for t in self.tensors:
# Do not modify the constants, otherwise the branches of operators such as shape need to be recalculated
if len(t.shape) >= 1:
# Cannot assign a value of 0, because 0 cannot distinguish between index 0 and an uninitialized index
t.data.fill_(-1)
def change_dimension(self):
self.reset_tensors()
dim_changed = False
for m in self.modifiers.values():
if dim_changed:
# Use the tensor generated in the trace process to save memory
self.reset_tensors()
dim_changed = m.change_dimension()
if dim_changed:
log.debug(f"operator [{m.unique_name()}] tracking over")
def divide_subgraph(self):
self.sub_graphs.clear()
sub_graphs = {}
for modifier in self.modifiers.values():
for center_name, center in modifier.dim_changes_info.centers.items():
if center_name not in sub_graphs:
sub_graphs[center_name] = set()
sub_graphs[center_name].add(modifier)
center_mapping = {}
for modifier in self.modifiers.values():
center_names = modifier.dim_changes_info.get_input_centers()
# TODO: It is more reasonable to arrange according to the forward order
center_names = sorted(list(set(center_names)))
if len(center_names) <= 1:
continue
main_center = center_names[0]
while main_center not in sub_graphs and main_center in center_mapping:
main_center = center_mapping[main_center]
for redundant_center in center_names[1:]:
if redundant_center == main_center:
continue
if redundant_center in sub_graphs:
sub_graphs[main_center].update(sub_graphs[redundant_center])
del sub_graphs[redundant_center]
center_mapping[redundant_center] = main_center
else:
redirect_center = redundant_center
while redirect_center not in sub_graphs:
redirect_center = center_mapping[redirect_center]
if redirect_center in sub_graphs and redirect_center != main_center:
sub_graphs[main_center].update(sub_graphs[redirect_center])
del sub_graphs[redirect_center]
center_mapping[redirect_center] = main_center
for center_name, modifiers in sub_graphs.items():
self.sub_graphs[center_name] = SubGraph(center_name, modifiers).build()
def divide(self) -> typing.Dict[str, SubGraph]:
log.info("Start tracking tensor dimension changes...")
self.change_dimension()
log.info("Start dividing subgraphs according to tensor dependencies...")
self.divide_subgraph()
log.info("Start to eliminate dimension change conflicts...")
for sub_graph in self.sub_graphs.values():
sub_graph.eliminate_conflict()
log.info("Start generating new subgraphs without conflicts...")
self.divide_subgraph()
return self.sub_graphs
class GraphChannelModifier(object):
graph: TraceGraph
center_nodes: typing.List[TraceNode]
sub_graphs: typing.Dict[str, SubGraph]
def __init__(self, graph: TraceGraph, center_nodes, bn_compensation=False):
"""Initialize a channel modifier for a calculation graph
Args:
graph: Compute graph generated by tracer
center_nodes: Operators that actively modify the channel
"""
self.graph = graph
self.center_nodes = center_nodes
self.bn_compensation = bn_compensation
if self.bn_compensation:
log.info("open bn compensation")
self.modifiers = self.register_modifier()
with torch.no_grad():
self.sub_graphs = SubGraphDivider(self.graph, self.modifiers).divide()
self.reset_masker()
def reset_masker(self):
self.unregister_masker()
for n in self.graph.forward_nodes:
masker.ChannelMasker(n.module, n.unique_name)
def unregister_masker(self):
mask_applied = False
for sub_graph in self.sub_graphs.values():
for m in sub_graph.modifiers:
m.reset_mask()
mask_applied = m.mask_applied or mask_applied
for n in self.graph.forward_nodes:
if hasattr(n.module, "masker"):
n.module.masker.unregister_all()
delattr(n.module, "masker")
if mask_applied:
self.graph.inited = False
def register_modifier(self) -> typing.Dict[str, Modifier]:
modifiers = OrderedDict()
for n in self.graph.all_nodes():
modifier = create_channel_modifier(n)
modifier.graph_modifier = self
modifiers[n.unique_name] = modifier
setattr(n, "modifier", modifier)
return modifiers
def unregister_modifier(self):
for n in self.graph.all_nodes():
delattr(n, "modifier")
self.unregister_masker()
def get_modifier(self, module=None, unique_name=None) -> Modifier:
if module is not None:
unique_name = self.graph.module_original_name_dict[id(module)]
return self.graph.nodes_map[unique_name].modifier
def fill_tensor_by_dim_changes(tensor, dim_changes, values=None):
tensor[:] = 0
for dim in reversed(dim_changes):
shape = tensor.shape[dim]
if values is not None:
assert len(values) == shape
for i in range(shape):
value_shape = list(tensor.shape)
value_shape[dim] = 1
if values:
value = torch.ones(value_shape, dtype=tensor.dtype) * values[i]
else:
value = torch.ones(value_shape, dtype=tensor.dtype) * i
tensor.index_add_(dim, torch.tensor(i), value)
return tensor