def start_flops_count()

in models/src/ptflops/flops_counter.py [0:0]


def start_flops_count(self, **kwargs):
    """
    A method that will be available after add_flops_counting_methods() is called
    on a desired net object.
    Activates the computation of mean flops consumption per image.
    Call it before you run the network.
    """
    add_batch_counter_hook_function(self)

    seen_types = set()

    def add_flops_counter_hook_function(module, ost, verbose, ignore_list):
        if type(module) in ignore_list:
            seen_types.add(type(module))
            if is_supported_instance(module):
                module.__params__ = 0
        elif is_supported_instance(module):
            if hasattr(module, "__flops_handle__"):
                return
            if type(module) in CUSTOM_MODULES_MAPPING:
                handle = module.register_forward_hook(
                    CUSTOM_MODULES_MAPPING[type(module)]
                )
            else:
                handle = module.register_forward_hook(MODULES_MAPPING[type(module)])
            module.__flops_handle__ = handle
            seen_types.add(type(module))
        else:
            if (
                verbose
                and not type(module) in (nn.Sequential, nn.ModuleList)
                and not type(module) in seen_types
            ):
                print(
                    "Warning: module "
                    + type(module).__name__
                    + " is treated as a zero-op.",
                    file=ost,
                )
            seen_types.add(type(module))

    self.apply(partial(add_flops_counter_hook_function, **kwargs))