def remove_reparameterization()

in apex/apex/reparameterization/__init__.py [0:0]


def remove_reparameterization(module, reparameterization=Reparameterization,
                                name='', remove_all=False):
    """
    Removes the given reparameterization of a parameter from a module.
    If no parameter is supplied then all reparameterizations are removed.
    Args:
        module (nn.Module): containing module
        reparameterization (Reparameterization): reparamaterization class to apply
        name (str, optional): name of weight parameter
        remove_all (bool, optional): if True, remove all reparamaterizations of given type. Default: False
    Example:
        >>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm)
        >>> remove_reparameterization(m)
    """
    if name != '' or remove_all:
        to_remove = []
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, reparameterization) and (hook.name == name or remove_all):
                hook.remove(module)
                to_remove.append(k)
        if len(to_remove) > 0:
            for k in to_remove:
                del module._forward_pre_hooks[k]
            return module
        if not remove_all:
            raise ValueError("reparameterization of '{}' not found in {}"
                             .format(name, module))
    else:
        modules = [module]+[x for x in module.modules()]
        for m in modules:
            remove_reparameterization(m, reparameterization=reparameterization, remove_all=True)
        return module