in tinynn/graph/modifier.py [0:0]
def rebuild(self):
"""Reconstruct the entire dimension change information according to dim_choice"""
valid_changes = []
for tensor_id, choice in self.tensor_choices.items():
for key in self.tensor_keys[tensor_id]:
if 'input' in key:
tensor_change = self.dim_changes_i[key]
else:
tensor_change = self.dim_changes_o[key]
if set(choice).issubset(set(tensor_change)):
center_name = key.split(":")[0]
center = self.centers[center_name]
all_tensors = self.modifier.pre_tensors() + self.modifier.next_tensors()
tensor = [t for t in all_tensors if id(t) == tensor_id][0]
valid_changes.append((center, tensor, tensor_change))
self.dim_changes_i = OrderedDict()
self.dim_changes_o = OrderedDict()
self.centers = OrderedDict()
self.tensor_changes = OrderedDict()
self.tensor_keys = OrderedDict()
constraint_i_new = OrderedDict()
constraint_o_new = OrderedDict()
for changes in valid_changes:
center, tensor, tensor_change = changes
if self.modifier.is_pre_tensor(tensor):
self.update_i(center, tensor, tensor_change, update_constraint=False)
constraint_old = self.constraints_i
constraint_new = constraint_i_new
else:
self.update_o(center, tensor, tensor_change, update_constraint=False)
constraint_old = self.constraints_o
constraint_new = constraint_o_new
choice = self.get_tensor_choices(tensor)
for dim, constraints in constraint_old.items():
if dim in choice:
if dim not in constraint_new:
constraint_new[dim] = {}
constraint_new[dim] = constraints
for dim, dim_constraints in constraint_i_new.items():
for center_name, constraints in dim_constraints.items():
if len(constraints) == 1:
continue
merge = [set() for i in constraints[0]]
for constraint in constraints:
for i in range(len(constraint)):
if constraint[i] != {-1}:
merge[i].update(constraint[i])
constraints[:] = [merge]
self.constraints_i = constraint_i_new
self.constraints_o = constraint_o_new