in optimum/amd/brevitas/accelerate_utils.py [0:0]
def remove_hooks(model: torch.nn.Module):
for module in model.modules():
if hasattr(module, "_hf_hook"):
if hasattr(module, "allocate_params"):
del module.allocate_params
if hasattr(module, "offload_params"):
del module.offload_params
remove_hook_from_module(model, recurse=True)
model.cpu()
if hasattr(model, "graph"):
for node in model.graph.nodes:
if node.op == "call_function":
if "orig_target" in node.meta:
node.target = node.meta["orig_target"]
del node.meta["orig_target"]
model.recompile()
model.graph.lint()