def remove_hooks()

in optimum/amd/brevitas/accelerate_utils.py [0:0]


def remove_hooks(model: torch.nn.Module):
    for module in model.modules():
        if hasattr(module, "_hf_hook"):
            if hasattr(module, "allocate_params"):
                del module.allocate_params
            if hasattr(module, "offload_params"):
                del module.offload_params
    remove_hook_from_module(model, recurse=True)
    model.cpu()
    if hasattr(model, "graph"):
        for node in model.graph.nodes:
            if node.op == "call_function":
                if "orig_target" in node.meta:
                    node.target = node.meta["orig_target"]
                    del node.meta["orig_target"]
        model.recompile()
        model.graph.lint()