in fvcore/nn/jit_analysis.py [0:0]
def _analyze(self) -> "Statistics":
# Don't calculate if results are already stored.
stats = self._stats
if stats is not None:
return stats
with warnings.catch_warnings():
if self._warn_trace == "none":
warnings.simplefilter("ignore")
elif self._warn_trace == "no_tracer_warning":
warnings.filterwarnings("ignore", category=TracerWarning)
graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)
# Assures even modules not in the trace graph are initialized to zero count
counts = {}
unsupported_ops = {}
# We don't need the duplication here, but self._model.named_modules()
# gives slightly different results for some wrapped models.
for _, mod in _named_modules_with_dup(self._model):
name = self._aliases[mod]
counts[name] = Counter()
unsupported_ops[name] = Counter()
all_seen = set()
for node in graph.nodes():
kind = node.kind()
if kind == "prim::PythonOp":
# for PythonOp, pyname contains the actual name in Python
kind = kind + "." + node.pyname()
scope_names = node.scopeName().split("/")
all_seen.update(scope_names)
if self._ancestor_mode == "caller":
ancestors = set(scope_names)
else:
ancestors = self._get_all_ancestors(scope_names[-1])
all_seen.update(ancestors)
if kind not in self._op_handles:
if self._should_ignore_node(node):
continue
for name in ancestors:
unsupported_ops[name][kind] += 1
else:
inputs, outputs = list(node.inputs()), list(node.outputs())
op_counts = self._op_handles[kind](inputs, outputs)
if isinstance(op_counts, Number):
op_counts = Counter({self._simplify_op_name(kind): op_counts})
for v in op_counts.values():
if not isinstance(v, (int, float, np.float64, np.int64)):
raise ValueError(
f"Invalid type {type(v)} for the flop count! "
"Please use a wider type to avoid overflow."
)
# Assures an op contributes at most once to a module
for name in ancestors:
counts[name] += op_counts
uncalled_mods = set(self._aliases.values()) - all_seen
stats = Statistics(
counts=counts, unsupported_ops=unsupported_ops, uncalled_mods=uncalled_mods
)
self._stats = stats
self._warn_unsupported_ops(unsupported_ops[""])
self._warn_uncalled_mods(uncalled_mods)
return stats