optimum/habana/accelerate/utils/transformer_engine.py (151 lines of code) (raw):
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import torch
from accelerate.utils import str_to_bool
from transformers.utils import is_peft_available
if is_peft_available():
from peft.tuners import lora
has_transformer_engine = False
def import_te():
global te, has_transformer_engine
try:
import habana_frameworks.torch.hpex.experimental.transformer_engine as te
has_transformer_engine = True
except ImportError:
has_transformer_engine = False
def is_fp8_available():
if not has_transformer_engine:
import_te()
return has_transformer_engine
def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
"""
Recursively converts the linear layer of a model to their `transformers_engine` counterpart.
"""
from ...transformers.models.llama.modeling_llama import ModuleFusedSDPA
if not is_fp8_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
minimize_memory = str_to_bool(os.getenv("PT_HPU_FP8_MINIMIZE_MEMORY", "false"))
for name, module in model.named_children():
if is_peft_available() and isinstance(module, lora.Linear) and to_transformer_engine and _convert_linear:
# For lora linear module, convert only base linear layer to fp8 and skip lora-a,
# lora-b linear layers. Since lora-a, lora-b are small in size, there is not much
# device performance gain by pushing these in fp8. This way we avoid host overhead
# associated with using TE for these layers.
for name, lora_module in module.named_children():
if name == "base_layer":
has_bias = lora_module.bias is not None
# Initializing TE linear without weights and biases and shallow copying them from the original module.
te_module = te.Linear(
lora_module.in_features,
lora_module.out_features,
bias=has_bias,
params_dtype=lora_module.weight.dtype,
skip_weight_param_allocation=True,
minimize_memory=minimize_memory,
)
te_module.weight = lora_module.weight
if has_bias:
te_module.bias = lora_module.bias
setattr(module, name, te_module)
elif isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
# Initializing TE linear without weights and biases and shallow copying them from the original module.
te_module = te.Linear(
module.in_features,
module.out_features,
bias=has_bias,
params_dtype=module.weight.dtype,
skip_weight_param_allocation=True,
minimize_memory=minimize_memory,
)
te_module.weight = module.weight
if has_bias:
te_module.bias = module.bias
setattr(model, name, te_module)
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
new_module = torch.nn.Linear(
module.in_features,
module.out_features,
bias=has_bias,
dtype=module.weight.dtype,
device=module.weight.device,
)
new_module.weight.copy_(module.weight)
if has_bias:
new_module.bias.copy_(module.bias)
setattr(model, name, new_module)
elif isinstance(module, ModuleFusedSDPA) and module.flash_attention_fp8 and to_transformer_engine:
from habana_frameworks.torch.hpex.experimental.transformer_engine import (
FusedAttention as TE_FusedAttention,
)
class TE_ModuleFusedSDPA(torch.nn.Module):
def __init__(self):
super().__init__()
self._hpu_kernel_fsdpa = TE_FusedAttention(
scale=module.scale,
attention_dropout=module.attention_dropout,
enable_recompute=module.enable_recompute,
)
def forward(
self,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, softmax_mode)
setattr(model, name, TE_ModuleFusedSDPA())
else:
_convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear)
def has_transformer_engine_layers(model):
"""
Returns whether a given model has some `transformer_engine` layer or not.
"""
if not is_fp8_available():
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
for m in model.modules():
if isinstance(m, (te.Linear)):
return True
return False
def convert_model(model):
"""
Converts torch.nn.Linear modules to `transformers_engine` Linear modules.
Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1303
"""
if not has_transformer_engine_layers(model):
with torch.no_grad():
_convert_model(model)
model._converted_to_transformer_engine = True
return model
def get_fp8_recipe(fp8_recipe_handler):
"""
Creates transformer engine FP8 recipe object.
Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1309
"""
if not is_fp8_available():
raise ImportError("Using `get_fp8_recipe` requires transformer_engine to be installed.")
kwargs = fp8_recipe_handler.to_dict() if fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te.recipe.Format, kwargs["fp8_format"])
fp8_recipe_handler = te.recipe.DelayedScaling(**kwargs)
fp8_recipe_handler.backend = "TE"
return fp8_recipe_handler
class FP8ContextWrapper:
"""
Helper class for FP8 context related operations.
"""
def __init__(self, ctx, fp8_recipe):
self.ctx = ctx
self.fp8_ctx = self.create_fp8_context(fp8_recipe)
def __enter__(self):
self.ctx.__enter__()
self.fp8_ctx.__enter__()
def __exit__(self, exc_type, exc_value, exc_traceback):
self.fp8_ctx.__exit__(exc_type, exc_value, exc_traceback)
self.ctx.__exit__(exc_type, exc_value, exc_traceback)
@staticmethod
def create_fp8_context(fp8_recipe):
return te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
@staticmethod
def _gradient_checkpointing_wrap(func, *args, **kwargs):
"""
`_gradient_checkpointing_func` always takes the function to be recomputed as the first argument. The function
below wraps this first argument with `transformer_engine`'s `activation_checkpointing` context.
"""
_args = list(args)
_args[0] = te.distributed.activation_checkpointing()(_args[0])
args = tuple(_args)
return func(*args, **kwargs)
@staticmethod
def gradient_checkpointing_wrap(model):
"""
Wrap `_gradient_checkpointing_func` in the model with `transformer_engine`'s `activation_checkpointing` context.
This context is used to signal the `transformer_engine` modules whether they have been called with activation checkpointing enabled or not.
"""
if hasattr(model, "gradient_checkpointing") and model.gradient_checkpointing:
model._gradient_checkpointing_func = functools.partial(
FP8ContextWrapper._gradient_checkpointing_wrap, model._gradient_checkpointing_func
)
return
for module in model.modules():
if hasattr(module, "gradient_checkpointing") and module.gradient_checkpointing:
module._gradient_checkpointing_func = functools.partial(
FP8ContextWrapper._gradient_checkpointing_wrap, module._gradient_checkpointing_func
)