def calc_prune_idx()

in tinynn/graph/modifier.py [0:0]


    def calc_prune_idx(self, importance, sparsity, multiple=None):
        """
        Calculate the dependence of index in the process of pruning. For convolutional pruning, it is the dependence
        between channels. For more complex unstructured/semi-structured pruning, it may have a finer granularity.
        """
        if self.center not in sparsity or sparsity[self.center] == 0.0:
            return

        center_constraint = {}
        leaf_prune_dim = {}
        leaf_constraint = {}
        leaf_constraint_len = {}

        for m in self.modifiers:
            log.debug(f"modifier {m.unique_name()} prune dim = {dict(m.dim_changes_info.dim_choices)}")

        for leaf in self.leaf:
            # After all subgraph dependencies are resolved, each operator will only be pruned in a single dimension
            if len(leaf.dim_changes_info.constraints_i) != 1:
                log.warning(f"[{leaf.unique_name()}] Pruning in two dimensions at the same time is not supported")
                return

            leaf_prune_dim[leaf.unique_name()] = list(leaf.dim_changes_info.constraints_i.keys())[0]
            leaf_constraint[leaf.unique_name()] = list(leaf.dim_changes_info.constraints_i.values())[0]

            for center_name, constraints in leaf_constraint[leaf.unique_name()].items():
                if center_name not in self.modifiers_dict.keys():
                    continue

                leaf_constraint_len[leaf.unique_name()] = len(constraints[0])
                for constraint in constraints:
                    center_constraint[center_name] = center_constraint.get(center_name, [])
                    center_constraint[center_name] += constraint

        for center_name, constraints in center_constraint.items():
            merge_constraint(constraints)

        self.center_constraint = center_constraint

        for center in self.dependent_centers:
            if len(center.dim_changes_info.groups_o) > 0:
                self.center_group[center.unique_name()] = center.dim_changes_info.groups_o

        # Build constraint mapping between center and leaf
        center_to_leaf_all = {}
        leaf_to_center_all = {}
        for leaf in self.leaf:
            center_to_leaf = {}
            leaf_to_center = {}

            center_to_leaf_all[leaf.unique_name()] = center_to_leaf
            leaf_to_center_all[leaf.unique_name()] = leaf_to_center

            # center_constraint loses the constraint mapping information between center
            # and leaf, so use the original leaf_constraint
            for center_name, constraints in leaf_constraint[leaf.unique_name()].items():
                if center_name not in self.modifiers_dict:
                    continue

                if center_name not in center_to_leaf:
                    center_to_leaf[center_name] = {}
                    leaf_to_center[center_name] = {}

                for constraint in constraints:
                    for leaf_idxes in range(len(constraint)):
                        leaf_to_center[center_name][leaf_idxes] = leaf_to_center[center_name].get(leaf_idxes, set())
                        leaf_to_center[center_name][leaf_idxes].update(constraint[leaf_idxes])

                        for center_idxes in constraint[leaf_idxes]:
                            center_to_leaf[center_name][center_idxes] = center_to_leaf[center_name].get(
                                center_idxes, set()
                            )
                            center_to_leaf[center_name][center_idxes].add(leaf_idxes)

                if -1.0 in center_to_leaf[center_name]:
                    del center_to_leaf[center_name][-1.0]

        # Aggregate all leaf constraints into a global center constraint
        for leaf in self.leaf:
            center_to_leaf = center_to_leaf_all[leaf.unique_name()]
            leaf_to_center = leaf_to_center_all[leaf.unique_name()]

            # Obtain the constraint of leaf through the constraint of center
            leaf_constraint_all = []
            for center_name, constraints in self.center_constraint.items():
                if center_name not in center_to_leaf.keys():
                    continue
                for constraint in constraints:
                    leaf_idx_constraint = set()
                    for center_idxes in constraint:
                        if center_idxes in center_to_leaf[center_name]:
                            leaf_idx_constraint.update(center_to_leaf[center_name][center_idxes])

                    if leaf_idx_constraint not in leaf_constraint_all:
                        leaf_constraint_all.append(leaf_idx_constraint)
            merge_constraint(leaf_constraint_all)

            # Pass the leaf constraint back to the center, so that the center nodes
            # can get dependencies between each other
            leaf_center_constraints = {}
            for center_name in leaf_to_center.keys():
                if center_name not in self.modifiers_dict:
                    continue
                leaf_center_constraints[center_name] = []
                for leaf_idx_constraint in leaf_constraint_all:
                    index_constraint = set()
                    for leaf_idxes in leaf_idx_constraint:
                        index_constraint.update(leaf_to_center[center_name][leaf_idxes])
                    if -1.0 in index_constraint:
                        index_constraint.remove(-1.0)
                    if index_constraint not in leaf_center_constraints[center_name]:
                        leaf_center_constraints[center_name].append(index_constraint)
            for center_name in leaf_center_constraints.keys():
                self.center_constraint[center_name] += leaf_center_constraints[center_name]
                merge_constraint(self.center_constraint[center_name])

            log.debug(f"leaf {leaf.unique_name()} constraint merge over")

        # Aggregate all leaf group into a global center group
        for leaf in self.leaf:
            center_to_leaf = center_to_leaf_all[leaf.unique_name()]
            leaf_to_center = leaf_to_center_all[leaf.unique_name()]

            leaf_group_all = []
            # TODO: Is it possible to skip when center_group has only one element?
            for center_name, center_group in self.center_group.items():
                if center_name not in center_to_leaf.keys():
                    continue

                for group in center_group:
                    leaf_idx_group = set()
                    for center_idxes in group:
                        # Nodes such as split may cause the number of idx in leaf and center to be inconsistent
                        if center_idxes in center_to_leaf[center_name]:
                            leaf_idx_group.update(center_to_leaf[center_name][center_idxes])

                    leaf_group_all.append(leaf_idx_group)

            if len(leaf.dim_changes_info.groups_i) > 0:
                leaf_group_all += leaf.dim_changes_info.groups_i

            merge_group(leaf_group_all)

            leaf_center_groups = {}
            for center_name in leaf_to_center.keys():
                leaf_center_groups[center_name] = []
                for leaf_idx_group in leaf_group_all:
                    index_group = set()
                    # Nodes such as split may cause the number of idx in leaf and center to be inconsistent
                    for leaf_idxes in leaf_idx_group:
                        if leaf_idxes in leaf_to_center[center_name]:
                            index_group.update(leaf_to_center[center_name][leaf_idxes])
                    if -1.0 in index_group:
                        index_group.remove(-1.0)
                    if len(index_group) > 0:
                        leaf_center_groups[center_name].append(index_group)

            for center_name in leaf_center_groups.keys():
                self.center_group[center_name] = self.center_group.get(center_name, [])
                self.center_group[center_name] += leaf_center_groups[center_name]
                merge_group(self.center_group[center_name])

        for leaf in self.leaf:
            center_to_leaf = center_to_leaf_all[leaf.unique_name()]

            leaf_group = []
            self.leaf_group[leaf.unique_name()] = leaf_group
            for center_name, center_group in self.center_group.items():
                if center_name not in center_to_leaf.keys():
                    continue

                for group in center_group:
                    leaf_idx_group = set()
                    for center_idxes in group:
                        # split等节点可能导致leaf和center中的idx数量不一致,需要判断存在合法性
                        if center_idxes in center_to_leaf[center_name]:
                            leaf_idx_group.update(center_to_leaf[center_name][center_idxes])

                    leaf_group.append(leaf_idx_group)

            if len(leaf.dim_changes_info.groups_i) > 0:
                leaf_group += leaf.dim_changes_info.groups_i

            merge_group(leaf_group)

        log.debug(f"subgraph {self.center} group merge over")

        # 1) Select a center
        # 2) Map the center to all leaves, and then complete leaf pruning
        # 3) Update the global center_pruned_constraint after leaf pruning
        # 4) Prune the next center, and then exclude the pruned idx in center_pruned_constraint
        # 5) Repeat the above steps until all centers are pruned
        center_list = []
        for center_name, constraint in self.center_constraint.items():
            constraint_all = set()
            for i in constraint:
                constraint_all.update(i)
            center_list.append((len(constraint_all), self.modifiers_dict[center_name]))

        # Prioritize the center with the shortest constraint. If the center with the longest constraint
        # is processed first, the short one may have an incorrect sparsity rate
        center_list = sorted(center_list, key=lambda x: x[0])
        center_list = [i[1] for i in center_list]

        center_to_center_all = {}

        for center in center_list:
            center_name = center.unique_name()
            center_to_center = {}
            center_to_center_all[center.unique_name()] = center_to_center

            for center_idxes in self.center_constraint[center.unique_name()]:
                for center_idxes in center_idxes:
                    if center_idxes not in center_to_center:
                        center_to_center[center_idxes] = {}

                    for leaf in self.leaf:
                        leaf_name = leaf.unique_name()
                        center_to_leaf = center_to_leaf_all[leaf_name]
                        leaf_to_center = leaf_to_center_all[leaf_name]

                        if center.unique_name() not in center_to_leaf.keys():
                            continue

                        if center_idxes not in center_to_leaf[center_name]:
                            continue

                        leaf_idxes = center_to_leaf[center_name][center_idxes]
                        for leaf_idx in leaf_idxes:
                            for depend_center_name in leaf_to_center.keys():
                                if depend_center_name == center_name:
                                    continue

                                depend_center_idxes = leaf_to_center[depend_center_name][leaf_idx]
                                if depend_center_idxes == {-1}:
                                    continue
                                if depend_center_name not in center_to_center[center_idxes]:
                                    center_to_center[center_idxes][depend_center_name] = set()
                                center_to_center[center_idxes][depend_center_name].update(depend_center_idxes)

        if importance is not None:
            pruned_center_constraint_all, pruned_leaf_constraint_all = self.calc_prune_idx_by_center_importance(
                center_list, center_to_leaf_all, leaf_to_center_all, importance, center_to_center_all, sparsity
            )
        else:
            pruned_center_constraint_all, pruned_leaf_constraint_all = self.calc_prune_idx_by_bn_variance(
                center_list,
                center_to_leaf_all,
                leaf_to_center_all,
                importance,
                center_to_center_all,
                sparsity,
                multiple,
            )

        for center_name, constraint in pruned_center_constraint_all.items():
            calculated_constraint = constraint

            if -1 in calculated_constraint:
                calculated_constraint.remove(-1)
            calculated_constraint = list(calculated_constraint)
            calculated_constraint.sort()

            self.modifiers_dict[center_name].dim_changes_info.pruned_idx_o = calculated_constraint

        for leaf_name, constraint in pruned_leaf_constraint_all.items():
            calculated_constraint = set()
            for i in constraint:
                calculated_constraint.update(i)

            if -1 in calculated_constraint:
                calculated_constraint.remove(-1)
            calculated_constraint = list(calculated_constraint)
            calculated_constraint.sort()

            self.modifiers_dict[leaf_name].dim_changes_info.pruned_idx_i = calculated_constraint

        log.debug(f"subgraph {self.center} prune idx compute over")