optimum/habana/accelerate/utils/other.py (49 lines of code) (raw):

import torch def is_compiled_module(module: torch.nn.Module) -> bool: """ Check whether the module was compiled with torch.compile() """ if not hasattr(torch, "_dynamo"): return False return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) def has_compiled_regions(module: torch.nn.Module) -> bool: """ Check whether the module has submodules that were compiled with `torch.compile()`. """ if not hasattr(torch, "_dynamo"): return False if module._modules: for submodule in module.modules(): if isinstance(submodule, torch._dynamo.eval_frame.OptimizedModule): return True return False def is_repeated_blocks(module: torch.nn.Module) -> bool: """ Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This is useful to determine whether we should apply regional compilation to the module. """ return isinstance(module, torch.nn.ModuleList) and all(isinstance(m, module[0].__class__) for m in module) def has_repeated_blocks(module: torch.nn.Module) -> bool: """ Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the module. """ if module._modules: for submodule in module.modules(): if is_repeated_blocks(submodule): return True return False def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module: """ Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to hit the compiler's cache. For example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be accessed as `model.transformer.h[0]`. The rest of the model (e.g. model.lm_head) is compiled separately. This allows us to speed up the compilation overhead / cold start of models like LLMs and Transformers in general. See https://pytorch.org/tutorials/recipes/regional_compilation.html for more details. Args: module (`torch.nn.Module`): The model to compile. **compile_kwargs: Additional keyword arguments to pass to `torch.compile()`. Returns: `torch.nn.Module`: A new instance of the model with some compiled regions. Example: ```python >>> from accelerate.utils import compile_regions >>> from transformers import AutoModelForCausalLM >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> compiled_model = compile_regions(model, mode="reduce-overhead") >>> compiled_model.transformer.h[0] OptimizedModule( (_orig_mod): GPT2Block( (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (attn): GPT2Attention( (c_attn): Conv1D(nf=2304, nx=768) (c_proj): Conv1D(nf=768, nx=768) (attn_dropout): Dropout(p=0.1, inplace=False) (resid_dropout): Dropout(p=0.1, inplace=False) ) (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (mlp): GPT2MLP( (c_fc): Conv1D(nf=3072, nx=768) (c_proj): Conv1D(nf=768, nx=3072) (act): NewGELUActivation() (dropout): Dropout(p=0.1, inplace=False) ) ) ) ``` """ def _compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module: if is_repeated_blocks(module): new_module = torch.nn.ModuleList() for submodule in module: new_module.append(torch.compile(submodule, **compile_kwargs)) elif has_repeated_blocks(module): new_module = module.__class__.__new__(module.__class__) new_module.__dict__.update(module.__dict__) new_module._modules = {} for name, submodule in module.named_children(): new_module.add_module(name, _compile_regions(submodule, **compile_kwargs)) else: new_module = torch.compile(module, **compile_kwargs) return new_module new_module = _compile_regions(module, **compile_kwargs) if "_orig_mod" not in new_module.__dict__: # Keeps a reference to the original module to decompile/unwrap it later new_module.__dict__["_orig_mod"] = module return new_module def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs): """ Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`. Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that `torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method instead. Args: module (`torch.nn.Module`): The model to compile. **compile_kwargs: Additional keyword arguments to pass to `module.compile()`. """ if is_repeated_blocks(module): for submodule in module: submodule.compile(**compile_kwargs) elif has_repeated_blocks(module): for child in module.children(): compile_regions_deepspeed(child, **compile_kwargs) else: # leaf node module.compile(**compile_kwargs)