in captum/attr/_utils/summarizer.py [0:0]
def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]:
# We want to want to store two things:
# 1. A mapping from a Stat to Stat object (self._stat_to_stat):
# This is to retrieve an existing Stat object for dependency
# resolution, e.g. Mean needs the Count stat - we want to
# retrieve it in O(1)
#
# 2. All of the necessary stats, in the correct order,
# to perform an update for each Stat (self.stats) trivially
# As a reference, the dependency graph for our stats is as follows:
# StdDev(x) -> Var(x) -> MSE -> Mean -> Count, for all valid x
#
# Step 1:
# Ensure we have all the necessary stats
# i.e. ensure we have the dependencies
# Step 2:
# Figure out the order to update them
dep_order = [StdDev, Var, MSE, Mean, Count]
# remove dupe stats
stats = set(stats)
summary_stats = set(stats)
from collections import defaultdict
stats_by_module: Dict[Type, List[Stat]] = defaultdict(list)
for stat in stats:
stats_by_module[stat.__class__].append(stat)
# StdDev is an odd case since it is parameterized, thus
# for each StdDev(order) we must ensure there is an associated Var(order)
for std_dev in stats_by_module[StdDev]:
stat_to_add = Var(order=std_dev.order) # type: ignore
stats.add(stat_to_add)
stats_by_module[stat_to_add.__class__].append(stat_to_add)
# For the other modules (deps[1:n-1]): if i exists =>
# we want to ensure i...n-1 exists
for i, dep in enumerate(dep_order[1:]):
if dep in stats_by_module:
stats.update([mod() for mod in dep_order[i + 1 :]])
break
# Step 2: get the correct order
# NOTE: we are sorting via a given topological order
sort_order = {mod: i for i, mod in enumerate(dep_order)}
sort_order[Min] = -1
sort_order[Max] = -1
sort_order[Sum] = -1
stats = list(stats)
stats.sort(key=lambda x: sort_order[x.__class__], reverse=True)
# get the summary stat indices
summary_stat_indexs = []
for i, stat in enumerate(stats):
if stat in summary_stats:
summary_stat_indexs.append(i)
return stats, summary_stat_indexs