in optimum/amd/brevitas/accelerate_utils.py [0:0]
def offload_params(module):
"""
This functions moves the parameters back to the meta device, after making sure to update the
internal state dict with the most recent values.
"""
if module._hf_hook.offload is False:
return
update_internal_dict(module)
for m in module.modules():
if hasattr(m, "_hf_hook"):
m._hf_hook.post_forward(m, torch.tensor([]))