in optimum/amd/brevitas/accelerate_utils.py [0:0]
def offload_call_function(model: torch.fx.GraphModule, device_map: Dict):
"""
Attaches AlignDevicesHook to fx.GraphModule call_function nodes. Although accelerate's `offload_model` attaches hooks
to submodules, it is unable to detect call_function.
"""
# If we only have one device, offloading is not needed
if len(set(device_map.values())) == 1:
return
for node in model.graph.nodes:
if node.op == "call_function":
def new_func(*args, old_callable=node.target, **kwargs):
args = list(args)
device_mapping = {}
# Identify the device for each tensor in args and kwargs
for _, arg in enumerate(args):
all_devices = find_all_devices(arg)
if all_devices is not None:
device_mapping.update(dict(all_devices))
for k, v in kwargs.items():
all_devices = find_all_devices(arg)
if all_devices is not None:
device_mapping.update(dict(all_devices))
total_devices = [d for d in list(device_mapping.values()) if d is not None]
# If there is only one device, no re-alignement is necessary
if len(set(total_devices)) > 1:
# Pick the main device, i.e. the first device that is not 'cpu' or 'disk'
if set(device_mapping.values()) == {"cpu"} or set(device_mapping.values()) == {"cpu", "disk"}:
device = "cpu"
else:
device = [d for d in device_mapping.values() if d not in ["cpu", "disk"]][0]
# Align args and kwargs to the same device
args = send_to_device(args, device)
kwargs = send_to_device(kwargs, device)
out = old_callable(*args, **kwargs)
if len(set(total_devices)) > 1:
# Restore the original device to avoid memory leaks
for k, v in device_mapping.items():
k = k.to(v)
return out
node.meta["orig_target"] = node.target
node.target = new_func
model.recompile()
model.graph.lint()