def _reorder_stats()

in captum/attr/_utils/summarizer.py [0:0]


def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]:
    # We want to want to store two things:
    # 1. A mapping from a Stat to Stat object (self._stat_to_stat):
    #    This is to retrieve an existing Stat object for dependency
    #    resolution, e.g.  Mean needs the Count stat - we want to
    #    retrieve it in O(1)
    #
    # 2. All of the necessary stats, in the correct order,
    #    to perform an update for each Stat (self.stats) trivially

    # As a reference, the dependency graph for our stats is as follows:
    # StdDev(x) -> Var(x) -> MSE -> Mean -> Count, for all valid x
    #
    # Step 1:
    #    Ensure we have all the necessary stats
    #    i.e. ensure we have the dependencies
    # Step 2:
    #    Figure out the order to update them
    dep_order = [StdDev, Var, MSE, Mean, Count]

    # remove dupe stats
    stats = set(stats)
    summary_stats = set(stats)

    from collections import defaultdict

    stats_by_module: Dict[Type, List[Stat]] = defaultdict(list)
    for stat in stats:
        stats_by_module[stat.__class__].append(stat)

    # StdDev is an odd case since it is parameterized, thus
    # for each StdDev(order) we must ensure there is an associated Var(order)
    for std_dev in stats_by_module[StdDev]:
        stat_to_add = Var(order=std_dev.order)  # type: ignore
        stats.add(stat_to_add)
        stats_by_module[stat_to_add.__class__].append(stat_to_add)

    # For the other modules (deps[1:n-1]): if i exists =>
    # we want to ensure i...n-1 exists
    for i, dep in enumerate(dep_order[1:]):
        if dep in stats_by_module:
            stats.update([mod() for mod in dep_order[i + 1 :]])
            break

    # Step 2: get the correct order
    # NOTE: we are sorting via a given topological order
    sort_order = {mod: i for i, mod in enumerate(dep_order)}
    sort_order[Min] = -1
    sort_order[Max] = -1
    sort_order[Sum] = -1

    stats = list(stats)
    stats.sort(key=lambda x: sort_order[x.__class__], reverse=True)

    # get the summary stat indices
    summary_stat_indexs = []
    for i, stat in enumerate(stats):
        if stat in summary_stats:
            summary_stat_indexs.append(i)
    return stats, summary_stat_indexs