in fvcore/nn/parameter_count.py [0:0]
def parameter_count_table(model: nn.Module, max_depth: int = 3) -> str:
"""
Format the parameter count of the model (and its submodules or parameters)
in a nice table. It looks like this:
::
| name | #elements or shape |
|:--------------------------------|:---------------------|
| model | 37.9M |
| backbone | 31.5M |
| backbone.fpn_lateral3 | 0.1M |
| backbone.fpn_lateral3.weight | (256, 512, 1, 1) |
| backbone.fpn_lateral3.bias | (256,) |
| backbone.fpn_output3 | 0.6M |
| backbone.fpn_output3.weight | (256, 256, 3, 3) |
| backbone.fpn_output3.bias | (256,) |
| backbone.fpn_lateral4 | 0.3M |
| backbone.fpn_lateral4.weight | (256, 1024, 1, 1) |
| backbone.fpn_lateral4.bias | (256,) |
| backbone.fpn_output4 | 0.6M |
| backbone.fpn_output4.weight | (256, 256, 3, 3) |
| backbone.fpn_output4.bias | (256,) |
| backbone.fpn_lateral5 | 0.5M |
| backbone.fpn_lateral5.weight | (256, 2048, 1, 1) |
| backbone.fpn_lateral5.bias | (256,) |
| backbone.fpn_output5 | 0.6M |
| backbone.fpn_output5.weight | (256, 256, 3, 3) |
| backbone.fpn_output5.bias | (256,) |
| backbone.top_block | 5.3M |
| backbone.top_block.p6 | 4.7M |
| backbone.top_block.p7 | 0.6M |
| backbone.bottom_up | 23.5M |
| backbone.bottom_up.stem | 9.4K |
| backbone.bottom_up.res2 | 0.2M |
| backbone.bottom_up.res3 | 1.2M |
| backbone.bottom_up.res4 | 7.1M |
| backbone.bottom_up.res5 | 14.9M |
| ...... | ..... |
Args:
model: a torch module
max_depth (int): maximum depth to recursively print submodules or
parameters
Returns:
str: the table to be printed
"""
count: typing.DefaultDict[str, int] = parameter_count(model)
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
param_shape: typing.Dict[str, typing.Tuple] = {
k: tuple(v.shape) for k, v in model.named_parameters()
}
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
table: typing.List[typing.Tuple] = []
def format_size(x: int) -> str:
if x > 1e8:
return "{:.1f}G".format(x / 1e9)
if x > 1e5:
return "{:.1f}M".format(x / 1e6)
if x > 1e2:
return "{:.1f}K".format(x / 1e3)
return str(x)
def fill(lvl: int, prefix: str) -> None:
if lvl >= max_depth:
return
for name, v in count.items():
if name.count(".") == lvl and name.startswith(prefix):
indent = " " * (lvl + 1)
if name in param_shape:
table.append((indent + name, indent + str(param_shape[name])))
else:
table.append((indent + name, indent + format_size(v)))
fill(lvl + 1, name + ".")
table.append(("model", format_size(count.pop(""))))
fill(0, "")
old_ws = tabulate.PRESERVE_WHITESPACE
tabulate.PRESERVE_WHITESPACE = True
tab = tabulate.tabulate(
table, headers=["name", "#elements or shape"], tablefmt="pipe"
)
tabulate.PRESERVE_WHITESPACE = old_ws
return tab