optimum/amd/brevitas/accelerate_utils.py (309 lines of code) (raw):
# Copyright 2023 The HuggingFace Team. All rights reserved.
# Licensed under the MIT License.
import logging
from typing import Dict, Mapping, Optional, Union
import brevitas.config as config
import torch
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_module
from accelerate.utils import (
check_tied_parameters_in_config,
compute_module_sizes,
find_tied_parameters,
get_max_layer_size,
get_max_memory,
send_to_device,
)
from accelerate.utils.modeling import named_module_tensors
from brevitas.graph.utils import get_module
from brevitas.utils.python_utils import recurse_getattr
from psutil import virtual_memory
logger = logging.getLogger(__name__)
def align_input(model, device_map):
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
hook = AlignDevicesHook(execution_device=main_device, io_same_device=True, skip_keys=None, tied_params_map=None)
add_hook_to_module(model, hook)
return model
# Adapted from accelerate.utils.modeling.infer_auto_device_map
def infer_fx_auto_device_map(
model: torch.fx.GraphModule,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
):
"""
Extends accelerate's infer_auto_device_map function to be compatible with torch.fx.GraphModule.
The main modifications are:
- Work around the fact that module.__class__.__name__ is Module for everything
- We do not need to keep entire blocks together anymore, since we add a functional equivalent of the AlignDeviceHook
before every call function.
"""
# TODO: Why no no_split_module_classes, clean_result parameters?
# Get default / clean up max_memory
max_memory = get_max_memory(max_memory)
devices = list(max_memory.keys())
if "disk" not in devices:
devices.append("disk")
gpus = [device for device in devices if device not in ["cpu", "disk"]]
# Devices that need to keep space for a potential offloaded layer.
if "mps" in gpus:
main_devices = ["mps"]
elif len(gpus) > 0:
main_devices = [gpus[0], "cpu"]
else:
main_devices = ["cpu"]
module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
tied_parameters = find_tied_parameters(model)
if check_tied_parameters_in_config(model) and len(tied_parameters) == 0:
logger.warn(
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
)
device_map = {}
current_device = 0
current_memory_used = 0
call_list = []
buffers_attributes = [n for n, _ in list(named_module_tensors(model, recurse=True))]
all_modules = [n.target for n in list(model.graph.nodes) if n.op == "call_module"]
for node in model.graph.nodes:
# If it's a module, we simply offload it or move it to the desired device
if node.op == "call_module":
name = node.target
module = get_module(model, node.target)
call_list.append((name, module))
# If it's get_attr, we check what module it is attached to
# In case the module is not part of call_module, we specifically allocate the buffer/parameter on some device
# NB: This does NOT guarantee that it will be aligned with whatever input tensor it will be combined with
# For that, there is a separate function
if node.op == "get_attr":
target = node.target
if target in buffers_attributes:
module_name = ".".join(target.split(".")[:-1])
if module_name not in all_modules:
module = get_module(model, target)
call_list.append((target, module))
# Direct submodules and parameters
modules_to_treat = call_list
# Initialize maximum largest layer, to know which space to keep in memory
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, [])
# Ready ? This is going to be a bit messy.
while len(modules_to_treat) > 0:
name, module = modules_to_treat.pop(0)
if verbose:
print(f"\nTreating module {name}.")
# Max size in the remaining layers may have changed since we took one, so we maybe update it.
max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")]
if len(max_layer_names) == 0:
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], module_sizes, []
)
# Assess size needed
module_size = module_sizes[name]
# We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module
# and the other is not.
tied_param_goups = [
tied_group
for tied_group in tied_parameters
if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
]
if verbose and len(tied_param_goups) > 0:
print(f" Found the relevant tied param groups {tied_param_goups}")
# Then we keep track of all the parameters that are tied to the current module, but not in the current module
tied_params = sum(
[[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_goups], []
)
if verbose and len(tied_params) > 0:
print(f" So those parameters need to be taken into account {tied_params}")
device = devices[current_device]
current_max_size = max_memory[device] if device != "disk" else None
# Reduce max size available by the largest layer.
if devices[current_device] in main_devices:
current_max_size = current_max_size - max_layer_size
# Case 1 -> We're too big!
if current_max_size is not None and current_memory_used + module_size > current_max_size:
# For FX, we never split a leaf call_module
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} (space available "
f"{current_max_size-current_memory_used}, module size {module_size})."
)
if verbose:
print("This module cannot be split, going to the next device.")
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
# Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters.
elif len(tied_params) > 0:
# First locate all tied modules
tied_module_names = []
tied_modules = []
for tied_param in tied_params:
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0]
tied_module_names.append(modules_to_treat[tied_module_index][0])
tied_modules.append(modules_to_treat[tied_module_index][1])
if verbose:
print(
f" It looks like {name} is going to fit on {devices[current_device]} but we have tied "
f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}"
)
# Let's see if it all fits first
module_size_with_ties = module_size
for tied_param, tied_module_name in zip(tied_params, tied_module_names):
module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]
if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size:
# We really really fit!
if verbose:
print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.")
current_memory_used += module_size_with_ties
device_map[name] = devices[current_device]
for tied_module_name in tied_module_names:
if tied_module_name in [m[0] for m in modules_to_treat]:
# The module may have been removed by a previous iteration of this loop.
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][
0
]
modules_to_treat.pop(tied_module_index)
device_map[tied_module_name] = devices[current_device]
else:
# We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it
# smaller or do we need to go on the next device?
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
f"available {current_max_size-current_memory_used}, needed size {module_size_with_ties})."
)
split_happened = False
for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
tied_module_children = list(tied_module.named_children())
if len(tied_module_children) == 0:
# can't break this one.
continue
if verbose:
print(f"Splitting {tied_module_name}.")
tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]
modules_to_treat = (
[(name, module)]
+ modules_to_treat[:tied_module_index]
+ tied_module_children
+ modules_to_treat[tied_module_index + 1 :]
)
# Update the max layer size.
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
module_sizes,
[],
)
split_happened = True
break
if not split_happened:
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
else:
if verbose:
if current_max_size is None:
print(f"Putting {name} (size={module_size}) on {devices[current_device]}.")
else:
print(
f"Putting {name} (size={module_size}) on {devices[current_device]} "
f"(available={current_max_size-current_memory_used})."
)
current_memory_used += module_size
device_map[name] = devices[current_device]
# If we have only one device, we simplify the device_map
if len(set(device_map.values())) == 1:
device_map = {"": list(device_map.values())[0]}
return device_map
def offload_call_function(model: torch.fx.GraphModule, device_map: Dict):
"""
Attaches AlignDevicesHook to fx.GraphModule call_function nodes. Although accelerate's `offload_model` attaches hooks
to submodules, it is unable to detect call_function.
"""
# If we only have one device, offloading is not needed
if len(set(device_map.values())) == 1:
return
for node in model.graph.nodes:
if node.op == "call_function":
def new_func(*args, old_callable=node.target, **kwargs):
args = list(args)
device_mapping = {}
# Identify the device for each tensor in args and kwargs
for _, arg in enumerate(args):
all_devices = find_all_devices(arg)
if all_devices is not None:
device_mapping.update(dict(all_devices))
for k, v in kwargs.items():
all_devices = find_all_devices(arg)
if all_devices is not None:
device_mapping.update(dict(all_devices))
total_devices = [d for d in list(device_mapping.values()) if d is not None]
# If there is only one device, no re-alignement is necessary
if len(set(total_devices)) > 1:
# Pick the main device, i.e. the first device that is not 'cpu' or 'disk'
if set(device_mapping.values()) == {"cpu"} or set(device_mapping.values()) == {"cpu", "disk"}:
device = "cpu"
else:
device = [d for d in device_mapping.values() if d not in ["cpu", "disk"]][0]
# Align args and kwargs to the same device
args = send_to_device(args, device)
kwargs = send_to_device(kwargs, device)
out = old_callable(*args, **kwargs)
if len(set(total_devices)) > 1:
# Restore the original device to avoid memory leaks
for k, v in device_mapping.items():
k = k.to(v)
return out
node.meta["orig_target"] = node.target
node.target = new_func
model.recompile()
model.graph.lint()
def remove_hooks(model: torch.nn.Module):
for module in model.modules():
if hasattr(module, "_hf_hook"):
if hasattr(module, "allocate_params"):
del module.allocate_params
if hasattr(module, "offload_params"):
del module.offload_params
remove_hook_from_module(model, recurse=True)
model.cpu()
if hasattr(model, "graph"):
for node in model.graph.nodes:
if node.op == "call_function":
if "orig_target" in node.meta:
node.target = node.meta["orig_target"]
del node.meta["orig_target"]
model.recompile()
model.graph.lint()
def update_internal_dict(module):
prefix = module._hf_hook.weights_map.prefix
for key in module.state_dict().keys():
# It might happen that we call an quantization's inner modules, and this cause some parameters to be
# already on meta device. This is not a problem for their value but we need to check here
curr_device = (recurse_getattr(module, key + ".data")).device
if str(curr_device) != "meta":
module._hf_hook.weights_map.dataset.state_dict[prefix + key] = (
recurse_getattr(module, key + ".data")
).cpu()
def find_all_devices(data):
"""
Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device).
Args:
(nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of.
"""
if isinstance(data, Mapping):
devices = []
for obj in data.values():
device = find_all_devices(obj)
if device is not None:
devices.extend(device)
return devices
elif isinstance(data, (tuple, list)):
devices = []
for obj in data:
device = find_all_devices(obj)
if device is not None:
devices.extend(device)
return devices
elif isinstance(data, torch.Tensor):
return [(data, str(data.device))]
def calc_gpu_device_map(absolute_mem_margin: float = 2.0 * 1e9, relative_mem_margin: float = 0.3) -> Dict[int, float]:
torch.cuda.empty_cache()
gpu_device_map = {
i: (torch.cuda.mem_get_info(i)[0] - absolute_mem_margin) * (1.0 - relative_mem_margin)
for i in range(torch.cuda.device_count())
}
return gpu_device_map
def calc_cpu_device_map(absolute_mem_margin: float = 2.0 * 1e9, relative_mem_margin: float = 0.3) -> Dict[str, float]:
cpu_device_map = {"cpu": (virtual_memory().available - absolute_mem_margin) * (1.0 - relative_mem_margin)}
return cpu_device_map
def offload_model(
model: torch.nn.Module,
gpu_device_map: Optional[Dict[int, float]] = None,
cpu_device_map: Optional[Dict[str, float]] = None,
) -> torch.nn.Module:
"""
Wraps accelerate's infer_auto_device_map and dispatch_model.
This functions if compatible both with classic nn.Modules, and with torch.fx.GraphModule.
"""
# FX vs non-FX model need different offloading
config._FULL_STATE_DICT = True
if gpu_device_map is None:
gpu_device_map = calc_gpu_device_map()
if cpu_device_map is None:
cpu_device_map = calc_cpu_device_map()
memory_map = {**cpu_device_map, **gpu_device_map}
if isinstance(model, torch.fx.GraphModule):
device_map = infer_fx_auto_device_map(model, memory_map)
offload_call_function(model, device_map)
else:
device_map = infer_auto_device_map(model, memory_map, no_split_module_classes=model._no_split_modules)
model = dispatch_model(model, device_map)
# Fixes an asymetric behavior in Accelerate where hooks are not attached at all when a single device is used.
# TODO: Fix directly in accelerate.
if len(set(device_map.values())) == 1:
model = align_input(model, device_map)
config._FULL_STATE_DICT = False
if "disk" in model.hf_device_map.values():
raise ValueError("Disk offload is not supported with quantization.")
# We attach these functions to the hooked modules for convenience when modifying parameters during PTQ (e.g. SmoothQuant).
# Attaching these functions allows use to fix a bug in accelerate with offloading to RAM/disk where even though a submodule parameter is updated, it is actually not updated in the AlignDevicesHook `weights_map` and thus
# the update is ignored elsewhere.
# TODO: Fix this bug directly in accelerate. https://github.com/huggingface/accelerate/pull/2214 would fix the bug for RAM offliading.
def allocate_params(module):
"""
This function calls the pre_forward function of the _hf_hook, making sure parameters are on
the selected device, rather than on the meta device.
"""
if module._hf_hook.offload is False:
return
# When quantizing and retrieving parameters (e.g., during GPTQ), we want to recurse through
# all the submodules
for m in module.modules():
if hasattr(m, "_hf_hook"):
m._hf_hook.pre_forward(m)
def offload_params(module):
"""
This functions moves the parameters back to the meta device, after making sure to update the
internal state dict with the most recent values.
"""
if module._hf_hook.offload is False:
return
update_internal_dict(module)
for m in module.modules():
if hasattr(m, "_hf_hook"):
m._hf_hook.post_forward(m, torch.tensor([]))
for module in model.modules():
if hasattr(module, "_hf_hook"):
module.allocate_params = allocate_params
module.offload_params = offload_params
return model