def offload_call_function()

in optimum/amd/brevitas/accelerate_utils.py [0:0]


def offload_call_function(model: torch.fx.GraphModule, device_map: Dict):
    """
    Attaches AlignDevicesHook to fx.GraphModule call_function nodes. Although accelerate's `offload_model` attaches hooks
    to submodules, it is unable to detect call_function.
    """
    # If we only have one device, offloading is not needed
    if len(set(device_map.values())) == 1:
        return

    for node in model.graph.nodes:
        if node.op == "call_function":

            def new_func(*args, old_callable=node.target, **kwargs):
                args = list(args)
                device_mapping = {}

                # Identify the device for each tensor in args and kwargs
                for _, arg in enumerate(args):
                    all_devices = find_all_devices(arg)
                    if all_devices is not None:
                        device_mapping.update(dict(all_devices))

                for k, v in kwargs.items():
                    all_devices = find_all_devices(arg)
                    if all_devices is not None:
                        device_mapping.update(dict(all_devices))

                total_devices = [d for d in list(device_mapping.values()) if d is not None]

                # If there is only one device, no re-alignement is necessary
                if len(set(total_devices)) > 1:
                    # Pick the main device, i.e. the first device that is not 'cpu' or 'disk'
                    if set(device_mapping.values()) == {"cpu"} or set(device_mapping.values()) == {"cpu", "disk"}:
                        device = "cpu"
                    else:
                        device = [d for d in device_mapping.values() if d not in ["cpu", "disk"]][0]
                    # Align args and kwargs to the same device
                    args = send_to_device(args, device)
                    kwargs = send_to_device(kwargs, device)

                out = old_callable(*args, **kwargs)

                if len(set(total_devices)) > 1:
                    # Restore the original device to avoid memory leaks
                    for k, v in device_mapping.items():
                        k = k.to(v)

                return out

            node.meta["orig_target"] = node.target
            node.target = new_func

    model.recompile()
    model.graph.lint()