in optimum/habana/accelerate/utils/transformer_engine.py [0:0]
def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
"""
Recursively converts the linear layer of a model to their `transformers_engine` counterpart.
"""
from ...transformers.models.llama.modeling_llama import ModuleFusedSDPA
if not is_fp8_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
minimize_memory = str_to_bool(os.getenv("PT_HPU_FP8_MINIMIZE_MEMORY", "false"))
for name, module in model.named_children():
if is_peft_available() and isinstance(module, lora.Linear) and to_transformer_engine and _convert_linear:
# For lora linear module, convert only base linear layer to fp8 and skip lora-a,
# lora-b linear layers. Since lora-a, lora-b are small in size, there is not much
# device performance gain by pushing these in fp8. This way we avoid host overhead
# associated with using TE for these layers.
for name, lora_module in module.named_children():
if name == "base_layer":
has_bias = lora_module.bias is not None
# Initializing TE linear without weights and biases and shallow copying them from the original module.
te_module = te.Linear(
lora_module.in_features,
lora_module.out_features,
bias=has_bias,
params_dtype=lora_module.weight.dtype,
skip_weight_param_allocation=True,
minimize_memory=minimize_memory,
)
te_module.weight = lora_module.weight
if has_bias:
te_module.bias = lora_module.bias
setattr(module, name, te_module)
elif isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
# Initializing TE linear without weights and biases and shallow copying them from the original module.
te_module = te.Linear(
module.in_features,
module.out_features,
bias=has_bias,
params_dtype=module.weight.dtype,
skip_weight_param_allocation=True,
minimize_memory=minimize_memory,
)
te_module.weight = module.weight
if has_bias:
te_module.bias = module.bias
setattr(model, name, te_module)
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
new_module = torch.nn.Linear(
module.in_features,
module.out_features,
bias=has_bias,
dtype=module.weight.dtype,
device=module.weight.device,
)
new_module.weight.copy_(module.weight)
if has_bias:
new_module.bias.copy_(module.bias)
setattr(model, name, new_module)
elif isinstance(module, ModuleFusedSDPA) and module.flash_attention_fp8 and to_transformer_engine:
from habana_frameworks.torch.hpex.experimental.transformer_engine import (
FusedAttention as TE_FusedAttention,
)
class TE_ModuleFusedSDPA(torch.nn.Module):
def __init__(self):
super().__init__()
self._hpu_kernel_fsdpa = TE_FusedAttention(
scale=module.scale,
attention_dropout=module.attention_dropout,
enable_recompute=module.enable_recompute,
)
def forward(
self,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, softmax_mode)
setattr(model, name, TE_ModuleFusedSDPA())
else:
_convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear)