def rebuild()

in tinynn/graph/modifier.py [0:0]


    def rebuild(self):
        """Reconstruct the entire dimension change information according to dim_choice"""

        valid_changes = []

        for tensor_id, choice in self.tensor_choices.items():
            for key in self.tensor_keys[tensor_id]:
                if 'input' in key:
                    tensor_change = self.dim_changes_i[key]
                else:
                    tensor_change = self.dim_changes_o[key]

                if set(choice).issubset(set(tensor_change)):
                    center_name = key.split(":")[0]
                    center = self.centers[center_name]
                    all_tensors = self.modifier.pre_tensors() + self.modifier.next_tensors()
                    tensor = [t for t in all_tensors if id(t) == tensor_id][0]
                    valid_changes.append((center, tensor, tensor_change))

        self.dim_changes_i = OrderedDict()
        self.dim_changes_o = OrderedDict()
        self.centers = OrderedDict()
        self.tensor_changes = OrderedDict()
        self.tensor_keys = OrderedDict()
        constraint_i_new = OrderedDict()
        constraint_o_new = OrderedDict()

        for changes in valid_changes:
            center, tensor, tensor_change = changes
            if self.modifier.is_pre_tensor(tensor):
                self.update_i(center, tensor, tensor_change, update_constraint=False)
                constraint_old = self.constraints_i
                constraint_new = constraint_i_new
            else:
                self.update_o(center, tensor, tensor_change, update_constraint=False)
                constraint_old = self.constraints_o
                constraint_new = constraint_o_new

            choice = self.get_tensor_choices(tensor)
            for dim, constraints in constraint_old.items():
                if dim in choice:
                    if dim not in constraint_new:
                        constraint_new[dim] = {}

                    constraint_new[dim] = constraints

        for dim, dim_constraints in constraint_i_new.items():
            for center_name, constraints in dim_constraints.items():
                if len(constraints) == 1:
                    continue

                merge = [set() for i in constraints[0]]
                for constraint in constraints:
                    for i in range(len(constraint)):
                        if constraint[i] != {-1}:
                            merge[i].update(constraint[i])
                constraints[:] = [merge]

        self.constraints_i = constraint_i_new
        self.constraints_o = constraint_o_new