in models/src/ptflops/flops_counter.py [0:0]
def start_flops_count(self, **kwargs):
"""
A method that will be available after add_flops_counting_methods() is called
on a desired net object.
Activates the computation of mean flops consumption per image.
Call it before you run the network.
"""
add_batch_counter_hook_function(self)
seen_types = set()
def add_flops_counter_hook_function(module, ost, verbose, ignore_list):
if type(module) in ignore_list:
seen_types.add(type(module))
if is_supported_instance(module):
module.__params__ = 0
elif is_supported_instance(module):
if hasattr(module, "__flops_handle__"):
return
if type(module) in CUSTOM_MODULES_MAPPING:
handle = module.register_forward_hook(
CUSTOM_MODULES_MAPPING[type(module)]
)
else:
handle = module.register_forward_hook(MODULES_MAPPING[type(module)])
module.__flops_handle__ = handle
seen_types.add(type(module))
else:
if (
verbose
and not type(module) in (nn.Sequential, nn.ModuleList)
and not type(module) in seen_types
):
print(
"Warning: module "
+ type(module).__name__
+ " is treated as a zero-op.",
file=ost,
)
seen_types.add(type(module))
self.apply(partial(add_flops_counter_hook_function, **kwargs))