in fvcore/nn/print_model_statistics.py [0:0]
def _model_stats_str(model: nn.Module, statistics: Dict[str, Dict[str, str]]) -> str:
"""
This produces a representation of the model much like 'str(model)'
would, except the provided statistics are written out as additional
information for each submodule.
Args:
model (nn.Module) : the model to form a representation of.
statistics (dict(str, dict(str, str))) : the statistics to
include in the model representations. Organized as a dictionary
over module names, which are each a dictionary over statistics.
The statistics are assumed to be formatted already to the
desired string format for printing.
Returns:
str : the string representation of the model with the statistics
inserted.
"""
# Copied from nn.Module._addindent
def _addindent(s_: str, numSpaces: int) -> str:
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
def print_statistics(name: str) -> str:
if name not in statistics:
return ""
printed_stats = ["{}: {}".format(k, v) for k, v in statistics[name].items()]
return ", ".join(printed_stats)
# This comes directly from nn.Module.__repr__ with small changes
# to include the statistics.
def repr_with_statistics(module: nn.Module, name: str) -> str:
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = module.extra_repr()
printed_stats = print_statistics(name)
# empty string will be split into list ['']
if extra_repr:
extra_lines.extend(extra_repr.split("\n"))
if printed_stats:
extra_lines.extend(printed_stats.split("\n"))
child_lines = []
for key, submod in module._modules.items():
submod_name = name + ("." if name else "") + key
# pyre-fixme[6]: Expected `Module` for 1st param but got
# `Optional[nn.modules.module.Module]`.
submod_str = repr_with_statistics(submod, submod_name)
submod_str = _addindent(submod_str, 2)
child_lines.append("(" + key + "): " + submod_str)
lines = extra_lines + child_lines
main_str = module._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
return repr_with_statistics(model, "")