in models/networks/model_profiling.py [0:0]
def module_profiling(module, input, output, verbose):
if not isinstance(input[0], list):
# Some modules return a list of outputs. We usually ignore them.
ins = input[0].size()
outs = output.size()
# NOTE: There are some difference between type and isinstance, thus please
# be careful.
t = type(module)
if isinstance(module, nn.Conv2d):
module.n_macs = (
ins[1]
* outs[1]
* module.kernel_size[0]
* module.kernel_size[1]
* outs[2]
* outs[3]
// module.groups
) * outs[0]
module.n_params = get_params(module)
module.n_seconds = run_forward(module, input)
module.name = conv_module_name_filter(module.__repr__())
elif isinstance(module, nn.ConvTranspose2d):
module.n_macs = (
ins[1]
* outs[1]
* module.kernel_size[0]
* module.kernel_size[1]
* outs[2]
* outs[3]
// module.groups
) * outs[0]
module.n_params = get_params(module)
module.n_seconds = run_forward(module, input)
module.name = conv_module_name_filter(module.__repr__())
elif isinstance(module, nn.Linear):
module.n_macs = ins[1] * outs[1] * outs[0]
module.n_params = get_params(module)
module.n_seconds = run_forward(module, input)
module.name = module.__repr__()
elif isinstance(module, nn.AvgPool2d):
# NOTE: this function is correct only when stride == kernel size
module.n_macs = ins[1] * ins[2] * ins[3] * ins[0]
module.n_params = 0
module.n_seconds = run_forward(module, input)
module.name = module.__repr__()
elif isinstance(module, nn.AdaptiveAvgPool2d):
# NOTE: this function is correct only when stride == kernel size
module.n_macs = ins[1] * ins[2] * ins[3] * ins[0]
module.n_params = 0
module.n_seconds = run_forward(module, input)
module.name = module.__repr__()
else:
# This works only in depth-first travel of modules.
module.n_macs = 0
module.n_params = 0
module.n_seconds = 0
num_children = 0
for m in module.children():
module.n_macs += getattr(m, "n_macs", 0)
module.n_params += getattr(m, "n_params", 0)
module.n_seconds += getattr(m, "n_seconds", 0)
num_children += 1
ignore_zeros_t = [
nn.BatchNorm2d,
nn.InstanceNorm2d,
nn.Dropout2d,
nn.Dropout,
nn.Sequential,
nn.ReLU6,
nn.ReLU,
nn.MaxPool2d,
nn.modules.padding.ZeroPad2d,
nn.modules.activation.Sigmoid,
]
if (
not getattr(module, "ignore_model_profiling", False)
and module.n_macs == 0
and t not in ignore_zeros_t
):
print(
"WARNING: leaf module {} has zero n_macs.".format(type(module))
)
return
if verbose:
print(
module.name.ljust(name_space, " ")
+ "{:,}".format(module.n_params).rjust(params_space, " ")
+ "{:,}".format(module.n_macs).rjust(macs_space, " ")
+ "{:,}".format(module.n_seconds).rjust(seconds_space, " ")
)
return