in tinynn/graph/modifier.py [0:0]
def calc_prune_idx(self, importance, sparsity, multiple=None):
"""
Calculate the dependence of index in the process of pruning. For convolutional pruning, it is the dependence
between channels. For more complex unstructured/semi-structured pruning, it may have a finer granularity.
"""
if self.center not in sparsity or sparsity[self.center] == 0.0:
return
center_constraint = {}
leaf_prune_dim = {}
leaf_constraint = {}
leaf_constraint_len = {}
for m in self.modifiers:
log.debug(f"modifier {m.unique_name()} prune dim = {dict(m.dim_changes_info.dim_choices)}")
for leaf in self.leaf:
# After all subgraph dependencies are resolved, each operator will only be pruned in a single dimension
if len(leaf.dim_changes_info.constraints_i) != 1:
log.warning(f"[{leaf.unique_name()}] Pruning in two dimensions at the same time is not supported")
return
leaf_prune_dim[leaf.unique_name()] = list(leaf.dim_changes_info.constraints_i.keys())[0]
leaf_constraint[leaf.unique_name()] = list(leaf.dim_changes_info.constraints_i.values())[0]
for center_name, constraints in leaf_constraint[leaf.unique_name()].items():
if center_name not in self.modifiers_dict.keys():
continue
leaf_constraint_len[leaf.unique_name()] = len(constraints[0])
for constraint in constraints:
center_constraint[center_name] = center_constraint.get(center_name, [])
center_constraint[center_name] += constraint
for center_name, constraints in center_constraint.items():
merge_constraint(constraints)
self.center_constraint = center_constraint
for center in self.dependent_centers:
if len(center.dim_changes_info.groups_o) > 0:
self.center_group[center.unique_name()] = center.dim_changes_info.groups_o
# Build constraint mapping between center and leaf
center_to_leaf_all = {}
leaf_to_center_all = {}
for leaf in self.leaf:
center_to_leaf = {}
leaf_to_center = {}
center_to_leaf_all[leaf.unique_name()] = center_to_leaf
leaf_to_center_all[leaf.unique_name()] = leaf_to_center
# center_constraint loses the constraint mapping information between center
# and leaf, so use the original leaf_constraint
for center_name, constraints in leaf_constraint[leaf.unique_name()].items():
if center_name not in self.modifiers_dict:
continue
if center_name not in center_to_leaf:
center_to_leaf[center_name] = {}
leaf_to_center[center_name] = {}
for constraint in constraints:
for leaf_idxes in range(len(constraint)):
leaf_to_center[center_name][leaf_idxes] = leaf_to_center[center_name].get(leaf_idxes, set())
leaf_to_center[center_name][leaf_idxes].update(constraint[leaf_idxes])
for center_idxes in constraint[leaf_idxes]:
center_to_leaf[center_name][center_idxes] = center_to_leaf[center_name].get(
center_idxes, set()
)
center_to_leaf[center_name][center_idxes].add(leaf_idxes)
if -1.0 in center_to_leaf[center_name]:
del center_to_leaf[center_name][-1.0]
# Aggregate all leaf constraints into a global center constraint
for leaf in self.leaf:
center_to_leaf = center_to_leaf_all[leaf.unique_name()]
leaf_to_center = leaf_to_center_all[leaf.unique_name()]
# Obtain the constraint of leaf through the constraint of center
leaf_constraint_all = []
for center_name, constraints in self.center_constraint.items():
if center_name not in center_to_leaf.keys():
continue
for constraint in constraints:
leaf_idx_constraint = set()
for center_idxes in constraint:
if center_idxes in center_to_leaf[center_name]:
leaf_idx_constraint.update(center_to_leaf[center_name][center_idxes])
if leaf_idx_constraint not in leaf_constraint_all:
leaf_constraint_all.append(leaf_idx_constraint)
merge_constraint(leaf_constraint_all)
# Pass the leaf constraint back to the center, so that the center nodes
# can get dependencies between each other
leaf_center_constraints = {}
for center_name in leaf_to_center.keys():
if center_name not in self.modifiers_dict:
continue
leaf_center_constraints[center_name] = []
for leaf_idx_constraint in leaf_constraint_all:
index_constraint = set()
for leaf_idxes in leaf_idx_constraint:
index_constraint.update(leaf_to_center[center_name][leaf_idxes])
if -1.0 in index_constraint:
index_constraint.remove(-1.0)
if index_constraint not in leaf_center_constraints[center_name]:
leaf_center_constraints[center_name].append(index_constraint)
for center_name in leaf_center_constraints.keys():
self.center_constraint[center_name] += leaf_center_constraints[center_name]
merge_constraint(self.center_constraint[center_name])
log.debug(f"leaf {leaf.unique_name()} constraint merge over")
# Aggregate all leaf group into a global center group
for leaf in self.leaf:
center_to_leaf = center_to_leaf_all[leaf.unique_name()]
leaf_to_center = leaf_to_center_all[leaf.unique_name()]
leaf_group_all = []
# TODO: Is it possible to skip when center_group has only one element?
for center_name, center_group in self.center_group.items():
if center_name not in center_to_leaf.keys():
continue
for group in center_group:
leaf_idx_group = set()
for center_idxes in group:
# Nodes such as split may cause the number of idx in leaf and center to be inconsistent
if center_idxes in center_to_leaf[center_name]:
leaf_idx_group.update(center_to_leaf[center_name][center_idxes])
leaf_group_all.append(leaf_idx_group)
if len(leaf.dim_changes_info.groups_i) > 0:
leaf_group_all += leaf.dim_changes_info.groups_i
merge_group(leaf_group_all)
leaf_center_groups = {}
for center_name in leaf_to_center.keys():
leaf_center_groups[center_name] = []
for leaf_idx_group in leaf_group_all:
index_group = set()
# Nodes such as split may cause the number of idx in leaf and center to be inconsistent
for leaf_idxes in leaf_idx_group:
if leaf_idxes in leaf_to_center[center_name]:
index_group.update(leaf_to_center[center_name][leaf_idxes])
if -1.0 in index_group:
index_group.remove(-1.0)
if len(index_group) > 0:
leaf_center_groups[center_name].append(index_group)
for center_name in leaf_center_groups.keys():
self.center_group[center_name] = self.center_group.get(center_name, [])
self.center_group[center_name] += leaf_center_groups[center_name]
merge_group(self.center_group[center_name])
for leaf in self.leaf:
center_to_leaf = center_to_leaf_all[leaf.unique_name()]
leaf_group = []
self.leaf_group[leaf.unique_name()] = leaf_group
for center_name, center_group in self.center_group.items():
if center_name not in center_to_leaf.keys():
continue
for group in center_group:
leaf_idx_group = set()
for center_idxes in group:
# split等节点可能导致leaf和center中的idx数量不一致,需要判断存在合法性
if center_idxes in center_to_leaf[center_name]:
leaf_idx_group.update(center_to_leaf[center_name][center_idxes])
leaf_group.append(leaf_idx_group)
if len(leaf.dim_changes_info.groups_i) > 0:
leaf_group += leaf.dim_changes_info.groups_i
merge_group(leaf_group)
log.debug(f"subgraph {self.center} group merge over")
# 1) Select a center
# 2) Map the center to all leaves, and then complete leaf pruning
# 3) Update the global center_pruned_constraint after leaf pruning
# 4) Prune the next center, and then exclude the pruned idx in center_pruned_constraint
# 5) Repeat the above steps until all centers are pruned
center_list = []
for center_name, constraint in self.center_constraint.items():
constraint_all = set()
for i in constraint:
constraint_all.update(i)
center_list.append((len(constraint_all), self.modifiers_dict[center_name]))
# Prioritize the center with the shortest constraint. If the center with the longest constraint
# is processed first, the short one may have an incorrect sparsity rate
center_list = sorted(center_list, key=lambda x: x[0])
center_list = [i[1] for i in center_list]
center_to_center_all = {}
for center in center_list:
center_name = center.unique_name()
center_to_center = {}
center_to_center_all[center.unique_name()] = center_to_center
for center_idxes in self.center_constraint[center.unique_name()]:
for center_idxes in center_idxes:
if center_idxes not in center_to_center:
center_to_center[center_idxes] = {}
for leaf in self.leaf:
leaf_name = leaf.unique_name()
center_to_leaf = center_to_leaf_all[leaf_name]
leaf_to_center = leaf_to_center_all[leaf_name]
if center.unique_name() not in center_to_leaf.keys():
continue
if center_idxes not in center_to_leaf[center_name]:
continue
leaf_idxes = center_to_leaf[center_name][center_idxes]
for leaf_idx in leaf_idxes:
for depend_center_name in leaf_to_center.keys():
if depend_center_name == center_name:
continue
depend_center_idxes = leaf_to_center[depend_center_name][leaf_idx]
if depend_center_idxes == {-1}:
continue
if depend_center_name not in center_to_center[center_idxes]:
center_to_center[center_idxes][depend_center_name] = set()
center_to_center[center_idxes][depend_center_name].update(depend_center_idxes)
if importance is not None:
pruned_center_constraint_all, pruned_leaf_constraint_all = self.calc_prune_idx_by_center_importance(
center_list, center_to_leaf_all, leaf_to_center_all, importance, center_to_center_all, sparsity
)
else:
pruned_center_constraint_all, pruned_leaf_constraint_all = self.calc_prune_idx_by_bn_variance(
center_list,
center_to_leaf_all,
leaf_to_center_all,
importance,
center_to_center_all,
sparsity,
multiple,
)
for center_name, constraint in pruned_center_constraint_all.items():
calculated_constraint = constraint
if -1 in calculated_constraint:
calculated_constraint.remove(-1)
calculated_constraint = list(calculated_constraint)
calculated_constraint.sort()
self.modifiers_dict[center_name].dim_changes_info.pruned_idx_o = calculated_constraint
for leaf_name, constraint in pruned_leaf_constraint_all.items():
calculated_constraint = set()
for i in constraint:
calculated_constraint.update(i)
if -1 in calculated_constraint:
calculated_constraint.remove(-1)
calculated_constraint = list(calculated_constraint)
calculated_constraint.sort()
self.modifiers_dict[leaf_name].dim_changes_info.pruned_idx_i = calculated_constraint
log.debug(f"subgraph {self.center} prune idx compute over")