optimum/amd/brevitas/accelerate_utils.py (309 lines of code) (raw):

# Copyright 2023 The HuggingFace Team. All rights reserved. # Licensed under the MIT License. import logging from typing import Dict, Mapping, Optional, Union import brevitas.config as config import torch from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_module from accelerate.utils import ( check_tied_parameters_in_config, compute_module_sizes, find_tied_parameters, get_max_layer_size, get_max_memory, send_to_device, ) from accelerate.utils.modeling import named_module_tensors from brevitas.graph.utils import get_module from brevitas.utils.python_utils import recurse_getattr from psutil import virtual_memory logger = logging.getLogger(__name__) def align_input(model, device_map): if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: main_device = "cpu" else: main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] hook = AlignDevicesHook(execution_device=main_device, io_same_device=True, skip_keys=None, tied_params_map=None) add_hook_to_module(model, hook) return model # Adapted from accelerate.utils.modeling.infer_auto_device_map def infer_fx_auto_device_map( model: torch.fx.GraphModule, max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, dtype: Optional[Union[str, torch.dtype]] = None, special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, verbose: bool = False, ): """ Extends accelerate's infer_auto_device_map function to be compatible with torch.fx.GraphModule. The main modifications are: - Work around the fact that module.__class__.__name__ is Module for everything - We do not need to keep entire blocks together anymore, since we add a functional equivalent of the AlignDeviceHook before every call function. """ # TODO: Why no no_split_module_classes, clean_result parameters? # Get default / clean up max_memory max_memory = get_max_memory(max_memory) devices = list(max_memory.keys()) if "disk" not in devices: devices.append("disk") gpus = [device for device in devices if device not in ["cpu", "disk"]] # Devices that need to keep space for a potential offloaded layer. if "mps" in gpus: main_devices = ["mps"] elif len(gpus) > 0: main_devices = [gpus[0], "cpu"] else: main_devices = ["cpu"] module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) tied_parameters = find_tied_parameters(model) if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: logger.warn( "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." ) device_map = {} current_device = 0 current_memory_used = 0 call_list = [] buffers_attributes = [n for n, _ in list(named_module_tensors(model, recurse=True))] all_modules = [n.target for n in list(model.graph.nodes) if n.op == "call_module"] for node in model.graph.nodes: # If it's a module, we simply offload it or move it to the desired device if node.op == "call_module": name = node.target module = get_module(model, node.target) call_list.append((name, module)) # If it's get_attr, we check what module it is attached to # In case the module is not part of call_module, we specifically allocate the buffer/parameter on some device # NB: This does NOT guarantee that it will be aligned with whatever input tensor it will be combined with # For that, there is a separate function if node.op == "get_attr": target = node.target if target in buffers_attributes: module_name = ".".join(target.split(".")[:-1]) if module_name not in all_modules: module = get_module(model, target) call_list.append((target, module)) # Direct submodules and parameters modules_to_treat = call_list # Initialize maximum largest layer, to know which space to keep in memory max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, []) # Ready ? This is going to be a bit messy. while len(modules_to_treat) > 0: name, module = modules_to_treat.pop(0) if verbose: print(f"\nTreating module {name}.") # Max size in the remaining layers may have changed since we took one, so we maybe update it. max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")] if len(max_layer_names) == 0: max_layer_size, max_layer_names = get_max_layer_size( [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], module_sizes, [] ) # Assess size needed module_size = module_sizes[name] # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module # and the other is not. tied_param_goups = [ tied_group for tied_group in tied_parameters if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) ] if verbose and len(tied_param_goups) > 0: print(f" Found the relevant tied param groups {tied_param_goups}") # Then we keep track of all the parameters that are tied to the current module, but not in the current module tied_params = sum( [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_goups], [] ) if verbose and len(tied_params) > 0: print(f" So those parameters need to be taken into account {tied_params}") device = devices[current_device] current_max_size = max_memory[device] if device != "disk" else None # Reduce max size available by the largest layer. if devices[current_device] in main_devices: current_max_size = current_max_size - max_layer_size # Case 1 -> We're too big! if current_max_size is not None and current_memory_used + module_size > current_max_size: # For FX, we never split a leaf call_module if verbose: print( f"Not enough space on {devices[current_device]} to put {name} (space available " f"{current_max_size-current_memory_used}, module size {module_size})." ) if verbose: print("This module cannot be split, going to the next device.") current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat current_memory_used = 0 # Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters. elif len(tied_params) > 0: # First locate all tied modules tied_module_names = [] tied_modules = [] for tied_param in tied_params: tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] tied_module_names.append(modules_to_treat[tied_module_index][0]) tied_modules.append(modules_to_treat[tied_module_index][1]) if verbose: print( f" It looks like {name} is going to fit on {devices[current_device]} but we have tied " f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}" ) # Let's see if it all fits first module_size_with_ties = module_size for tied_param, tied_module_name in zip(tied_params, tied_module_names): module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size: # We really really fit! if verbose: print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.") current_memory_used += module_size_with_ties device_map[name] = devices[current_device] for tied_module_name in tied_module_names: if tied_module_name in [m[0] for m in modules_to_treat]: # The module may have been removed by a previous iteration of this loop. tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][ 0 ] modules_to_treat.pop(tied_module_index) device_map[tied_module_name] = devices[current_device] else: # We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it # smaller or do we need to go on the next device? if verbose: print( f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " f"available {current_max_size-current_memory_used}, needed size {module_size_with_ties})." ) split_happened = False for tied_module_name, tied_module in zip(tied_module_names, tied_modules): tied_module_children = list(tied_module.named_children()) if len(tied_module_children) == 0: # can't break this one. continue if verbose: print(f"Splitting {tied_module_name}.") tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] modules_to_treat = ( [(name, module)] + modules_to_treat[:tied_module_index] + tied_module_children + modules_to_treat[tied_module_index + 1 :] ) # Update the max layer size. max_layer_size, max_layer_names = get_max_layer_size( [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], module_sizes, [], ) split_happened = True break if not split_happened: # If the tied module is not split, we go to the next device if verbose: print("None of the tied module can be split, going to the next device.") current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat current_memory_used = 0 else: if verbose: if current_max_size is None: print(f"Putting {name} (size={module_size}) on {devices[current_device]}.") else: print( f"Putting {name} (size={module_size}) on {devices[current_device]} " f"(available={current_max_size-current_memory_used})." ) current_memory_used += module_size device_map[name] = devices[current_device] # If we have only one device, we simplify the device_map if len(set(device_map.values())) == 1: device_map = {"": list(device_map.values())[0]} return device_map def offload_call_function(model: torch.fx.GraphModule, device_map: Dict): """ Attaches AlignDevicesHook to fx.GraphModule call_function nodes. Although accelerate's `offload_model` attaches hooks to submodules, it is unable to detect call_function. """ # If we only have one device, offloading is not needed if len(set(device_map.values())) == 1: return for node in model.graph.nodes: if node.op == "call_function": def new_func(*args, old_callable=node.target, **kwargs): args = list(args) device_mapping = {} # Identify the device for each tensor in args and kwargs for _, arg in enumerate(args): all_devices = find_all_devices(arg) if all_devices is not None: device_mapping.update(dict(all_devices)) for k, v in kwargs.items(): all_devices = find_all_devices(arg) if all_devices is not None: device_mapping.update(dict(all_devices)) total_devices = [d for d in list(device_mapping.values()) if d is not None] # If there is only one device, no re-alignement is necessary if len(set(total_devices)) > 1: # Pick the main device, i.e. the first device that is not 'cpu' or 'disk' if set(device_mapping.values()) == {"cpu"} or set(device_mapping.values()) == {"cpu", "disk"}: device = "cpu" else: device = [d for d in device_mapping.values() if d not in ["cpu", "disk"]][0] # Align args and kwargs to the same device args = send_to_device(args, device) kwargs = send_to_device(kwargs, device) out = old_callable(*args, **kwargs) if len(set(total_devices)) > 1: # Restore the original device to avoid memory leaks for k, v in device_mapping.items(): k = k.to(v) return out node.meta["orig_target"] = node.target node.target = new_func model.recompile() model.graph.lint() def remove_hooks(model: torch.nn.Module): for module in model.modules(): if hasattr(module, "_hf_hook"): if hasattr(module, "allocate_params"): del module.allocate_params if hasattr(module, "offload_params"): del module.offload_params remove_hook_from_module(model, recurse=True) model.cpu() if hasattr(model, "graph"): for node in model.graph.nodes: if node.op == "call_function": if "orig_target" in node.meta: node.target = node.meta["orig_target"] del node.meta["orig_target"] model.recompile() model.graph.lint() def update_internal_dict(module): prefix = module._hf_hook.weights_map.prefix for key in module.state_dict().keys(): # It might happen that we call an quantization's inner modules, and this cause some parameters to be # already on meta device. This is not a problem for their value but we need to check here curr_device = (recurse_getattr(module, key + ".data")).device if str(curr_device) != "meta": module._hf_hook.weights_map.dataset.state_dict[prefix + key] = ( recurse_getattr(module, key + ".data") ).cpu() def find_all_devices(data): """ Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device). Args: (nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of. """ if isinstance(data, Mapping): devices = [] for obj in data.values(): device = find_all_devices(obj) if device is not None: devices.extend(device) return devices elif isinstance(data, (tuple, list)): devices = [] for obj in data: device = find_all_devices(obj) if device is not None: devices.extend(device) return devices elif isinstance(data, torch.Tensor): return [(data, str(data.device))] def calc_gpu_device_map(absolute_mem_margin: float = 2.0 * 1e9, relative_mem_margin: float = 0.3) -> Dict[int, float]: torch.cuda.empty_cache() gpu_device_map = { i: (torch.cuda.mem_get_info(i)[0] - absolute_mem_margin) * (1.0 - relative_mem_margin) for i in range(torch.cuda.device_count()) } return gpu_device_map def calc_cpu_device_map(absolute_mem_margin: float = 2.0 * 1e9, relative_mem_margin: float = 0.3) -> Dict[str, float]: cpu_device_map = {"cpu": (virtual_memory().available - absolute_mem_margin) * (1.0 - relative_mem_margin)} return cpu_device_map def offload_model( model: torch.nn.Module, gpu_device_map: Optional[Dict[int, float]] = None, cpu_device_map: Optional[Dict[str, float]] = None, ) -> torch.nn.Module: """ Wraps accelerate's infer_auto_device_map and dispatch_model. This functions if compatible both with classic nn.Modules, and with torch.fx.GraphModule. """ # FX vs non-FX model need different offloading config._FULL_STATE_DICT = True if gpu_device_map is None: gpu_device_map = calc_gpu_device_map() if cpu_device_map is None: cpu_device_map = calc_cpu_device_map() memory_map = {**cpu_device_map, **gpu_device_map} if isinstance(model, torch.fx.GraphModule): device_map = infer_fx_auto_device_map(model, memory_map) offload_call_function(model, device_map) else: device_map = infer_auto_device_map(model, memory_map, no_split_module_classes=model._no_split_modules) model = dispatch_model(model, device_map) # Fixes an asymetric behavior in Accelerate where hooks are not attached at all when a single device is used. # TODO: Fix directly in accelerate. if len(set(device_map.values())) == 1: model = align_input(model, device_map) config._FULL_STATE_DICT = False if "disk" in model.hf_device_map.values(): raise ValueError("Disk offload is not supported with quantization.") # We attach these functions to the hooked modules for convenience when modifying parameters during PTQ (e.g. SmoothQuant). # Attaching these functions allows use to fix a bug in accelerate with offloading to RAM/disk where even though a submodule parameter is updated, it is actually not updated in the AlignDevicesHook `weights_map` and thus # the update is ignored elsewhere. # TODO: Fix this bug directly in accelerate. https://github.com/huggingface/accelerate/pull/2214 would fix the bug for RAM offliading. def allocate_params(module): """ This function calls the pre_forward function of the _hf_hook, making sure parameters are on the selected device, rather than on the meta device. """ if module._hf_hook.offload is False: return # When quantizing and retrieving parameters (e.g., during GPTQ), we want to recurse through # all the submodules for m in module.modules(): if hasattr(m, "_hf_hook"): m._hf_hook.pre_forward(m) def offload_params(module): """ This functions moves the parameters back to the meta device, after making sure to update the internal state dict with the most recent values. """ if module._hf_hook.offload is False: return update_internal_dict(module) for m in module.modules(): if hasattr(m, "_hf_hook"): m._hf_hook.post_forward(m, torch.tensor([])) for module in model.modules(): if hasattr(module, "_hf_hook"): module.allocate_params = allocate_params module.offload_params = offload_params return model