in opacus/validators/module_validator.py [0:0]
def fix(cls, module: nn.Module) -> nn.Module:
"""
Make the module and sub_modules DP compatible by running registered custom fixers.
Args:
module: The root module to be made compatible.
Returns:
Fixed module.
"""
module = clone_module(module)
# iterate over all sub_modules
# We have to get sub_module names in a list first as we will be
# changing the modules inside the the loop.
sub_module_names = [name for name, _ in trainable_modules(module)]
for sub_module_name in sub_module_names:
# get sub_module
sub_module = get_submodule(module, sub_module_name)
# if sub_module has a registered fixer
if type(sub_module) in ModuleValidator.FIXERS:
# get a repalcement for sub_module
sub_module_fixer = ModuleValidator.FIXERS[type(sub_module)]
new_sub_module = sub_module_fixer(sub_module)
# get module after replacement.
module = cls._repalce_sub_module(
root=module,
sub_module_name=sub_module_name,
new_sub_module=new_sub_module,
)
# log it
logger.info(
f"Replaced sub_module {sub_module_name} : {sub_module}"
f" with {new_sub_module}"
)
# return fixed module
return module